graph_helper.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. import networkx as nx
  2. import json
  3. from tabulate import tabulate
  4. import leidenalg
  5. import igraph as ig
  6. import sys,os
  7. from db.session import get_db
  8. from service.kg_node_service import KGNodeService
  9. current_path = os.getcwd()
  10. sys.path.append(current_path)
  11. RESOLUTION = 0.07
  12. # 图谱数据缓存路径(由dump_graph_data.py生成)
  13. CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data')
  14. def load_entity_data():
  15. print("load entity data")
  16. if not os.path.exists(os.path.join(CACHED_DATA_PATH,'entities_med.json')):
  17. return []
  18. with open(os.path.join(CACHED_DATA_PATH,'entities_med.json'), "r", encoding="utf-8") as f:
  19. entities = json.load(f)
  20. return entities
  21. def load_relation_data(g):
  22. for i in range(89):
  23. if not os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
  24. continue
  25. if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
  26. print("load entity data", os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"))
  27. with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f:
  28. relations = json.load(f)
  29. for item in relations:
  30. if item[0] is None or item[1] is None or item[2] is None:
  31. continue
  32. g.add_edge(item[0], item[1], weight=1, **item[2])
  33. class GraphHelper:
  34. def __init__(self):
  35. self.graph = None
  36. self.build_graph()
  37. def build_graph(self):
  38. """构建企业关系图谱"""
  39. self.graph = nx.Graph()
  40. # 加载节点数据
  41. entities = load_entity_data()
  42. for item in entities:
  43. node_id = item[0]
  44. attrs = item[1]
  45. self.graph.add_node(node_id, **attrs)
  46. # 加载边数据
  47. load_relation_data(self.graph)
  48. def community_report_search(self, query):
  49. """社区报告检索功能"""
  50. es_result = self.es.search_title_index("graph_community_report_index", query, 10)
  51. results = []
  52. for item in es_result:
  53. results.append({
  54. 'id': item["title"],
  55. 'score': item["score"],
  56. 'text': item["text"]})
  57. return results
  58. def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
  59. """节点检索功能"""
  60. kg_node_service = KGNodeService(next(get_db()))
  61. es_result = kg_node_service.search_title_index("graph_entity_index", node_id, limit)
  62. results = []
  63. for item in es_result:
  64. n = self.graph.nodes.get(item["title"])
  65. score = item["score"]
  66. if n:
  67. results.append({
  68. 'id': item["title"],
  69. 'score': score,
  70. **n
  71. })
  72. return results
  73. def edge_search(self, source=None, target=None, edge_type=None, min_weight=0):
  74. """边检索功能"""
  75. results = []
  76. for u, v, data in self.graph.edges(data=True):
  77. if edge_type and data.get('type') != edge_type:
  78. continue
  79. if data.get('weight', 0) < min_weight:
  80. continue
  81. if (source and u != source and v != source) or \
  82. (target and u != target and v != target):
  83. continue
  84. results.append({
  85. 'source': u,
  86. 'target': v,
  87. **data
  88. })
  89. return results
  90. def neighbor_search(self, node_id, hops=1):
  91. """近邻检索功能"""
  92. if node_id not in self.graph:
  93. return [],[]
  94. # 使用ego_graph获取指定跳数的子图
  95. subgraph = nx.ego_graph(self.graph, node_id, radius=hops)
  96. entities = []
  97. for n in subgraph.nodes(data=True):
  98. if n[0] == node_id: # 跳过中心节点
  99. continue
  100. entities.append({
  101. 'id': n[0],
  102. **n[1]
  103. })
  104. relations = []
  105. for edge in subgraph.edges(data=True):
  106. relations.append({
  107. 'src_name': edge[0],
  108. 'dest_name': edge[1],
  109. **edge[2]
  110. })
  111. return entities, relations
  112. def find_paths(self, source, target, max_depth=3):
  113. """路径查找功能"""
  114. try:
  115. shortest = nx.shortest_path(self.graph, source=source, target=target)
  116. all_paths = nx.all_simple_paths(self.graph, source=source, target=target, cutoff=max_depth)
  117. return {
  118. 'shortest_path': shortest,
  119. 'all_paths': list(all_paths)
  120. }
  121. except nx.NetworkXNoPath:
  122. return {'error': 'No path found'}
  123. except nx.NodeNotFound as e:
  124. return {'error': f'Node not found: {e}'}
  125. def format_output(self, data, fmt='text'):
  126. """格式化输出结果"""
  127. if fmt == 'json':
  128. return json.dumps(data, indent=2, ensure_ascii=False)
  129. # 文本表格格式
  130. if isinstance(data, list):
  131. rows = []
  132. headers = []
  133. if not data:
  134. return "No results found"
  135. # 节点结果
  136. if 'id' in data[0]:
  137. headers = ["ID", "Type", "Description"]
  138. rows = [[d['id'], d.get('type',''), d.get('description','')] for d in data]
  139. # 边结果
  140. elif 'source' in data[0]:
  141. headers = ["Source", "Target", "Type", "Weight"]
  142. rows = [[d['source'], d['target'], d.get('type',''), d.get('weight',0)] for d in data]
  143. return tabulate(rows, headers=headers, tablefmt="grid")
  144. # 路径结果
  145. if isinstance(data, dict):
  146. if 'shortest_path' in data:
  147. output = [
  148. "Shortest Path: " + " → ".join(data['shortest_path']),
  149. "\nAll Paths:"
  150. ]
  151. for path in data['all_paths']:
  152. output.append(" → ".join(path))
  153. return "\n".join(output)
  154. elif 'error' in data:
  155. return data['error']
  156. return str(data)
  157. def detect_communities(self):
  158. """使用Leiden算法进行社区检测"""
  159. # 转换networkx图到igraph格式
  160. print("convert to igraph")
  161. ig_graph = ig.Graph.from_networkx(self.graph)
  162. # 执行Leiden算法
  163. partition = leidenalg.find_partition(
  164. ig_graph,
  165. leidenalg.CPMVertexPartition,
  166. resolution_parameter=RESOLUTION,
  167. n_iterations=2
  168. )
  169. # 将社区标签添加到原始图
  170. for i, node in enumerate(self.graph.nodes()):
  171. self.graph.nodes[node]['community'] = partition.membership[i]
  172. print("convert to igraph finished")
  173. return self.graph, partition
  174. def filter_nodes(self, node_type=None, min_degree=0, attributes=None):
  175. """根据条件过滤节点"""
  176. filtered = []
  177. for node, data in self.graph.nodes(data=True):
  178. if node_type and data.get('type') != node_type:
  179. continue
  180. if min_degree > 0 and self.graph.degree(node) < min_degree:
  181. continue
  182. if attributes:
  183. if not all(data.get(k) == v for k, v in attributes.items()):
  184. continue
  185. filtered.append({'id': node, **data})
  186. return filtered
  187. def get_graph_statistics(self):
  188. """获取图谱统计信息"""
  189. return {
  190. 'node_count': self.graph.number_of_nodes(),
  191. 'edge_count': self.graph.number_of_edges(),
  192. 'density': nx.density(self.graph),
  193. 'components': nx.number_connected_components(self.graph),
  194. 'average_degree': sum(dict(self.graph.degree()).values()) / self.graph.number_of_nodes()
  195. }
  196. def get_community_details(self, community_id=None):
  197. """获取社区详细信息"""
  198. communities = {}
  199. for node, data in self.graph.nodes(data=True):
  200. comm = data.get('community', -1)
  201. if community_id is not None and comm != community_id:
  202. continue
  203. if comm not in communities:
  204. communities[comm] = {
  205. 'node_count': 0,
  206. 'nodes': [],
  207. 'central_nodes': []
  208. }
  209. communities[comm]['node_count'] += 1
  210. communities[comm]['nodes'].append(node)
  211. # 计算每个社区的中心节点
  212. for comm in communities:
  213. subgraph = self.graph.subgraph(communities[comm]['nodes'])
  214. centrality = nx.degree_centrality(subgraph)
  215. top_nodes = sorted(centrality.items(), key=lambda x: -x[1])[:3]
  216. communities[comm]['central_nodes'] = [n[0] for n in top_nodes]
  217. return communities
  218. def find_relations(self, node_ids, relation_types=None):
  219. """查找指定节点集之间的关系"""
  220. relations = []
  221. for u, v, data in self.graph.edges(data=True):
  222. if (u in node_ids or v in node_ids) and \
  223. (not relation_types or data.get('type') in relation_types):
  224. relations.append({
  225. 'source': u,
  226. 'target': v,
  227. **data
  228. })
  229. return relations
  230. def semantic_search(self, query, top_k=5):
  231. """语义搜索(需要与文本嵌入结合)"""
  232. # 这里需要调用文本处理模块的embedding功能
  233. # 示例实现:简单名称匹配
  234. results = []
  235. query_lower = query.lower()
  236. for node, data in self.graph.nodes(data=True):
  237. if query_lower in data.get('name', '').lower():
  238. results.append({
  239. 'id': node,
  240. 'score': 1.0,
  241. **data
  242. })
  243. return sorted(results, key=lambda x: -x['score'])[:top_k]
  244. def get_node_details(self, node_id):
  245. """获取节点详细信息及其关联"""
  246. if node_id not in self.graph:
  247. return None
  248. details = dict(self.graph.nodes[node_id])
  249. details['degree'] = self.graph.degree(node_id)
  250. details['neighbors'] = list(self.graph.neighbors(node_id))
  251. details['edges'] = []
  252. for u, v, data in self.graph.edges(node_id, data=True):
  253. details['edges'].append({
  254. 'target': v if u == node_id else u,
  255. **data
  256. })
  257. return details