123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273 |
- """
- 医疗知识图谱助手模块
- 本模块提供构建医疗知识图谱、执行社区检测、路径查找等功能
- 主要功能:
- 1. 构建医疗知识图谱
- 2. 支持节点/关系检索
- 3. 社区检测
- 4. 路径查找
- 5. 邻居分析
- """
- import networkx as nx
- import argparse
- import json
- from tabulate import tabulate
- import leidenalg
- import igraph as ig
- import sys,os
- from db.session import get_db
- from service.kg_node_service import KGNodeService
- # 当前工作路径
- current_path = os.getcwd()
- sys.path.append(current_path)
- # Leiden算法社区检测的分辨率参数,控制社区划分的粒度
- RESOLUTION = 0.07
- # 图谱数据缓存路径(由dump_graph_data.py生成)
- CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data')
- def load_entity_data():
- """
- 加载实体数据
-
- 返回:
- list: 实体数据列表,每个元素格式为[node_id, attributes_dict]
- """
- print("load entity data")
- with open(os.path.join(CACHED_DATA_PATH,'entities_med.json'), "r", encoding="utf-8") as f:
- entities = json.load(f)
- return entities
- def load_relation_data(g):
- """
- 分块加载关系数据
-
- 参数:
- g (nx.Graph): 要添加边的NetworkX图对象
-
- 说明:
- 1. 支持分块加载多个关系文件(relationship_med_0.json ~ relationship_med_29.json)
- 2. 每个关系项格式为[source, target, {relation_attrs}]
- """
- for i in range(89):
- if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
- print("load entity data", os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"))
- with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f:
-
- relations = json.load(f)
- for item in relations:
- # 添加带权重的边,并存储关系属性
- weight = int(item[2].pop('weight', '8').replace('权重:', ''))
- #如果item[0]或者item[1]为空或null,则跳过
- if item[0] is None or item[1] is None:
- continue
- g.add_edge(item[0], item[1], weight=weight, **item[2])
- class GraphHelper:
- """
- 医疗知识图谱助手类
-
- 功能:
- - 构建医疗知识图谱
- - 支持节点/关系检索
- - 社区检测
- - 路径查找
- - 邻居分析
-
- 属性:
- graph: NetworkX图对象,存储知识图谱
- """
-
- def __init__(self):
- """
- 初始化方法
-
- 功能:
- 1. 初始化graph属性为None
- 2. 调用build_graph()方法构建知识图谱
- """
- self.graph = None
- self.build_graph()
- def build_graph(self):
- """构建知识图谱
-
- 步骤:
- 1. 初始化空图
- 2. 加载实体数据作为节点
- 3. 加载关系数据作为边
- """
- self.graph = nx.Graph()
-
- # 加载节点数据(疾病、症状等)
- entities = load_entity_data()
- for item in entities:
- node_id = item[0]
- attrs = item[1]
- self.graph.add_node(node_id, **attrs)
-
- # 加载边数据(疾病-症状关系等)
- load_relation_data(self.graph)
- def node_search2(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
- """节点检索功能"""
- kg_node_service = KGNodeService(next(get_db()))
- es_result = kg_node_service.search_title_index("graph_entity_index", node_id, limit)
- results = []
- for item in es_result:
- n = self.graph.nodes.get(item["title"])
- score = item["score"]
- if n:
- results.append({
- 'id': item["title"],
- 'score': score,
- **n
- })
- return results
- def node_search(self, node_id=None, node_type=None, filters=None):
- """节点检索
-
- 参数:
- node_id (str): 精确匹配节点ID
- node_type (str): 按节点类型过滤
- filters (dict): 自定义属性过滤,格式为{属性名: 期望值}
-
- 返回:
- list: 匹配的节点列表,每个节点包含id和所有属性
- """
- results = []
-
- # 遍历所有节点进行多条件过滤
- for n in self.graph.nodes(data=True):
- match = True
- if node_id and n[0] != node_id:
- continue
- if node_type and n[1].get('type') != node_type:
- continue
- if filters:
- for k, v in filters.items():
- if n[1].get(k) != v:
- match = False
- break
- if match:
- results.append({
- 'id': n[0],
- **n[1]
- })
- return results
- def neighbor_search(self, center_node, hops=2):
- """邻居节点检索
-
- 参数:
- center_node (str): 中心节点ID
- hops (int): 跳数(默认2跳)
-
- 返回:
- tuple: (邻居实体列表, 关联关系列表)
-
- 算法说明:
- 使用BFS算法进行层级遍历,时间复杂度O(k^d),其中k为平均度数,d为跳数
- """
- # 执行BFS遍历
- visited = {center_node: 0}
- queue = [center_node]
- relations = []
-
- while queue:
- try:
- current = queue.pop(0)
- current_hop = visited[current]
-
- if current_hop >= hops:
- continue
-
- # 遍历相邻节点
- for neighbor in self.graph.neighbors(current):
- if neighbor not in visited:
- visited[neighbor] = current_hop + 1
- queue.append(neighbor)
-
- # 记录边关系
- edge_data = self.graph.get_edge_data(current, neighbor)
- relations.append({
- 'src_name': current,
- 'dest_name': neighbor,
- **edge_data
- })
- except Exception as e:
- print(f"Error processing node {current}: {str(e)}")
- continue
-
- # 提取邻居实体(排除中心节点)
- entities = [
- {'id': n, **self.graph.nodes[n]}
- for n in visited if n != center_node
- ]
-
- return entities, relations
- def detect_communities(self):
- """使用Leiden算法进行社区检测
-
- 返回:
- tuple: (添加社区属性的图对象, 社区划分结果)
-
- 算法说明:
- 1. 将NetworkX图转换为igraph格式
- 2. 使用Leiden算法(分辨率参数RESOLUTION=0.07)
- 3. 将社区标签添加回原始图
- 4. 时间复杂度约为O(n log n)
- """
- # 转换图格式
- ig_graph = ig.Graph.from_networkx(self.graph)
-
- # 执行Leiden算法
- partition = leidenalg.find_partition(
- ig_graph,
- leidenalg.CPMVertexPartition,
- resolution_parameter=RESOLUTION,
- n_iterations=2
- )
-
- # 添加社区属性
- for i, node in enumerate(self.graph.nodes()):
- self.graph.nodes[node]['community'] = partition.membership[i]
-
- return self.graph, partition
- def find_paths(self, source, target, max_paths=5):
- """查找所有简单路径
-
- 参数:
- source (str): 起始节点
- target (str): 目标节点
- max_paths (int): 最大返回路径数
-
- 返回:
- dict: 包含最短路径和所有路径的结果字典
-
- 注意:
- 使用Yen算法寻找top k最短路径,时间复杂度O(kn(m + n log n))
- """
- result = {'shortest_path': [], 'all_paths': []}
-
- try:
- # 使用Dijkstra算法找最短路径
- shortest_path = nx.shortest_path(self.graph, source, target, weight='weight')
- result['shortest_path'] = shortest_path
-
- # 使用Yen算法找top k路径
- all_paths = list(nx.shortest_simple_paths(self.graph, source, target, weight='weight'))[:max_paths]
- result['all_paths'] = all_paths
- except nx.NetworkXNoPath:
- pass
-
- return result
|