123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297 |
- import networkx as nx
- import json
- from tabulate import tabulate
- import leidenalg
- import igraph as ig
- import sys,os
- from db.session import get_db
- from service.kg_node_service import KGNodeService
- current_path = os.getcwd()
- sys.path.append(current_path)
- RESOLUTION = 0.07
- # 图谱数据缓存路径(由dump_graph_data.py生成)
- CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data')
- def load_entity_data():
- print("load entity data")
- if not os.path.exists(os.path.join(CACHED_DATA_PATH,'entities_med.json')):
- return []
- with open(os.path.join(CACHED_DATA_PATH,'entities_med.json'), "r", encoding="utf-8") as f:
- entities = json.load(f)
- return entities
- def load_relation_data(g):
- for i in range(89):
- if not os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
- continue
- if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
- print("load entity data", os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"))
- with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f:
- relations = json.load(f)
- for item in relations:
- if item[0] is None or item[1] is None or item[2] is None:
- continue
- g.add_edge(item[0], item[1], weight=1, **item[2])
-
- class GraphHelper:
- def __init__(self):
- self.graph = None
- self.build_graph()
- def build_graph(self):
- """构建企业关系图谱"""
- self.graph = nx.Graph()
-
- # 加载节点数据
- entities = load_entity_data()
- for item in entities:
- node_id = item[0]
- attrs = item[1]
- self.graph.add_node(node_id, **attrs)
-
- # 加载边数据
- load_relation_data(self.graph)
-
- def community_report_search(self, query):
- """社区报告检索功能"""
- es_result = self.es.search_title_index("graph_community_report_index", query, 10)
- results = []
- for item in es_result:
- results.append({
- 'id': item["title"],
- 'score': item["score"],
- 'text': item["text"]})
- return results
-
- def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
- """节点检索功能"""
- kg_node_service = KGNodeService(next(get_db()))
- es_result = kg_node_service.search_title_index("graph_entity_index", node_id, limit)
- results = []
- for item in es_result:
- n = self.graph.nodes.get(item["title"])
- score = item["score"]
- if n:
- results.append({
- 'id': item["title"],
- 'score': score,
- **n
- })
- return results
- def edge_search(self, source=None, target=None, edge_type=None, min_weight=0):
- """边检索功能"""
- results = []
-
- for u, v, data in self.graph.edges(data=True):
- if edge_type and data.get('type') != edge_type:
- continue
- if data.get('weight', 0) < min_weight:
- continue
- if (source and u != source and v != source) or \
- (target and u != target and v != target):
- continue
-
- results.append({
- 'source': u,
- 'target': v,
- **data
- })
- return results
- def neighbor_search(self, node_id, hops=1):
- """近邻检索功能"""
- if node_id not in self.graph:
- return [],[]
-
- # 使用ego_graph获取指定跳数的子图
- subgraph = nx.ego_graph(self.graph, node_id, radius=hops)
-
- entities = []
- for n in subgraph.nodes(data=True):
- if n[0] == node_id: # 跳过中心节点
- continue
- entities.append({
- 'id': n[0],
- **n[1]
- })
- relations = []
- for edge in subgraph.edges(data=True):
- relations.append({
- 'src_name': edge[0],
- 'dest_name': edge[1],
- **edge[2]
- })
- return entities, relations
- def find_paths(self, source, target, max_depth=3):
- """路径查找功能"""
- try:
- shortest = nx.shortest_path(self.graph, source=source, target=target)
- all_paths = nx.all_simple_paths(self.graph, source=source, target=target, cutoff=max_depth)
- return {
- 'shortest_path': shortest,
- 'all_paths': list(all_paths)
- }
- except nx.NetworkXNoPath:
- return {'error': 'No path found'}
- except nx.NodeNotFound as e:
- return {'error': f'Node not found: {e}'}
- def format_output(self, data, fmt='text'):
- """格式化输出结果"""
- if fmt == 'json':
- return json.dumps(data, indent=2, ensure_ascii=False)
-
- # 文本表格格式
- if isinstance(data, list):
- rows = []
- headers = []
- if not data:
- return "No results found"
- # 节点结果
- if 'id' in data[0]:
- headers = ["ID", "Type", "Description"]
- rows = [[d['id'], d.get('type',''), d.get('description','')] for d in data]
- # 边结果
- elif 'source' in data[0]:
- headers = ["Source", "Target", "Type", "Weight"]
- rows = [[d['source'], d['target'], d.get('type',''), d.get('weight',0)] for d in data]
- return tabulate(rows, headers=headers, tablefmt="grid")
-
- # 路径结果
- if isinstance(data, dict):
- if 'shortest_path' in data:
- output = [
- "Shortest Path: " + " → ".join(data['shortest_path']),
- "\nAll Paths:"
- ]
- for path in data['all_paths']:
- output.append(" → ".join(path))
- return "\n".join(output)
- elif 'error' in data:
- return data['error']
-
- return str(data)
-
- def detect_communities(self):
- """使用Leiden算法进行社区检测"""
- # 转换networkx图到igraph格式
- print("convert to igraph")
- ig_graph = ig.Graph.from_networkx(self.graph)
-
- # 执行Leiden算法
- partition = leidenalg.find_partition(
- ig_graph,
- leidenalg.CPMVertexPartition,
- resolution_parameter=RESOLUTION,
- n_iterations=2
- )
-
- # 将社区标签添加到原始图
- for i, node in enumerate(self.graph.nodes()):
- self.graph.nodes[node]['community'] = partition.membership[i]
-
- print("convert to igraph finished")
- return self.graph, partition
- def filter_nodes(self, node_type=None, min_degree=0, attributes=None):
- """根据条件过滤节点"""
- filtered = []
- for node, data in self.graph.nodes(data=True):
- if node_type and data.get('type') != node_type:
- continue
- if min_degree > 0 and self.graph.degree(node) < min_degree:
- continue
- if attributes:
- if not all(data.get(k) == v for k, v in attributes.items()):
- continue
- filtered.append({'id': node, **data})
- return filtered
- def get_graph_statistics(self):
- """获取图谱统计信息"""
- return {
- 'node_count': self.graph.number_of_nodes(),
- 'edge_count': self.graph.number_of_edges(),
- 'density': nx.density(self.graph),
- 'components': nx.number_connected_components(self.graph),
- 'average_degree': sum(dict(self.graph.degree()).values()) / self.graph.number_of_nodes()
- }
- def get_community_details(self, community_id=None):
- """获取社区详细信息"""
- communities = {}
- for node, data in self.graph.nodes(data=True):
- comm = data.get('community', -1)
- if community_id is not None and comm != community_id:
- continue
- if comm not in communities:
- communities[comm] = {
- 'node_count': 0,
- 'nodes': [],
- 'central_nodes': []
- }
- communities[comm]['node_count'] += 1
- communities[comm]['nodes'].append(node)
-
- # 计算每个社区的中心节点
- for comm in communities:
- subgraph = self.graph.subgraph(communities[comm]['nodes'])
- centrality = nx.degree_centrality(subgraph)
- top_nodes = sorted(centrality.items(), key=lambda x: -x[1])[:3]
- communities[comm]['central_nodes'] = [n[0] for n in top_nodes]
-
- return communities
- def find_relations(self, node_ids, relation_types=None):
- """查找指定节点集之间的关系"""
- relations = []
- for u, v, data in self.graph.edges(data=True):
- if (u in node_ids or v in node_ids) and \
- (not relation_types or data.get('type') in relation_types):
- relations.append({
- 'source': u,
- 'target': v,
- **data
- })
- return relations
- def semantic_search(self, query, top_k=5):
- """语义搜索(需要与文本嵌入结合)"""
- # 这里需要调用文本处理模块的embedding功能
- # 示例实现:简单名称匹配
- results = []
- query_lower = query.lower()
- for node, data in self.graph.nodes(data=True):
- if query_lower in data.get('name', '').lower():
- results.append({
- 'id': node,
- 'score': 1.0,
- **data
- })
- return sorted(results, key=lambda x: -x['score'])[:top_k]
- def get_node_details(self, node_id):
- """获取节点详细信息及其关联"""
- if node_id not in self.graph:
- return None
- details = dict(self.graph.nodes[node_id])
- details['degree'] = self.graph.degree(node_id)
- details['neighbors'] = list(self.graph.neighbors(node_id))
- details['edges'] = []
- for u, v, data in self.graph.edges(node_id, data=True):
- details['edges'].append({
- 'target': v if u == node_id else u,
- **data
- })
- return details
|