graph_router.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import sys,os
  2. from community.graph_helper2 import GraphHelper
  3. from model.response import StandardResponse
  4. current_path = os.getcwd()
  5. sys.path.append(current_path)
  6. import time
  7. from fastapi import APIRouter, Depends, Query
  8. from typing import Optional, List
  9. import sys
  10. sys.path.append('..')
  11. from utils.agent import call_chat_api,get_conversation_id
  12. import json
  13. router = APIRouter(prefix="/graph", tags=["Knowledge Graph"])
  14. graph_helper = GraphHelper()
  15. @router.get("/nodes/recommend", response_model=StandardResponse)
  16. async def recommend(
  17. chief: str
  18. ):
  19. app_id = "256fd853-60b0-4357-b11b-8114b4e90ae0"
  20. conversation_id = get_conversation_id(app_id)
  21. result = call_chat_api(app_id, conversation_id, chief)
  22. json_data = json.loads(result)
  23. keyword = " ".join(json_data["chief_complaint"])
  24. return await neighbor_search(keyword=keyword, neighbor_type='Check',limit=10)
  25. @router.get("/nodes/neighbor_search", response_model=StandardResponse)
  26. async def neighbor_search(
  27. keyword: str = Query(..., min_length=2),
  28. limit: int = Query(10, ge=1, le=100),
  29. node_type: Optional[str] = Query(None),
  30. neighbor_type: Optional[str] = Query(None),
  31. min_degree: Optional[int] = Query(None)
  32. ):
  33. """
  34. 根据关键词和属性过滤条件搜索图谱节点
  35. """
  36. try:
  37. start_time = time.time()
  38. print(f"开始执行neighbor_search,参数:keyword={keyword}, limit={limit}, node_type={node_type}, neighbor_type={neighbor_type}, min_degree={min_degree}")
  39. scores_factor = 1.7
  40. results = []
  41. diseases = {}
  42. has_good_result = False
  43. if not has_good_result:
  44. keywords = keyword.split(" ")
  45. new_results = []
  46. for item in keywords:
  47. if len(item) > 1:
  48. results = graph_helper.node_search2(
  49. item,
  50. limit=limit,
  51. node_type=node_type,
  52. min_degree=min_degree
  53. )
  54. for result_item in results:
  55. if result_item["score"] > scores_factor:
  56. new_results.append(result_item)
  57. if result_item["type"] == "Disease":
  58. if result_item["id"] not in diseases:
  59. diseases[result_item["id"]] = {
  60. "id":result_item["id"],
  61. "type":1,
  62. "count":1
  63. }
  64. else:
  65. diseases[result_item["id"]]["count"] = diseases[result_item["id"]]["count"] + 1
  66. has_good_result = True
  67. results = new_results
  68. print("扩展搜索的结果数量:",len(results))
  69. neighbors_data = {}
  70. for item in results:
  71. entities, relations = graph_helper.neighbor_search(item["id"], 1)
  72. max = 20 #因为类似发热这种疾病会有很多关联的疾病,所以需要防止检索范围过大,设置了上限
  73. for neighbor in entities:
  74. #有时候数据中会出现链接有的节点,但是节点数据中缺失的情况,所以这里要检查
  75. if "type" not in neighbor.keys():
  76. continue
  77. if neighbor["type"] == neighbor_type:
  78. #如果这里正好找到了要求检索的节点类型
  79. if neighbor["id"] not in neighbors_data:
  80. neighbors_data[neighbor["id"]] = {
  81. "id":neighbor["id"],
  82. "type":neighbor["type"],
  83. "count":1
  84. }
  85. else:
  86. neighbors_data[neighbor["id"]]["count"] = neighbors_data[neighbor["id"]]["count"] + 1
  87. else:
  88. #如果这里找到的节点是个疾病,那么就再检索一层,看看是否有符合要求的节点类型
  89. if neighbor["type"] == "Disease":
  90. if neighbor["id"] not in diseases:
  91. diseases[neighbor["id"]] = {
  92. "id":neighbor["id"],
  93. "type":"Disease",
  94. "count":1
  95. }
  96. else:
  97. diseases[neighbor["id"]]["count"] = diseases[neighbor["id"]]["count"] + 1
  98. disease_entities, relations = graph_helper.neighbor_search(neighbor["id"], 1)
  99. for disease_neighbor in disease_entities:
  100. #有时候数据中会出现链接有的节点,但是节点数据中缺失的情况,所以这里要检查
  101. if "type" in disease_neighbor.keys():
  102. if disease_neighbor["type"] == neighbor_type:
  103. if disease_neighbor["id"] not in neighbors_data:
  104. neighbors_data[disease_neighbor["id"]] = {
  105. "id":disease_neighbor["id"],
  106. "type":disease_neighbor["type"],
  107. "count":1
  108. }
  109. else:
  110. neighbors_data[disease_neighbor["id"]]["count"] = neighbors_data[disease_neighbor["id"]]["count"] + 1
  111. #最多搜索的范围是max个疾病
  112. max = max - 1
  113. if max == 0:
  114. break
  115. disease_data = [diseases[k] for k in diseases]
  116. disease_data = sorted(disease_data, key=lambda x:x["count"],reverse=True)
  117. data = [neighbors_data[k] for k in neighbors_data if neighbors_data[k]["type"] == "Check"]
  118. data = sorted(data, key=lambda x:x["count"],reverse=True)
  119. if len(data) > 10:
  120. data = data[:10]
  121. factor = 1.0
  122. total = 0.0
  123. for item in data:
  124. total = item["count"] * factor + total
  125. for item in data:
  126. item["count"] = item["count"] / total
  127. factor = factor * 0.9
  128. if len(disease_data) > 10:
  129. disease_data = disease_data[:10]
  130. factor = 1.0
  131. total = 0.0
  132. for item in disease_data:
  133. total = item["count"] * factor + total
  134. for item in disease_data:
  135. item["count"] = item["count"] / total
  136. factor = factor * 0.9
  137. for item in data:
  138. item["type"] = 3
  139. item["name"] = item["id"]
  140. item["rate"] = round(item["count"] * 100, 2)
  141. for item in disease_data:
  142. item["type"] = 1
  143. item["name"] = item["id"]
  144. item["rate"] = round(item["count"] * 100, 2)
  145. end_time = time.time()
  146. print(f"neighbor_search执行完成,耗时:{end_time - start_time:.2f}秒")
  147. return StandardResponse(
  148. success=True,
  149. records={"可能诊断":disease_data,"推荐检验":data},
  150. error_code = 0,
  151. error_msg=f"Found {len(results)} nodes"
  152. )
  153. except Exception as e:
  154. print(e)
  155. raise e
  156. return StandardResponse(
  157. success=False,
  158. error_code=500,
  159. error_msg=str(e)
  160. )
  161. graph_router = router