graph_helper.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. import networkx as nx
  2. import json
  3. from tabulate import tabulate
  4. import leidenalg
  5. import igraph as ig
  6. import sys,os
  7. from config.site import SiteConfig
  8. from utils.es import ElasticsearchOperations
  9. from typing import List
  10. config = SiteConfig()
  11. current_path = os.getcwd()
  12. sys.path.append(current_path)
  13. RESOLUTION = 0.07
  14. #图谱数据的缓存路径,数据从dump_graph_data.py生成
  15. CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
  16. def load_entity_data():
  17. print("load entity data")
  18. with open(f"{CACHED_DATA_PATH}\\entities_med.json", "r", encoding="utf-8") as f:
  19. entities = json.load(f)
  20. return entities
  21. def load_relation_data(g):
  22. for i in range(30):
  23. if os.path.exists(f"{CACHED_DATA_PATH}\\relationship_med_{i}.json"):
  24. print("load entity data", f"{CACHED_DATA_PATH}\\relationship_med_{i}.json")
  25. with open(f"{CACHED_DATA_PATH}\\relationship_med_{i}.json", "r", encoding="utf-8") as f:
  26. relations = json.load(f)
  27. for item in relations:
  28. g.add_edge(item[0], item[1], weight=1, **item[2])
  29. class GraphHelper:
  30. def __init__(self):
  31. self.graph = None
  32. self.build_graph()
  33. self.es = ElasticsearchOperations()
  34. def build_graph(self):
  35. """构建企业关系图谱"""
  36. self.graph = nx.DiGraph()
  37. # 加载节点数据
  38. entities = load_entity_data()
  39. for item in entities:
  40. node_id = item[0]
  41. attrs = item[1]
  42. self.graph.add_node(node_id, **attrs)
  43. # 加载边数据
  44. load_relation_data(self.graph)
  45. def community_report_search(self, query):
  46. """社区报告检索功能"""
  47. es_result = self.es.search_title_index("graph_community_report_index", query, 10)
  48. results = []
  49. for item in es_result:
  50. results.append({
  51. 'id': item["title"],
  52. 'score': item["score"],
  53. 'text': item["text"]})
  54. return results
  55. def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
  56. """节点检索功能"""
  57. es_result = self.es.search_title_index("graph_entity_index", node_id, limit)
  58. results = []
  59. for item in es_result:
  60. n = self.graph.nodes.get(item["title"])
  61. score = item["score"]
  62. if n:
  63. results.append({
  64. 'id': item["title"],
  65. 'score': score,
  66. **n
  67. })
  68. # for n in self.graph.nodes(data=True):
  69. # match = True
  70. # if node_id and n[0] != node_id:
  71. # continue
  72. # if node_type and n[1].get('type') != node_type:
  73. # continue
  74. # if filters:
  75. # for k, v in filters.items():
  76. # if n[1].get(k) != v:
  77. # match = False
  78. # break
  79. # if match:
  80. # results.append({
  81. # 'id': n[0],
  82. # **n[1]
  83. # })
  84. return results
  85. def edge_search(self, source=None, target=None, edge_type=None, min_weight=0):
  86. """边检索功能"""
  87. results = []
  88. for u, v, data in self.graph.edges(data=True):
  89. if edge_type and data.get('type') != edge_type:
  90. continue
  91. if data.get('weight', 0) < min_weight:
  92. continue
  93. if (source and u != source and v != source) or \
  94. (target and u != target and v != target):
  95. continue
  96. results.append({
  97. 'source': u,
  98. 'target': v,
  99. **data
  100. })
  101. return results
  102. def neighbor_search(self, node_id, hops=1):
  103. """近邻检索功能"""
  104. if node_id not in self.graph:
  105. return [],[]
  106. # 使用ego_graph获取指定跳数的子图
  107. subgraph = nx.ego_graph(self.graph, node_id, radius=hops)
  108. entities = []
  109. for n in subgraph.nodes(data=True):
  110. if n[0] == node_id: # 跳过中心节点
  111. continue
  112. entities.append({
  113. 'id': n[0],
  114. **n[1]
  115. })
  116. relations = []
  117. for edge in subgraph.edges(data=True):
  118. relations.append({
  119. 'src_name': edge[0],
  120. 'dest_name': edge[1],
  121. **edge[2]
  122. })
  123. return entities, relations
  124. def find_paths(self, source, target, max_depth=3):
  125. """路径查找功能"""
  126. try:
  127. shortest = nx.shortest_path(self.graph, source=source, target=target)
  128. all_paths = nx.all_simple_paths(self.graph, source=source, target=target, cutoff=max_depth)
  129. return {
  130. 'shortest_path': shortest,
  131. 'all_paths': list(all_paths)
  132. }
  133. except nx.NetworkXNoPath:
  134. return {'error': 'No path found'}
  135. except nx.NodeNotFound as e:
  136. return {'error': f'Node not found: {e}'}
  137. def format_output(self, data, fmt='text'):
  138. """格式化输出结果"""
  139. if fmt == 'json':
  140. return json.dumps(data, indent=2, ensure_ascii=False)
  141. # 文本表格格式
  142. if isinstance(data, list):
  143. rows = []
  144. headers = []
  145. if not data:
  146. return "No results found"
  147. # 节点结果
  148. if 'id' in data[0]:
  149. headers = ["ID", "Type", "Description"]
  150. rows = [[d['id'], d.get('type',''), d.get('description','')] for d in data]
  151. # 边结果
  152. elif 'source' in data[0]:
  153. headers = ["Source", "Target", "Type", "Weight"]
  154. rows = [[d['source'], d['target'], d.get('type',''), d.get('weight',0)] for d in data]
  155. return tabulate(rows, headers=headers, tablefmt="grid")
  156. # 路径结果
  157. if isinstance(data, dict):
  158. if 'shortest_path' in data:
  159. output = [
  160. "Shortest Path: " + " → ".join(data['shortest_path']),
  161. "\nAll Paths:"
  162. ]
  163. for path in data['all_paths']:
  164. output.append(" → ".join(path))
  165. return "\n".join(output)
  166. elif 'error' in data:
  167. return data['error']
  168. return str(data)
  169. def detect_communities(self):
  170. """使用Leiden算法进行社区检测"""
  171. # 转换networkx图到igraph格式
  172. print("convert to igraph")
  173. ig_graph = ig.Graph.from_networkx(self.graph)
  174. # 执行Leiden算法
  175. partition = leidenalg.find_partition(
  176. ig_graph,
  177. leidenalg.CPMVertexPartition,
  178. resolution_parameter=RESOLUTION,
  179. n_iterations=2
  180. )
  181. # 将社区标签添加到原始图
  182. for i, node in enumerate(self.graph.nodes()):
  183. self.graph.nodes[node]['community'] = partition.membership[i]
  184. print("convert to igraph finished")
  185. return self.graph, partition
  186. def filter_nodes(self, node_type=None, min_degree=0, attributes=None):
  187. """根据条件过滤节点"""
  188. filtered = []
  189. for node, data in self.graph.nodes(data=True):
  190. if node_type and data.get('type') != node_type:
  191. continue
  192. if min_degree > 0 and self.graph.degree(node) < min_degree:
  193. continue
  194. if attributes:
  195. if not all(data.get(k) == v for k, v in attributes.items()):
  196. continue
  197. filtered.append({'id': node, **data})
  198. return filtered
  199. def get_graph_statistics(self):
  200. """获取图谱统计信息"""
  201. return {
  202. 'node_count': self.graph.number_of_nodes(),
  203. 'edge_count': self.graph.number_of_edges(),
  204. 'density': nx.density(self.graph),
  205. 'components': nx.number_connected_components(self.graph),
  206. 'average_degree': sum(dict(self.graph.degree()).values()) / self.graph.number_of_nodes()
  207. }
  208. def get_community_details(self, community_id=None):
  209. """获取社区详细信息"""
  210. communities = {}
  211. for node, data in self.graph.nodes(data=True):
  212. comm = data.get('community', -1)
  213. if community_id is not None and comm != community_id:
  214. continue
  215. if comm not in communities:
  216. communities[comm] = {
  217. 'node_count': 0,
  218. 'nodes': [],
  219. 'central_nodes': []
  220. }
  221. communities[comm]['node_count'] += 1
  222. communities[comm]['nodes'].append(node)
  223. # 计算每个社区的中心节点
  224. for comm in communities:
  225. subgraph = self.graph.subgraph(communities[comm]['nodes'])
  226. centrality = nx.degree_centrality(subgraph)
  227. top_nodes = sorted(centrality.items(), key=lambda x: -x[1])[:3]
  228. communities[comm]['central_nodes'] = [n[0] for n in top_nodes]
  229. return communities
  230. def find_relations(self, node_ids, relation_types=None):
  231. """查找指定节点集之间的关系"""
  232. relations = []
  233. for u, v, data in self.graph.edges(data=True):
  234. if (u in node_ids or v in node_ids) and \
  235. (not relation_types or data.get('type') in relation_types):
  236. relations.append({
  237. 'source': u,
  238. 'target': v,
  239. **data
  240. })
  241. return relations
  242. def semantic_search(self, query, top_k=5):
  243. """语义搜索(需要与文本嵌入结合)"""
  244. # 这里需要调用文本处理模块的embedding功能
  245. # 示例实现:简单名称匹配
  246. results = []
  247. query_lower = query.lower()
  248. for node, data in self.graph.nodes(data=True):
  249. if query_lower in data.get('name', '').lower():
  250. results.append({
  251. 'id': node,
  252. 'score': 1.0,
  253. **data
  254. })
  255. return sorted(results, key=lambda x: -x['score'])[:top_k]
  256. def get_node_details(self, node_id):
  257. """获取节点详细信息及其关联"""
  258. if node_id not in self.graph:
  259. return None
  260. details = dict(self.graph.nodes[node_id])
  261. details['degree'] = self.graph.degree(node_id)
  262. details['neighbors'] = list(self.graph.neighbors(node_id))
  263. details['edges'] = []
  264. for u, v, data in self.graph.edges(node_id, data=True):
  265. details['edges'].append({
  266. 'target': v if u == node_id else u,
  267. **data
  268. })
  269. return details