graph_network_router.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import sys,os
  2. current_path = os.getcwd()
  3. sys.path.append(current_path)
  4. from fastapi import APIRouter, Depends, Query
  5. from typing import Optional, List
  6. from agent.cdss.libs.cdss_helper import CDSSHelper
  7. from web.models.response import StandardResponse
  8. from web.models.request import GraphFilterRequest
  9. router = APIRouter(prefix="/graph", tags=["Knowledge Graph"])
  10. graph_helper = CDSSHelper()
  11. @router.get("/nodes/search", response_model=StandardResponse)
  12. async def search_nodes(
  13. keyword: str = Query(..., min_length=2),
  14. limit: int = Query(10, ge=1, le=100),
  15. node_type: Optional[str] = Query(None),
  16. min_degree: Optional[int] = Query(None)
  17. ):
  18. """
  19. 根据关键词和属性过滤条件搜索图谱节点
  20. """
  21. try:
  22. results = graph_helper.node_search(
  23. keyword,
  24. limit=limit,
  25. node_type=node_type,
  26. min_degree=min_degree
  27. )
  28. community_report_results = graph_helper.community_report_search(keyword)
  29. return StandardResponse(
  30. success=True,
  31. records={"nodes":results,"community_report":community_report_results},
  32. code = 0,
  33. error_msg=f"Found {len(results)} nodes"
  34. )
  35. except Exception as e:
  36. return StandardResponse(
  37. success=False,
  38. message=str(e)
  39. )
  40. @router.get("/nodes/neighbor_search", response_model=StandardResponse)
  41. async def neighbor_search(
  42. keyword: str = Query(..., min_length=2),
  43. limit: int = Query(10, ge=1, le=100),
  44. node_type: Optional[str] = Query(None),
  45. neighbor_type: Optional[str] = Query(None),
  46. min_degree: Optional[int] = Query(None)
  47. ):
  48. """
  49. 根据关键词和属性过滤条件搜索图谱节点
  50. """
  51. try:
  52. scores_factor = 1.7
  53. results = []
  54. diseases = {}
  55. # results = graph_helper.node_search(
  56. # keyword,
  57. # limit=limit,
  58. # node_type=node_type,
  59. # min_degree=min_degree
  60. # )
  61. print("搜索的结果数量:",len(results)) # 检查是否有结果返回,没有则进行关键词拆分搜索(仅针对症状),并检查拆分后的结果是否有结果返回,没有则返回空结果。如果有结果返回,则返回结果。(仅针对症状),并检查拆分后的结果是否有结果返回,没有则返回空结果。如果有结果返回,则返回结果。
  62. has_good_result = False
  63. # new_results = []
  64. # for item in results:
  65. # if item["score"] > scores_factor:
  66. # has_good_result = True
  67. # new_results.append(item)
  68. #results = new_results
  69. print("通过相似度过滤之后剩余的数量:",len(results))
  70. if not has_good_result:
  71. keywords = keyword.split(" ")
  72. new_results = []
  73. for item in keywords:
  74. if len(item) > 1:
  75. results = graph_helper.node_search(
  76. item,
  77. limit=limit,
  78. node_type=node_type,
  79. min_degree=min_degree
  80. )
  81. for result_item in results:
  82. if result_item["score"] > scores_factor:
  83. new_results.append(result_item)
  84. if result_item["type"] == "Disease":
  85. if result_item["id"] not in diseases:
  86. diseases[result_item["id"]] = {
  87. "id":result_item["id"],
  88. "type":result_item["type"],
  89. "count":1
  90. }
  91. else:
  92. diseases[result_item["id"]]["count"] = diseases[result_item["id"]]["count"] + 1
  93. has_good_result = True
  94. results = new_results
  95. print("扩展搜索的结果数量:",len(results))
  96. neighbors_data = {}
  97. for item in results:
  98. entities, relations = graph_helper.neighbor_search(item["id"], 1)
  99. max = 20 #因为类似发热这种疾病会有很多关联的疾病,所以需要防止检索范围过大,设置了上限
  100. for neighbor in entities:
  101. if neighbor["type"] == neighbor_type:
  102. #如果这里正好找到了要求检索的节点类型
  103. if neighbor["id"] not in neighbors_data:
  104. neighbors_data[neighbor["id"]] = {
  105. "id":neighbor["id"],
  106. "type":neighbor["type"],
  107. "count":1
  108. }
  109. else:
  110. neighbors_data[neighbor["id"]]["count"] = neighbors_data[neighbor["id"]]["count"] + 1
  111. else:
  112. #如果这里找到的节点是个疾病,那么就再检索一层,看看是否有符合要求的节点类型
  113. if neighbor["type"] == "Disease":
  114. if neighbor["id"] not in diseases:
  115. diseases[neighbor["id"]] = {
  116. "id":neighbor["id"],
  117. "type":neighbor["type"],
  118. "count":1
  119. }
  120. else:
  121. diseases[neighbor["id"]]["count"] = diseases[neighbor["id"]]["count"] + 1
  122. disease_entities, relations = graph_helper.neighbor_search(neighbor["id"], 1)
  123. for disease_neighbor in disease_entities:
  124. #有时候数据中会出现链接有的节点,但是节点数据中缺失的情况,所以这里要检查
  125. if "type" in disease_neighbor.keys():
  126. if disease_neighbor["type"] == neighbor_type:
  127. if disease_neighbor["id"] not in neighbors_data:
  128. neighbors_data[disease_neighbor["id"]] = {
  129. "id":disease_neighbor["id"],
  130. "type":disease_neighbor["type"],
  131. "count":1
  132. }
  133. else:
  134. neighbors_data[disease_neighbor["id"]]["count"] = neighbors_data[disease_neighbor["id"]]["count"] + 1
  135. #最多搜索的范围是max个疾病
  136. max = max - 1
  137. if max == 0:
  138. break
  139. disease_data = [diseases[k] for k in diseases]
  140. disease_data = sorted(disease_data, key=lambda x:x["count"],reverse=True)
  141. data = [neighbors_data[k] for k in neighbors_data]
  142. data = sorted(data, key=lambda x:x["count"],reverse=True)
  143. if len(data) > 10:
  144. data = data[:10]
  145. factor = 1.0
  146. total = 0.0
  147. for item in data:
  148. total = item["count"] * factor + total
  149. for item in data:
  150. item["count"] = item["count"] / total
  151. factor = factor * 0.9
  152. if len(disease_data) > 10:
  153. disease_data = disease_data[:10]
  154. factor = 1.0
  155. total = 0.0
  156. for item in disease_data:
  157. total = item["count"] * factor + total
  158. for item in disease_data:
  159. item["count"] = item["count"] / total
  160. factor = factor * 0.9
  161. return StandardResponse(
  162. success=True,
  163. records={"nodes":disease_data,"neighbors":data},
  164. error_code = 0,
  165. error_msg=f"Found {len(results)} nodes"
  166. )
  167. except Exception as e:
  168. return StandardResponse(
  169. success=False,
  170. error_code=500,
  171. error_msg=str(e)
  172. )
  173. @router.post("/nodes/filter", response_model=StandardResponse)
  174. async def filter_nodes(request: GraphFilterRequest):
  175. """
  176. 根据复杂条件过滤节点
  177. """
  178. try:
  179. results = graph_helper.filter_nodes(
  180. node_types=request.node_types,
  181. min_degree=request.min_degree,
  182. min_community_size=request.min_community_size,
  183. attributes=request.attributes
  184. )
  185. return StandardResponse(
  186. success=True,
  187. data={"nodes": results},
  188. message=f"Filtered {len(results)} nodes"
  189. )
  190. except Exception as e:
  191. return StandardResponse(
  192. success=False,
  193. message=str(e)
  194. )
  195. @router.get("/statistics", response_model=StandardResponse)
  196. async def get_graph_statistics():
  197. """
  198. 获取图谱统计信息
  199. """
  200. try:
  201. stats = graph_helper.get_graph_statistics()
  202. return StandardResponse(
  203. success=True,
  204. data=stats,
  205. message="Graph statistics retrieved"
  206. )
  207. except Exception as e:
  208. return StandardResponse(
  209. success=False,
  210. message=str(e)
  211. )
  212. @router.get("/community/{community_id}", response_model=StandardResponse)
  213. async def get_community_details(
  214. community_id: int,
  215. min_size: int = Query(3, ge=2)
  216. ):
  217. """
  218. 获取指定社区的详细信息
  219. """
  220. try:
  221. community_info = graph_helper.get_community_details(community_id, min_size)
  222. return StandardResponse(
  223. success=True,
  224. data=community_info,
  225. message="Community details retrieved"
  226. )
  227. except Exception as e:
  228. return StandardResponse(
  229. success=False,
  230. message=str(e)
  231. )
  232. # # 新增图谱路径分析接口
  233. # @router.post("/path-analysis", response_model=StandardResponse)
  234. # async def analyze_paths(request: GraphSearchRequest):
  235. # """
  236. # 分析节点间的潜在关系路径
  237. # """
  238. # try:
  239. # paths = graph_helper.find_paths(
  240. # source_id=request.source_id,
  241. # target_id=request.target_id,
  242. # max_depth=request.max_depth
  243. # )
  244. # return StandardResponse(
  245. # success=True,
  246. # data={"paths": paths},
  247. # message=f"Found {len(paths)} possible paths"
  248. # )
  249. # except Exception as e:
  250. # return StandardResponse(
  251. # success=False,
  252. # message=str(e)
  253. # )
  254. @router.get("/node/{node_id}", response_model=StandardResponse)
  255. async def get_node_details(
  256. node_id: str,
  257. with_relations: bool = False,
  258. relation_types: List[str] = Query(None),
  259. relation_limit: int = Query(10, ge=1, le=100)
  260. ):
  261. """
  262. 获取节点详细信息
  263. - relation_types: 过滤指定类型的关系
  264. - relation_limit: 返回关系数量限制
  265. """
  266. try:
  267. node_info = graph_helper.get_node_details(
  268. node_id,
  269. include_relations=with_relations,
  270. relation_types=relation_types,
  271. relation_limit=relation_limit
  272. )
  273. return StandardResponse(
  274. success=True,
  275. data=node_info,
  276. message="Node details retrieved"
  277. )
  278. except ValueError as e:
  279. return StandardResponse(
  280. success=False,
  281. message=str(e)
  282. )
  283. graph_router = router