graph_helper2.bak 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. """
  2. 医疗知识图谱助手模块
  3. 本模块提供构建医疗知识图谱、执行社区检测、路径查找等功能
  4. 主要功能:
  5. 1. 构建医疗知识图谱
  6. 2. 支持节点/关系检索
  7. 3. 社区检测
  8. 4. 路径查找
  9. 5. 邻居分析
  10. """
  11. import networkx as nx
  12. import argparse
  13. import json
  14. from tabulate import tabulate
  15. import leidenalg
  16. import igraph as ig
  17. import sys,os
  18. from db.session import get_db
  19. from service.kg_node_service import KGNodeService
  20. # 当前工作路径
  21. current_path = os.getcwd()
  22. sys.path.append(current_path)
  23. # Leiden算法社区检测的分辨率参数,控制社区划分的粒度
  24. RESOLUTION = 0.07
  25. # 图谱数据缓存路径(由dump_graph_data.py生成)
  26. CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data')
  27. def load_entity_data():
  28. """
  29. 加载实体数据
  30. 返回:
  31. list: 实体数据列表,每个元素格式为[node_id, attributes_dict]
  32. """
  33. print("load entity data")
  34. with open(os.path.join(CACHED_DATA_PATH,'entities_med.json'), "r", encoding="utf-8") as f:
  35. entities = json.load(f)
  36. return entities
  37. def load_relation_data(g):
  38. """
  39. 分块加载关系数据
  40. 参数:
  41. g (nx.Graph): 要添加边的NetworkX图对象
  42. 说明:
  43. 1. 支持分块加载多个关系文件(relationship_med_0.json ~ relationship_med_29.json)
  44. 2. 每个关系项格式为[source, target, {relation_attrs}]
  45. """
  46. for i in range(89):
  47. if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
  48. print("load entity data", os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"))
  49. with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f:
  50. relations = json.load(f)
  51. for item in relations:
  52. # 添加带权重的边,并存储关系属性
  53. weight = int(item[2].pop('weight', '8').replace('权重:', ''))
  54. #如果item[0]或者item[1]为空或null,则跳过
  55. if item[0] is None or item[1] is None:
  56. continue
  57. g.add_edge(item[0], item[1], weight=weight, **item[2])
  58. class GraphHelper:
  59. """
  60. 医疗知识图谱助手类
  61. 功能:
  62. - 构建医疗知识图谱
  63. - 支持节点/关系检索
  64. - 社区检测
  65. - 路径查找
  66. - 邻居分析
  67. 属性:
  68. graph: NetworkX图对象,存储知识图谱
  69. """
  70. def __init__(self):
  71. """
  72. 初始化方法
  73. 功能:
  74. 1. 初始化graph属性为None
  75. 2. 调用build_graph()方法构建知识图谱
  76. """
  77. self.graph = None
  78. self.build_graph()
  79. def build_graph(self):
  80. """构建知识图谱
  81. 步骤:
  82. 1. 初始化空图
  83. 2. 加载实体数据作为节点
  84. 3. 加载关系数据作为边
  85. """
  86. self.graph = nx.Graph()
  87. # 加载节点数据(疾病、症状等)
  88. entities = load_entity_data()
  89. for item in entities:
  90. node_id = item[0]
  91. attrs = item[1]
  92. self.graph.add_node(node_id, **attrs)
  93. # 加载边数据(疾病-症状关系等)
  94. load_relation_data(self.graph)
  95. def node_search2(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
  96. """节点检索功能"""
  97. kg_node_service = KGNodeService(next(get_db()))
  98. es_result = kg_node_service.search_title_index("graph_entity_index", node_id, limit)
  99. results = []
  100. for item in es_result:
  101. n = self.graph.nodes.get(item["title"])
  102. score = item["score"]
  103. if n:
  104. results.append({
  105. 'id': item["title"],
  106. 'score': score,
  107. **n
  108. })
  109. return results
  110. def node_search(self, node_id=None, node_type=None, filters=None):
  111. """节点检索
  112. 参数:
  113. node_id (str): 精确匹配节点ID
  114. node_type (str): 按节点类型过滤
  115. filters (dict): 自定义属性过滤,格式为{属性名: 期望值}
  116. 返回:
  117. list: 匹配的节点列表,每个节点包含id和所有属性
  118. """
  119. results = []
  120. # 遍历所有节点进行多条件过滤
  121. for n in self.graph.nodes(data=True):
  122. match = True
  123. if node_id and n[0] != node_id:
  124. continue
  125. if node_type and n[1].get('type') != node_type:
  126. continue
  127. if filters:
  128. for k, v in filters.items():
  129. if n[1].get(k) != v:
  130. match = False
  131. break
  132. if match:
  133. results.append({
  134. 'id': n[0],
  135. **n[1]
  136. })
  137. return results
  138. def neighbor_search(self, center_node, hops=2):
  139. """邻居节点检索
  140. 参数:
  141. center_node (str): 中心节点ID
  142. hops (int): 跳数(默认2跳)
  143. 返回:
  144. tuple: (邻居实体列表, 关联关系列表)
  145. 算法说明:
  146. 使用BFS算法进行层级遍历,时间复杂度O(k^d),其中k为平均度数,d为跳数
  147. """
  148. # 执行BFS遍历
  149. visited = {center_node: 0}
  150. queue = [center_node]
  151. relations = []
  152. while queue:
  153. try:
  154. current = queue.pop(0)
  155. current_hop = visited[current]
  156. if current_hop >= hops:
  157. continue
  158. # 遍历相邻节点
  159. for neighbor in self.graph.neighbors(current):
  160. if neighbor not in visited:
  161. visited[neighbor] = current_hop + 1
  162. queue.append(neighbor)
  163. # 记录边关系
  164. edge_data = self.graph.get_edge_data(current, neighbor)
  165. relations.append({
  166. 'src_name': current,
  167. 'dest_name': neighbor,
  168. **edge_data
  169. })
  170. except Exception as e:
  171. print(f"Error processing node {current}: {str(e)}")
  172. continue
  173. # 提取邻居实体(排除中心节点)
  174. entities = [
  175. {'id': n, **self.graph.nodes[n]}
  176. for n in visited if n != center_node
  177. ]
  178. return entities, relations
  179. def detect_communities(self):
  180. """使用Leiden算法进行社区检测
  181. 返回:
  182. tuple: (添加社区属性的图对象, 社区划分结果)
  183. 算法说明:
  184. 1. 将NetworkX图转换为igraph格式
  185. 2. 使用Leiden算法(分辨率参数RESOLUTION=0.07)
  186. 3. 将社区标签添加回原始图
  187. 4. 时间复杂度约为O(n log n)
  188. """
  189. # 转换图格式
  190. ig_graph = ig.Graph.from_networkx(self.graph)
  191. # 执行Leiden算法
  192. partition = leidenalg.find_partition(
  193. ig_graph,
  194. leidenalg.CPMVertexPartition,
  195. resolution_parameter=RESOLUTION,
  196. n_iterations=2
  197. )
  198. # 添加社区属性
  199. for i, node in enumerate(self.graph.nodes()):
  200. self.graph.nodes[node]['community'] = partition.membership[i]
  201. return self.graph, partition
  202. def find_paths(self, source, target, max_paths=5):
  203. """查找所有简单路径
  204. 参数:
  205. source (str): 起始节点
  206. target (str): 目标节点
  207. max_paths (int): 最大返回路径数
  208. 返回:
  209. dict: 包含最短路径和所有路径的结果字典
  210. 注意:
  211. 使用Yen算法寻找top k最短路径,时间复杂度O(kn(m + n log n))
  212. """
  213. result = {'shortest_path': [], 'all_paths': []}
  214. try:
  215. # 使用Dijkstra算法找最短路径
  216. shortest_path = nx.shortest_path(self.graph, source, target, weight='weight')
  217. result['shortest_path'] = shortest_path
  218. # 使用Yen算法找top k路径
  219. all_paths = list(nx.shortest_simple_paths(self.graph, source, target, weight='weight'))[:max_paths]
  220. result['all_paths'] = all_paths
  221. except nx.NetworkXNoPath:
  222. pass
  223. return result