graph_router.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import sys,os
  2. current_path = os.getcwd()
  3. sys.path.append(current_path)
  4. from fastapi import APIRouter, Depends, Query
  5. from typing import Optional, List
  6. from libs.graph_helper import GraphHelper
  7. from saas.models.response import StandardResponse
  8. from saas.models.request import GraphFilterRequest
  9. router = APIRouter(prefix="/graph", tags=["Knowledge Graph"])
  10. graph_helper = GraphHelper()
  11. @router.get("/nodes/search", response_model=StandardResponse)
  12. async def search_nodes(
  13. keyword: str = Query(..., min_length=2),
  14. limit: int = Query(10, ge=1, le=100),
  15. node_type: Optional[str] = Query(None),
  16. min_degree: Optional[int] = Query(None)
  17. ):
  18. """
  19. 根据关键词和属性过滤条件搜索图谱节点
  20. """
  21. try:
  22. results = graph_helper.node_search(
  23. keyword,
  24. limit=limit,
  25. node_type=node_type,
  26. min_degree=min_degree
  27. )
  28. community_report_results = graph_helper.community_report_search(keyword)
  29. return StandardResponse(
  30. success=True,
  31. records={"nodes":results,"community_report":community_report_results},
  32. error_code = 0,
  33. error_msg=f"Found {len(results)} nodes"
  34. )
  35. except Exception as e:
  36. raise e
  37. return StandardResponse(
  38. success=False,
  39. message=str(e)
  40. )
  41. @router.get("/nodes/neighbor_search", response_model=StandardResponse)
  42. async def neighbor_search(
  43. keyword: str = Query(..., min_length=2),
  44. limit: int = Query(10, ge=1, le=100),
  45. node_type: Optional[str] = Query(None),
  46. neighbor_type: Optional[str] = Query(None),
  47. min_degree: Optional[int] = Query(None)
  48. ):
  49. """
  50. 根据关键词和属性过滤条件搜索图谱节点
  51. """
  52. try:
  53. scores_factor = 1.7
  54. results = []
  55. diseases = {}
  56. # results = graph_helper.node_search(
  57. # keyword,
  58. # limit=limit,
  59. # node_type=node_type,
  60. # min_degree=min_degree
  61. # )
  62. print("搜索的结果数量:",len(results)) # 检查是否有结果返回,没有则进行关键词拆分搜索(仅针对症状),并检查拆分后的结果是否有结果返回,没有则返回空结果。如果有结果返回,则返回结果。(仅针对症状),并检查拆分后的结果是否有结果返回,没有则返回空结果。如果有结果返回,则返回结果。
  63. has_good_result = False
  64. # new_results = []
  65. # for item in results:
  66. # if item["score"] > scores_factor:
  67. # has_good_result = True
  68. # new_results.append(item)
  69. #results = new_results
  70. print("通过相似度过滤之后剩余的数量:",len(results))
  71. if not has_good_result:
  72. keywords = keyword.split(" ")
  73. new_results = []
  74. for item in keywords:
  75. if len(item) > 1:
  76. results = graph_helper.node_search(
  77. item,
  78. limit=limit,
  79. node_type=node_type,
  80. min_degree=min_degree
  81. )
  82. for result_item in results:
  83. if result_item["score"] > scores_factor:
  84. new_results.append(result_item)
  85. if result_item["type"] == "Disease":
  86. if result_item["id"] not in diseases:
  87. diseases[result_item["id"]] = {
  88. "id":result_item["id"],
  89. "type":result_item["type"],
  90. "count":1
  91. }
  92. else:
  93. diseases[result_item["id"]]["count"] = diseases[result_item["id"]]["count"] + 1
  94. has_good_result = True
  95. results = new_results
  96. print("扩展搜索的结果数量:",len(results))
  97. neighbors_data = {}
  98. for item in results:
  99. entities, relations = graph_helper.neighbor_search(item["id"], 1)
  100. max = 20 #因为类似发热这种疾病会有很多关联的疾病,所以需要防止检索范围过大,设置了上限
  101. for neighbor in entities:
  102. if neighbor["type"] == neighbor_type:
  103. #如果这里正好找到了要求检索的节点类型
  104. if neighbor["id"] not in neighbors_data:
  105. neighbors_data[neighbor["id"]] = {
  106. "id":neighbor["id"],
  107. "type":neighbor["type"],
  108. "count":1
  109. }
  110. else:
  111. neighbors_data[neighbor["id"]]["count"] = neighbors_data[neighbor["id"]]["count"] + 1
  112. else:
  113. #如果这里找到的节点是个疾病,那么就再检索一层,看看是否有符合要求的节点类型
  114. if neighbor["type"] == "Disease":
  115. if neighbor["id"] not in diseases:
  116. diseases[neighbor["id"]] = {
  117. "id":neighbor["id"],
  118. "type":neighbor["type"],
  119. "count":1
  120. }
  121. else:
  122. diseases[neighbor["id"]]["count"] = diseases[neighbor["id"]]["count"] + 1
  123. disease_entities, relations = graph_helper.neighbor_search(neighbor["id"], 1)
  124. for disease_neighbor in disease_entities:
  125. #有时候数据中会出现链接有的节点,但是节点数据中缺失的情况,所以这里要检查
  126. if "type" in disease_neighbor.keys():
  127. if disease_neighbor["type"] == neighbor_type:
  128. if disease_neighbor["id"] not in neighbors_data:
  129. neighbors_data[disease_neighbor["id"]] = {
  130. "id":disease_neighbor["id"],
  131. "type":disease_neighbor["type"],
  132. "count":1
  133. }
  134. else:
  135. neighbors_data[disease_neighbor["id"]]["count"] = neighbors_data[disease_neighbor["id"]]["count"] + 1
  136. #最多搜索的范围是max个疾病
  137. max = max - 1
  138. if max == 0:
  139. break
  140. disease_data = [diseases[k] for k in diseases]
  141. disease_data = sorted(disease_data, key=lambda x:x["count"],reverse=True)
  142. data = [neighbors_data[k] for k in neighbors_data]
  143. data = sorted(data, key=lambda x:x["count"],reverse=True)
  144. if len(data) > 10:
  145. data = data[:10]
  146. factor = 1.0
  147. total = 0.0
  148. for item in data:
  149. total = item["count"] * factor + total
  150. for item in data:
  151. item["count"] = item["count"] / total
  152. factor = factor * 0.9
  153. if len(disease_data) > 10:
  154. disease_data = disease_data[:10]
  155. factor = 1.0
  156. total = 0.0
  157. for item in disease_data:
  158. total = item["count"] * factor + total
  159. for item in disease_data:
  160. item["count"] = item["count"] / total
  161. factor = factor * 0.9
  162. return StandardResponse(
  163. success=True,
  164. records={"nodes":disease_data,"neighbors":data},
  165. error_code = 0,
  166. error_msg=f"Found {len(results)} nodes"
  167. )
  168. except Exception as e:
  169. return StandardResponse(
  170. success=False,
  171. error_code=500,
  172. error_msg=str(e)
  173. )
  174. @router.post("/nodes/filter", response_model=StandardResponse)
  175. async def filter_nodes(request: GraphFilterRequest):
  176. """
  177. 根据复杂条件过滤节点
  178. """
  179. try:
  180. results = graph_helper.filter_nodes(
  181. node_types=request.node_types,
  182. min_degree=request.min_degree,
  183. min_community_size=request.min_community_size,
  184. attributes=request.attributes
  185. )
  186. return StandardResponse(
  187. success=True,
  188. data={"nodes": results},
  189. message=f"Filtered {len(results)} nodes"
  190. )
  191. except Exception as e:
  192. return StandardResponse(
  193. success=False,
  194. message=str(e)
  195. )
  196. @router.get("/statistics", response_model=StandardResponse)
  197. async def get_graph_statistics():
  198. """
  199. 获取图谱统计信息
  200. """
  201. try:
  202. stats = graph_helper.get_graph_statistics()
  203. return StandardResponse(
  204. success=True,
  205. data=stats,
  206. message="Graph statistics retrieved"
  207. )
  208. except Exception as e:
  209. return StandardResponse(
  210. success=False,
  211. message=str(e)
  212. )
  213. @router.get("/community/{community_id}", response_model=StandardResponse)
  214. async def get_community_details(
  215. community_id: int,
  216. min_size: int = Query(3, ge=2)
  217. ):
  218. """
  219. 获取指定社区的详细信息
  220. """
  221. try:
  222. community_info = graph_helper.get_community_details(community_id, min_size)
  223. return StandardResponse(
  224. success=True,
  225. data=community_info,
  226. message="Community details retrieved"
  227. )
  228. except Exception as e:
  229. return StandardResponse(
  230. success=False,
  231. message=str(e)
  232. )
  233. # # 新增图谱路径分析接口
  234. # @router.post("/path-analysis", response_model=StandardResponse)
  235. # async def analyze_paths(request: GraphSearchRequest):
  236. # """
  237. # 分析节点间的潜在关系路径
  238. # """
  239. # try:
  240. # paths = graph_helper.find_paths(
  241. # source_id=request.source_id,
  242. # target_id=request.target_id,
  243. # max_depth=request.max_depth
  244. # )
  245. # return StandardResponse(
  246. # success=True,
  247. # data={"paths": paths},
  248. # message=f"Found {len(paths)} possible paths"
  249. # )
  250. # except Exception as e:
  251. # return StandardResponse(
  252. # success=False,
  253. # message=str(e)
  254. # )
  255. @router.get("/node/{node_id}", response_model=StandardResponse)
  256. async def get_node_details(
  257. node_id: str,
  258. with_relations: bool = False,
  259. relation_types: List[str] = Query(None),
  260. relation_limit: int = Query(10, ge=1, le=100)
  261. ):
  262. """
  263. 获取节点详细信息
  264. - relation_types: 过滤指定类型的关系
  265. - relation_limit: 返回关系数量限制
  266. """
  267. try:
  268. node_info = graph_helper.get_node_details(
  269. node_id,
  270. include_relations=with_relations,
  271. relation_types=relation_types,
  272. relation_limit=relation_limit
  273. )
  274. return StandardResponse(
  275. success=True,
  276. data=node_info,
  277. message="Node details retrieved"
  278. )
  279. except ValueError as e:
  280. return StandardResponse(
  281. success=False,
  282. message=str(e)
  283. )
  284. graph_router = router