import networkx as nx import json from tabulate import tabulate import leidenalg import igraph as ig import sys,os from config.site import SiteConfig from utils.es import ElasticsearchOperations from typing import List config = SiteConfig() current_path = os.getcwd() sys.path.append(current_path) RESOLUTION = 0.07 #图谱数据的缓存路径,数据从dump_graph_data.py生成 CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH") def load_entity_data(): print("load entity data") with open(f"{CACHED_DATA_PATH}\\entities_med.json", "r", encoding="utf-8") as f: entities = json.load(f) return entities def load_relation_data(g): for i in range(30): if os.path.exists(f"{CACHED_DATA_PATH}\\relationship_med_{i}.json"): print("load entity data", f"{CACHED_DATA_PATH}\\relationship_med_{i}.json") with open(f"{CACHED_DATA_PATH}\\relationship_med_{i}.json", "r", encoding="utf-8") as f: relations = json.load(f) for item in relations: g.add_edge(item[0], item[1], weight=1, **item[2]) class GraphHelper: def __init__(self): self.graph = None self.build_graph() self.es = ElasticsearchOperations() def build_graph(self): """构建企业关系图谱""" self.graph = nx.DiGraph() # 加载节点数据 entities = load_entity_data() for item in entities: node_id = item[0] attrs = item[1] self.graph.add_node(node_id, **attrs) # 加载边数据 load_relation_data(self.graph) def community_report_search(self, query): """社区报告检索功能""" es_result = self.es.search_title_index("graph_community_report_index", query, 10) results = [] for item in es_result: results.append({ 'id': item["title"], 'score': item["score"], 'text': item["text"]}) return results def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None): """节点检索功能""" es_result = self.es.search_title_index("graph_entity_index", node_id, limit) results = [] for item in es_result: n = self.graph.nodes.get(item["title"]) score = item["score"] if n: results.append({ 'id': item["title"], 'score': score, **n }) # for n in self.graph.nodes(data=True): # match = True # if node_id and n[0] != node_id: # continue # if node_type and n[1].get('type') != node_type: # continue # if filters: # for k, v in filters.items(): # if n[1].get(k) != v: # match = False # break # if match: # results.append({ # 'id': n[0], # **n[1] # }) return results def edge_search(self, source=None, target=None, edge_type=None, min_weight=0): """边检索功能""" results = [] for u, v, data in self.graph.edges(data=True): if edge_type and data.get('type') != edge_type: continue if data.get('weight', 0) < min_weight: continue if (source and u != source and v != source) or \ (target and u != target and v != target): continue results.append({ 'source': u, 'target': v, **data }) return results def neighbor_search(self, node_id, hops=1): """近邻检索功能""" if node_id not in self.graph: return [],[] # 使用ego_graph获取指定跳数的子图 subgraph = nx.ego_graph(self.graph, node_id, radius=hops) entities = [] for n in subgraph.nodes(data=True): if n[0] == node_id: # 跳过中心节点 continue entities.append({ 'id': n[0], **n[1] }) relations = [] for edge in subgraph.edges(data=True): relations.append({ 'src_name': edge[0], 'dest_name': edge[1], **edge[2] }) return entities, relations def find_paths(self, source, target, max_depth=3): """路径查找功能""" try: shortest = nx.shortest_path(self.graph, source=source, target=target) all_paths = nx.all_simple_paths(self.graph, source=source, target=target, cutoff=max_depth) return { 'shortest_path': shortest, 'all_paths': list(all_paths) } except nx.NetworkXNoPath: return {'error': 'No path found'} except nx.NodeNotFound as e: return {'error': f'Node not found: {e}'} def format_output(self, data, fmt='text'): """格式化输出结果""" if fmt == 'json': return json.dumps(data, indent=2, ensure_ascii=False) # 文本表格格式 if isinstance(data, list): rows = [] headers = [] if not data: return "No results found" # 节点结果 if 'id' in data[0]: headers = ["ID", "Type", "Description"] rows = [[d['id'], d.get('type',''), d.get('description','')] for d in data] # 边结果 elif 'source' in data[0]: headers = ["Source", "Target", "Type", "Weight"] rows = [[d['source'], d['target'], d.get('type',''), d.get('weight',0)] for d in data] return tabulate(rows, headers=headers, tablefmt="grid") # 路径结果 if isinstance(data, dict): if 'shortest_path' in data: output = [ "Shortest Path: " + " → ".join(data['shortest_path']), "\nAll Paths:" ] for path in data['all_paths']: output.append(" → ".join(path)) return "\n".join(output) elif 'error' in data: return data['error'] return str(data) def detect_communities(self): """使用Leiden算法进行社区检测""" # 转换networkx图到igraph格式 print("convert to igraph") ig_graph = ig.Graph.from_networkx(self.graph) # 执行Leiden算法 partition = leidenalg.find_partition( ig_graph, leidenalg.CPMVertexPartition, resolution_parameter=RESOLUTION, n_iterations=2 ) # 将社区标签添加到原始图 for i, node in enumerate(self.graph.nodes()): self.graph.nodes[node]['community'] = partition.membership[i] print("convert to igraph finished") return self.graph, partition def filter_nodes(self, node_type=None, min_degree=0, attributes=None): """根据条件过滤节点""" filtered = [] for node, data in self.graph.nodes(data=True): if node_type and data.get('type') != node_type: continue if min_degree > 0 and self.graph.degree(node) < min_degree: continue if attributes: if not all(data.get(k) == v for k, v in attributes.items()): continue filtered.append({'id': node, **data}) return filtered def get_graph_statistics(self): """获取图谱统计信息""" return { 'node_count': self.graph.number_of_nodes(), 'edge_count': self.graph.number_of_edges(), 'density': nx.density(self.graph), 'components': nx.number_connected_components(self.graph), 'average_degree': sum(dict(self.graph.degree()).values()) / self.graph.number_of_nodes() } def get_community_details(self, community_id=None): """获取社区详细信息""" communities = {} for node, data in self.graph.nodes(data=True): comm = data.get('community', -1) if community_id is not None and comm != community_id: continue if comm not in communities: communities[comm] = { 'node_count': 0, 'nodes': [], 'central_nodes': [] } communities[comm]['node_count'] += 1 communities[comm]['nodes'].append(node) # 计算每个社区的中心节点 for comm in communities: subgraph = self.graph.subgraph(communities[comm]['nodes']) centrality = nx.degree_centrality(subgraph) top_nodes = sorted(centrality.items(), key=lambda x: -x[1])[:3] communities[comm]['central_nodes'] = [n[0] for n in top_nodes] return communities def find_relations(self, node_ids, relation_types=None): """查找指定节点集之间的关系""" relations = [] for u, v, data in self.graph.edges(data=True): if (u in node_ids or v in node_ids) and \ (not relation_types or data.get('type') in relation_types): relations.append({ 'source': u, 'target': v, **data }) return relations def semantic_search(self, query, top_k=5): """语义搜索(需要与文本嵌入结合)""" # 这里需要调用文本处理模块的embedding功能 # 示例实现:简单名称匹配 results = [] query_lower = query.lower() for node, data in self.graph.nodes(data=True): if query_lower in data.get('name', '').lower(): results.append({ 'id': node, 'score': 1.0, **data }) return sorted(results, key=lambda x: -x['score'])[:top_k] def get_node_details(self, node_id): """获取节点详细信息及其关联""" if node_id not in self.graph: return None details = dict(self.graph.nodes[node_id]) details['degree'] = self.graph.degree(node_id) details['neighbors'] = list(self.graph.neighbors(node_id)) details['edges'] = [] for u, v, data in self.graph.edges(node_id, data=True): details['edges'].append({ 'target': v if u == node_id else u, **data }) return details