""" 医疗知识图谱助手模块 本模块提供构建医疗知识图谱、执行社区检测、路径查找等功能 主要功能: 1. 构建医疗知识图谱 2. 支持节点/关系检索 3. 社区检测 4. 路径查找 5. 邻居分析 """ import networkx as nx import argparse import json from tabulate import tabulate import leidenalg import igraph as ig import sys,os from db.session import get_db from service.kg_node_service import KGNodeService # 当前工作路径 current_path = os.getcwd() sys.path.append(current_path) # Leiden算法社区检测的分辨率参数,控制社区划分的粒度 RESOLUTION = 0.07 # 图谱数据缓存路径(由dump_graph_data.py生成) CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data') def load_entity_data(): """ 加载实体数据 返回: list: 实体数据列表,每个元素格式为[node_id, attributes_dict] """ print("load entity data") with open(os.path.join(CACHED_DATA_PATH,'entities_med.json'), "r", encoding="utf-8") as f: entities = json.load(f) return entities def load_relation_data(g): """ 分块加载关系数据 参数: g (nx.Graph): 要添加边的NetworkX图对象 说明: 1. 支持分块加载多个关系文件(relationship_med_0.json ~ relationship_med_29.json) 2. 每个关系项格式为[source, target, {relation_attrs}] """ for i in range(89): if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")): print("load entity data", os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")) with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f: relations = json.load(f) for item in relations: # 添加带权重的边,并存储关系属性 weight = int(item[2].pop('weight', '8').replace('权重:', '')) #如果item[0]或者item[1]为空或null,则跳过 if item[0] is None or item[1] is None: continue g.add_edge(item[0], item[1], weight=weight, **item[2]) class GraphHelper: """ 医疗知识图谱助手类 功能: - 构建医疗知识图谱 - 支持节点/关系检索 - 社区检测 - 路径查找 - 邻居分析 属性: graph: NetworkX图对象,存储知识图谱 """ def __init__(self): """ 初始化方法 功能: 1. 初始化graph属性为None 2. 调用build_graph()方法构建知识图谱 """ self.graph = None self.build_graph() def build_graph(self): """构建知识图谱 步骤: 1. 初始化空图 2. 加载实体数据作为节点 3. 加载关系数据作为边 """ self.graph = nx.Graph() # 加载节点数据(疾病、症状等) entities = load_entity_data() for item in entities: node_id = item[0] attrs = item[1] self.graph.add_node(node_id, **attrs) # 加载边数据(疾病-症状关系等) load_relation_data(self.graph) def node_search2(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None): """节点检索功能""" kg_node_service = KGNodeService(next(get_db())) es_result = kg_node_service.search_title_index("graph_entity_index", node_id, limit) results = [] for item in es_result: n = self.graph.nodes.get(item["title"]) score = item["score"] if n: results.append({ 'id': item["title"], 'score': score, **n }) return results def node_search(self, node_id=None, node_type=None, filters=None): """节点检索 参数: node_id (str): 精确匹配节点ID node_type (str): 按节点类型过滤 filters (dict): 自定义属性过滤,格式为{属性名: 期望值} 返回: list: 匹配的节点列表,每个节点包含id和所有属性 """ results = [] # 遍历所有节点进行多条件过滤 for n in self.graph.nodes(data=True): match = True if node_id and n[0] != node_id: continue if node_type and n[1].get('type') != node_type: continue if filters: for k, v in filters.items(): if n[1].get(k) != v: match = False break if match: results.append({ 'id': n[0], **n[1] }) return results def neighbor_search(self, center_node, hops=2): """邻居节点检索 参数: center_node (str): 中心节点ID hops (int): 跳数(默认2跳) 返回: tuple: (邻居实体列表, 关联关系列表) 算法说明: 使用BFS算法进行层级遍历,时间复杂度O(k^d),其中k为平均度数,d为跳数 """ # 执行BFS遍历 visited = {center_node: 0} queue = [center_node] relations = [] while queue: try: current = queue.pop(0) current_hop = visited[current] if current_hop >= hops: continue # 遍历相邻节点 for neighbor in self.graph.neighbors(current): if neighbor not in visited: visited[neighbor] = current_hop + 1 queue.append(neighbor) # 记录边关系 edge_data = self.graph.get_edge_data(current, neighbor) relations.append({ 'src_name': current, 'dest_name': neighbor, **edge_data }) except Exception as e: print(f"Error processing node {current}: {str(e)}") continue # 提取邻居实体(排除中心节点) entities = [ {'id': n, **self.graph.nodes[n]} for n in visited if n != center_node ] return entities, relations def detect_communities(self): """使用Leiden算法进行社区检测 返回: tuple: (添加社区属性的图对象, 社区划分结果) 算法说明: 1. 将NetworkX图转换为igraph格式 2. 使用Leiden算法(分辨率参数RESOLUTION=0.07) 3. 将社区标签添加回原始图 4. 时间复杂度约为O(n log n) """ # 转换图格式 ig_graph = ig.Graph.from_networkx(self.graph) # 执行Leiden算法 partition = leidenalg.find_partition( ig_graph, leidenalg.CPMVertexPartition, resolution_parameter=RESOLUTION, n_iterations=2 ) # 添加社区属性 for i, node in enumerate(self.graph.nodes()): self.graph.nodes[node]['community'] = partition.membership[i] return self.graph, partition def find_paths(self, source, target, max_paths=5): """查找所有简单路径 参数: source (str): 起始节点 target (str): 目标节点 max_paths (int): 最大返回路径数 返回: dict: 包含最短路径和所有路径的结果字典 注意: 使用Yen算法寻找top k最短路径,时间复杂度O(kn(m + n log n)) """ result = {'shortest_path': [], 'all_paths': []} try: # 使用Dijkstra算法找最短路径 shortest_path = nx.shortest_path(self.graph, source, target, weight='weight') result['shortest_path'] = shortest_path # 使用Yen算法找top k路径 all_paths = list(nx.shortest_simple_paths(self.graph, source, target, weight='weight'))[:max_paths] result['all_paths'] = all_paths except nx.NetworkXNoPath: pass return result