import networkx as nx import json from tabulate import tabulate import leidenalg import igraph as ig import sys,os from db.session import get_db from service.kg_node_service import KGNodeService current_path = os.getcwd() sys.path.append(current_path) RESOLUTION = 0.07 # 图谱数据缓存路径(由dump_graph_data.py生成) CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data') def load_entity_data(): print("load entity data") with open(os.path.join(CACHED_DATA_PATH,'entities_med.json'), "r", encoding="utf-8") as f: entities = json.load(f) return entities def load_relation_data(g): for i in range(30): if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")): print("load entity data", os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")) with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f: relations = json.load(f) for item in relations: if item[0] is None or item[1] is None or item[2] is None: continue #删除item[2]['weight']属性 if 'weight' in item[2]: del item[2]['weight'] g.add_edge(item[0], item[1], weight=1, **item[2]) class GraphHelper: def __init__(self): self.graph = None self.build_graph() def build_graph(self): """构建企业关系图谱""" self.graph = nx.Graph() # 加载节点数据 entities = load_entity_data() for item in entities: node_id = item[0] attrs = item[1] self.graph.add_node(node_id, **attrs) # 加载边数据 load_relation_data(self.graph) def community_report_search(self, query): """社区报告检索功能""" es_result = self.es.search_title_index("graph_community_report_index", query, 10) results = [] for item in es_result: results.append({ 'id': item["title"], 'score': item["score"], 'text': item["text"]}) return results def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None): """节点检索功能""" kg_node_service = KGNodeService(next(get_db())) es_result = kg_node_service.search_title_index("graph_entity_index", node_id, limit) results = [] for item in es_result: n = self.graph.nodes.get(item["title"]) score = item["score"] if n: results.append({ 'id': item["title"], 'score': score, **n }) return results # for n in self.graph.nodes(data=True): # match = True # if node_id and n[0] != node_id: # continue # if node_type and n[1].get('type') != node_type: # continue # if filters: # for k, v in filters.items(): # if n[1].get(k) != v: # match = False # break # if match: # results.append({ # 'id': n[0], # **n[1] # }) return results def edge_search(self, source=None, target=None, edge_type=None, min_weight=0): """边检索功能""" results = [] for u, v, data in self.graph.edges(data=True): if edge_type and data.get('type') != edge_type: continue if data.get('weight', 0) < min_weight: continue if (source and u != source and v != source) or \ (target and u != target and v != target): continue results.append({ 'source': u, 'target': v, **data }) return results def neighbor_search(self, node_id, hops=1): """近邻检索功能""" if node_id not in self.graph: return [],[] # 使用ego_graph获取指定跳数的子图 subgraph = nx.ego_graph(self.graph, node_id, radius=hops) entities = [] for n in subgraph.nodes(data=True): if n[0] == node_id: # 跳过中心节点 continue entities.append({ 'id': n[0], **n[1] }) relations = [] for edge in subgraph.edges(data=True): relations.append({ 'src_name': edge[0], 'dest_name': edge[1], **edge[2] }) return entities, relations def find_paths(self, source, target, max_depth=3): """路径查找功能""" try: shortest = nx.shortest_path(self.graph, source=source, target=target) all_paths = nx.all_simple_paths(self.graph, source=source, target=target, cutoff=max_depth) return { 'shortest_path': shortest, 'all_paths': list(all_paths) } except nx.NetworkXNoPath: return {'error': 'No path found'} except nx.NodeNotFound as e: return {'error': f'Node not found: {e}'} def format_output(self, data, fmt='text'): """格式化输出结果""" if fmt == 'json': return json.dumps(data, indent=2, ensure_ascii=False) # 文本表格格式 if isinstance(data, list): rows = [] headers = [] if not data: return "No results found" # 节点结果 if 'id' in data[0]: headers = ["ID", "Type", "Description"] rows = [[d['id'], d.get('type',''), d.get('description','')] for d in data] # 边结果 elif 'source' in data[0]: headers = ["Source", "Target", "Type", "Weight"] rows = [[d['source'], d['target'], d.get('type',''), d.get('weight',0)] for d in data] return tabulate(rows, headers=headers, tablefmt="grid") # 路径结果 if isinstance(data, dict): if 'shortest_path' in data: output = [ "Shortest Path: " + " → ".join(data['shortest_path']), "\nAll Paths:" ] for path in data['all_paths']: output.append(" → ".join(path)) return "\n".join(output) elif 'error' in data: return data['error'] return str(data) def detect_communities(self): """使用Leiden算法进行社区检测""" # 转换networkx图到igraph格式 print("convert to igraph") ig_graph = ig.Graph.from_networkx(self.graph) # 执行Leiden算法 partition = leidenalg.find_partition( ig_graph, leidenalg.CPMVertexPartition, resolution_parameter=RESOLUTION, n_iterations=2 ) # 将社区标签添加到原始图 for i, node in enumerate(self.graph.nodes()): self.graph.nodes[node]['community'] = partition.membership[i] print("convert to igraph finished") return self.graph, partition def filter_nodes(self, node_type=None, min_degree=0, attributes=None): """根据条件过滤节点""" filtered = [] for node, data in self.graph.nodes(data=True): if node_type and data.get('type') != node_type: continue if min_degree > 0 and self.graph.degree(node) < min_degree: continue if attributes: if not all(data.get(k) == v for k, v in attributes.items()): continue filtered.append({'id': node, **data}) return filtered def get_graph_statistics(self): """获取图谱统计信息""" return { 'node_count': self.graph.number_of_nodes(), 'edge_count': self.graph.number_of_edges(), 'density': nx.density(self.graph), 'components': nx.number_connected_components(self.graph), 'average_degree': sum(dict(self.graph.degree()).values()) / self.graph.number_of_nodes() } def get_community_details(self, community_id=None): """获取社区详细信息""" communities = {} for node, data in self.graph.nodes(data=True): comm = data.get('community', -1) if community_id is not None and comm != community_id: continue if comm not in communities: communities[comm] = { 'node_count': 0, 'nodes': [], 'central_nodes': [] } communities[comm]['node_count'] += 1 communities[comm]['nodes'].append(node) # 计算每个社区的中心节点 for comm in communities: subgraph = self.graph.subgraph(communities[comm]['nodes']) centrality = nx.degree_centrality(subgraph) top_nodes = sorted(centrality.items(), key=lambda x: -x[1])[:3] communities[comm]['central_nodes'] = [n[0] for n in top_nodes] return communities def find_relations(self, node_ids, relation_types=None): """查找指定节点集之间的关系""" relations = [] for u, v, data in self.graph.edges(data=True): if (u in node_ids or v in node_ids) and \ (not relation_types or data.get('type') in relation_types): relations.append({ 'source': u, 'target': v, **data }) return relations def semantic_search(self, query, top_k=5): """语义搜索(需要与文本嵌入结合)""" # 这里需要调用文本处理模块的embedding功能 # 示例实现:简单名称匹配 results = [] query_lower = query.lower() for node, data in self.graph.nodes(data=True): if query_lower in data.get('name', '').lower(): results.append({ 'id': node, 'score': 1.0, **data }) return sorted(results, key=lambda x: -x['score'])[:top_k] def get_node_details(self, node_id): """获取节点详细信息及其关联""" if node_id not in self.graph: return None details = dict(self.graph.nodes[node_id]) details['degree'] = self.graph.degree(node_id) details['neighbors'] = list(self.graph.neighbors(node_id)) details['edges'] = [] for u, v, data in self.graph.edges(node_id, data=True): details['edges'].append({ 'target': v if u == node_id else u, **data }) return details