kb_router.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. import sys,os
  2. current_path = os.getcwd()
  3. sys.path.append(current_path)
  4. from config.site import SiteConfig
  5. from fastapi import APIRouter, Depends, Query
  6. from db.database import get_db
  7. from sqlalchemy.orm import Session
  8. from sqlalchemy import text
  9. from agent.models.web.response import StandardResponse,FAILED,SUCCESS
  10. from agent.models.web.request import BasicRequest
  11. from agent.libs.graph import GraphBusiness
  12. from agent.libs.auth import verify_session_id, SessionValues
  13. import logging
  14. from typing import Optional, List
  15. router = APIRouter(prefix="/kb", tags=["knowledge build interface"])
  16. logger = logging.getLogger(__name__)
  17. config = SiteConfig()
  18. LOG_DIR = config.get_config("TASK_LOG_DIR", current_path)
  19. # job_category = Column(String(64), nullable=False)
  20. # job_name = Column(String(64))
  21. # job_details = Column(Text, nullable=False)
  22. # job_creator = Column(String(64), nullable=False)
  23. # job_logs = Column(Text, nullable=True)
  24. # job_files = Column(String(300), nullable=True)
  25. @router.post('/summary', response_model=StandardResponse)
  26. def summary_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse:
  27. if request.action != "get_summary":
  28. return StandardResponse(code=FAILED, message="invalid action")
  29. graph_id = request.get_param("graph_id",0)
  30. biz = GraphBusiness(db)
  31. summary = biz.get_graph_summary(graph_id=graph_id)
  32. if summary:
  33. logger.info(summary)
  34. return StandardResponse(code=SUCCESS, message="summary found", records=[summary])
  35. else:
  36. return StandardResponse(code=FAILED, message="summary not found",records=[])
  37. @router.post('/schemas', response_model=StandardResponse)
  38. def schemas_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse:
  39. if request.action== "get_nodes_schemas":
  40. graph_id = request.get_param("graph_id",0)
  41. biz = GraphBusiness(db)
  42. schemas = biz.get_nodes_categories(graph_id=graph_id)
  43. if schemas:
  44. return StandardResponse(code=SUCCESS, message="schemas found", records=schemas)
  45. if request.action== "get_edges_schemas":
  46. graph_id = request.get_param("graph_id",0)
  47. biz = GraphBusiness(db)
  48. schemas = biz.get_edges_categories(graph_id=graph_id)
  49. if schemas:
  50. return StandardResponse(code=SUCCESS, message="schemas found", records=schemas)
  51. return StandardResponse(code=FAILED, message="invalid action")
  52. @router.post('/nodes', response_model=StandardResponse)
  53. def nodes_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse:
  54. if (request.action == "search_nodes"):
  55. node_name = request.get_param("name","")
  56. category = request.get_param("category","")
  57. graph_id = request.get_param("graph_id",0)
  58. biz = GraphBusiness(db)
  59. if node_name == "":
  60. return StandardResponse(code=FAILED, message="node name is empty", records=[])
  61. if category == "":
  62. return StandardResponse(code=FAILED, message="category is empty", records=[])
  63. if graph_id == 0:
  64. return StandardResponse(code=FAILED, message="graph id is empty", records=[])
  65. nodes = biz.search_like_node_by_name(graph_id=graph_id, category=category, name=node_name)
  66. if nodes:
  67. return StandardResponse(code=SUCCESS, message="nodes found", records=nodes)
  68. else:
  69. return StandardResponse(code=FAILED, message="search job failed")
  70. elif (request.action == "get_nodes"):
  71. graph_id = request.get_param("graph_id",0)
  72. page = request.get_param("page",1)
  73. page_size = request.get_param("page_size",1)
  74. biz = GraphBusiness(db)
  75. nodes = biz.get_nodes_by_page(graph_id=graph_id, page=page, page_size=page_size)
  76. if nodes:
  77. return StandardResponse(code=SUCCESS, message="nodes found", records=nodes)
  78. elif (request.action == "neighbors"):
  79. graph_id = request.get_param("graph_id",0)
  80. node_id = request.get_param("node_id",0)
  81. if node_id>0:
  82. biz = GraphBusiness(db)
  83. node = biz.get_node_by_id(graph_id=graph_id, node_id=node_id)
  84. if node is None:
  85. return StandardResponse(code=FAILED, message="node not found", records=[])
  86. nodes_in = biz.get_neighbors(graph_id=graph_id, node_id=node_id, direction="in")
  87. nodes_out = biz.get_neighbors(graph_id=graph_id, node_id=node_id, direction="out")
  88. nodes_all = []
  89. nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"self"})
  90. for node in nodes_in:
  91. nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"in"})
  92. for node in nodes_out:
  93. nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"out"})
  94. return StandardResponse(code=SUCCESS, message="nodes found", records=nodes_all)
  95. return StandardResponse(code=FAILED, message="invalid action")
  96. async def get_node_properties(node_id: int, db: Session) -> dict:
  97. """
  98. 查询节点的属性
  99. :param node_id: 节点ID
  100. :param db: 数据库会话
  101. :return: 属性字典
  102. """
  103. prop_sql = text("SELECT prop_title, prop_value FROM kg_props WHERE ref_id = :node_id")
  104. result = db.execute(prop_sql, {'node_id': node_id}).fetchall()
  105. properties = {}
  106. for row in result:
  107. properties[row._mapping['prop_title']] = row._mapping['prop_value']
  108. return properties
  109. @router.get("/graph_data", response_model=StandardResponse)
  110. async def get_graph_data(
  111. label_name: Optional[str] = Query(None),
  112. input_str: Optional[str] = Query(None),
  113. db: Session = Depends(get_db),
  114. sess:SessionValues = Depends(verify_session_id)
  115. ):
  116. """
  117. 获取用户关联的图谱数据
  118. - 从session_id获取user_id
  119. - 查询DbUserDataRelation获取用户关联的数据
  120. - 返回与Java端一致的数据结构
  121. """
  122. try:
  123. # 1. 从session获取user_id
  124. user_id = sess.user_id
  125. if not user_id:
  126. return StandardResponse(code=FAILED, message="user not found", records=[])
  127. # 2. 使用JOIN查询用户关联的图谱数据
  128. sql = text("""
  129. WITH RankedRelations AS (
  130. SELECT
  131. e.name as rType,
  132. m.id as target_id,
  133. m.name as target_name,
  134. m.category as target_label,
  135. (SELECT COUNT(*) FROM kg_edges WHERE src_id = m.id) as pCount,
  136. ROW_NUMBER() OVER(PARTITION BY e.name ORDER BY m.id) as rn
  137. FROM user_data_relations udr
  138. JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode'
  139. JOIN kg_edges e ON n.id = e.src_id
  140. JOIN kg_nodes m ON e.dest_id = m.id
  141. WHERE udr.user_id = :user_id
  142. AND n.category = :label_name
  143. AND n.name = :input_str
  144. AND n.status = '0'
  145. )
  146. SELECT rType, target_id, target_name, target_label, pCount
  147. FROM RankedRelations
  148. WHERE rn <= 50
  149. ORDER BY rType
  150. """)
  151. # 3. 组装返回数据结构
  152. categories = ["中心词", "关系"]
  153. nodes = []
  154. links = []
  155. # 查询中心节点
  156. center_sql = text("""
  157. SELECT n.id, n.name, n.category
  158. FROM user_data_relations udr
  159. JOIN kg_nodes n ON udr.data_id = n.id AND udr.data_category = 'DbKgNode'
  160. WHERE udr.user_id = :user_id
  161. AND n.category = :label_name
  162. AND n.name = :input_str
  163. AND n.status = '0'
  164. """)
  165. # 执行查询并处理结果
  166. center_node = None
  167. rtype_map = {}
  168. # 1. 查询中心节点
  169. center_result = db.execute(center_sql, {
  170. 'user_id': user_id,
  171. 'label_name': label_name,
  172. 'input_str': input_str
  173. }).fetchall()
  174. if center_result:
  175. for row in center_result:
  176. center_node = {
  177. "id": 0,
  178. "name": row._mapping['name'],
  179. "label": row._mapping['category'],
  180. "symbolSize": 50,
  181. "symbol": "circle",
  182. "properties": await get_node_properties(row._mapping['id'], db)
  183. }
  184. nodes.append(center_node)
  185. break
  186. # 2. 查询关联的边和目标节点
  187. relation_result = db.execute(sql, {
  188. 'user_id': user_id,
  189. 'label_name': label_name,
  190. 'input_str': input_str
  191. }).fetchall()
  192. if relation_result:
  193. for row in relation_result:
  194. r_type = row._mapping['rtype']
  195. target_node = {
  196. "id": row._mapping['target_id'],
  197. "name": row._mapping['target_name'],
  198. "label": row._mapping['target_label'],
  199. "symbolSize": 28,
  200. "symbol": "circle",
  201. "properties": await get_node_properties(row._mapping['target_id'], db)
  202. }
  203. if r_type not in rtype_map:
  204. rtype_map[r_type] = []
  205. rtype_map[r_type].append(target_node)
  206. if r_type not in categories:
  207. categories.append(r_type)
  208. # 3. 组装返回结果
  209. if center_node:
  210. for r_type, targets in rtype_map.items():
  211. # 添加关系节点
  212. relation_node = {
  213. "id": len(nodes),
  214. "name": "",
  215. "label": center_node['label'],
  216. "symbolSize": 10,
  217. "symbol": "diamond",
  218. "properties": {}
  219. }
  220. nodes.append(relation_node)
  221. # 添加中心节点到关系节点的链接
  222. links.append({
  223. "source": center_node['name'],
  224. "target": "",
  225. "name": r_type,
  226. "category": r_type
  227. })
  228. # 添加关系节点到目标节点的链接
  229. for target in targets:
  230. links.append({
  231. "source": "",
  232. "target": target['name'],
  233. "name": "",
  234. "category": r_type
  235. })
  236. final_data={
  237. "categories": categories,
  238. "nodes": nodes,
  239. "links": links
  240. }
  241. return StandardResponse(
  242. records=[{"records": final_data}],
  243. message="Graph data retrieved"
  244. )
  245. except Exception as e:
  246. return StandardResponse(
  247. code=500,
  248. message=str(e)
  249. )
  250. kb_router = router