graph_helper.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. import networkx as nx
  2. import json
  3. from tabulate import tabulate
  4. import leidenalg
  5. import igraph as ig
  6. import sys,os
  7. from db.session import get_db
  8. from service.kg_node_service import KGNodeService
  9. current_path = os.getcwd()
  10. sys.path.append(current_path)
  11. RESOLUTION = 0.07
  12. # 图谱数据缓存路径(由dump_graph_data.py生成)
  13. CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data')
  14. def load_entity_data():
  15. print("load entity data")
  16. if not os.path.exists(os.path.join(CACHED_DATA_PATH,'entities_med.json')):
  17. return []
  18. with open(os.path.join(CACHED_DATA_PATH,'entities_med.json'), "r", encoding="utf-8") as f:
  19. entities = json.load(f)
  20. return entities
  21. def load_relation_data(g):
  22. for i in range(0):
  23. if not os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
  24. continue
  25. if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
  26. print("load entity data", os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"))
  27. with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f:
  28. relations = json.load(f)
  29. for item in relations:
  30. if item[0] is None or item[1] is None or item[2] is None:
  31. continue
  32. #删除item[2]['weight']属性
  33. if 'weight' in item[2]:
  34. del item[2]['weight']
  35. g.add_edge(item[0], item[1], weight=1, **item[2])
  36. class GraphHelper:
  37. def __init__(self):
  38. self.graph = None
  39. self.build_graph()
  40. def build_graph(self):
  41. """构建企业关系图谱"""
  42. self.graph = nx.Graph()
  43. # 加载节点数据
  44. entities = load_entity_data()
  45. for item in entities:
  46. node_id = item[0]
  47. attrs = item[1]
  48. self.graph.add_node(node_id, **attrs)
  49. # 加载边数据
  50. load_relation_data(self.graph)
  51. def community_report_search(self, query):
  52. """社区报告检索功能"""
  53. es_result = self.es.search_title_index("graph_community_report_index", query, 10)
  54. results = []
  55. for item in es_result:
  56. results.append({
  57. 'id': item["title"],
  58. 'score': item["score"],
  59. 'text': item["text"]})
  60. return results
  61. def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
  62. """节点检索功能"""
  63. kg_node_service = KGNodeService(next(get_db()))
  64. es_result = kg_node_service.search_title_index("graph_entity_index", node_id, limit)
  65. results = []
  66. for item in es_result:
  67. n = self.graph.nodes.get(item["title"])
  68. score = item["score"]
  69. if n:
  70. results.append({
  71. 'id': item["title"],
  72. 'score': score,
  73. **n
  74. })
  75. return results
  76. # for n in self.graph.nodes(data=True):
  77. # match = True
  78. # if node_id and n[0] != node_id:
  79. # continue
  80. # if node_type and n[1].get('type') != node_type:
  81. # continue
  82. # if filters:
  83. # for k, v in filters.items():
  84. # if n[1].get(k) != v:
  85. # match = False
  86. # break
  87. # if match:
  88. # results.append({
  89. # 'id': n[0],
  90. # **n[1]
  91. # })
  92. return results
  93. def edge_search(self, source=None, target=None, edge_type=None, min_weight=0):
  94. """边检索功能"""
  95. results = []
  96. for u, v, data in self.graph.edges(data=True):
  97. if edge_type and data.get('type') != edge_type:
  98. continue
  99. if data.get('weight', 0) < min_weight:
  100. continue
  101. if (source and u != source and v != source) or \
  102. (target and u != target and v != target):
  103. continue
  104. results.append({
  105. 'source': u,
  106. 'target': v,
  107. **data
  108. })
  109. return results
  110. def neighbor_search(self, node_id, hops=1):
  111. """近邻检索功能"""
  112. if node_id not in self.graph:
  113. return [],[]
  114. # 使用ego_graph获取指定跳数的子图
  115. subgraph = nx.ego_graph(self.graph, node_id, radius=hops)
  116. entities = []
  117. for n in subgraph.nodes(data=True):
  118. if n[0] == node_id: # 跳过中心节点
  119. continue
  120. entities.append({
  121. 'id': n[0],
  122. **n[1]
  123. })
  124. relations = []
  125. for edge in subgraph.edges(data=True):
  126. relations.append({
  127. 'src_name': edge[0],
  128. 'dest_name': edge[1],
  129. **edge[2]
  130. })
  131. return entities, relations
  132. def find_paths(self, source, target, max_depth=3):
  133. """路径查找功能"""
  134. try:
  135. shortest = nx.shortest_path(self.graph, source=source, target=target)
  136. all_paths = nx.all_simple_paths(self.graph, source=source, target=target, cutoff=max_depth)
  137. return {
  138. 'shortest_path': shortest,
  139. 'all_paths': list(all_paths)
  140. }
  141. except nx.NetworkXNoPath:
  142. return {'error': 'No path found'}
  143. except nx.NodeNotFound as e:
  144. return {'error': f'Node not found: {e}'}
  145. def format_output(self, data, fmt='text'):
  146. """格式化输出结果"""
  147. if fmt == 'json':
  148. return json.dumps(data, indent=2, ensure_ascii=False)
  149. # 文本表格格式
  150. if isinstance(data, list):
  151. rows = []
  152. headers = []
  153. if not data:
  154. return "No results found"
  155. # 节点结果
  156. if 'id' in data[0]:
  157. headers = ["ID", "Type", "Description"]
  158. rows = [[d['id'], d.get('type',''), d.get('description','')] for d in data]
  159. # 边结果
  160. elif 'source' in data[0]:
  161. headers = ["Source", "Target", "Type", "Weight"]
  162. rows = [[d['source'], d['target'], d.get('type',''), d.get('weight',0)] for d in data]
  163. return tabulate(rows, headers=headers, tablefmt="grid")
  164. # 路径结果
  165. if isinstance(data, dict):
  166. if 'shortest_path' in data:
  167. output = [
  168. "Shortest Path: " + " → ".join(data['shortest_path']),
  169. "\nAll Paths:"
  170. ]
  171. for path in data['all_paths']:
  172. output.append(" → ".join(path))
  173. return "\n".join(output)
  174. elif 'error' in data:
  175. return data['error']
  176. return str(data)
  177. def detect_communities(self):
  178. """使用Leiden算法进行社区检测"""
  179. # 转换networkx图到igraph格式
  180. print("convert to igraph")
  181. ig_graph = ig.Graph.from_networkx(self.graph)
  182. # 执行Leiden算法
  183. partition = leidenalg.find_partition(
  184. ig_graph,
  185. leidenalg.CPMVertexPartition,
  186. resolution_parameter=RESOLUTION,
  187. n_iterations=2
  188. )
  189. # 将社区标签添加到原始图
  190. for i, node in enumerate(self.graph.nodes()):
  191. self.graph.nodes[node]['community'] = partition.membership[i]
  192. print("convert to igraph finished")
  193. return self.graph, partition
  194. def filter_nodes(self, node_type=None, min_degree=0, attributes=None):
  195. """根据条件过滤节点"""
  196. filtered = []
  197. for node, data in self.graph.nodes(data=True):
  198. if node_type and data.get('type') != node_type:
  199. continue
  200. if min_degree > 0 and self.graph.degree(node) < min_degree:
  201. continue
  202. if attributes:
  203. if not all(data.get(k) == v for k, v in attributes.items()):
  204. continue
  205. filtered.append({'id': node, **data})
  206. return filtered
  207. def get_graph_statistics(self):
  208. """获取图谱统计信息"""
  209. return {
  210. 'node_count': self.graph.number_of_nodes(),
  211. 'edge_count': self.graph.number_of_edges(),
  212. 'density': nx.density(self.graph),
  213. 'components': nx.number_connected_components(self.graph),
  214. 'average_degree': sum(dict(self.graph.degree()).values()) / self.graph.number_of_nodes()
  215. }
  216. def get_community_details(self, community_id=None):
  217. """获取社区详细信息"""
  218. communities = {}
  219. for node, data in self.graph.nodes(data=True):
  220. comm = data.get('community', -1)
  221. if community_id is not None and comm != community_id:
  222. continue
  223. if comm not in communities:
  224. communities[comm] = {
  225. 'node_count': 0,
  226. 'nodes': [],
  227. 'central_nodes': []
  228. }
  229. communities[comm]['node_count'] += 1
  230. communities[comm]['nodes'].append(node)
  231. # 计算每个社区的中心节点
  232. for comm in communities:
  233. subgraph = self.graph.subgraph(communities[comm]['nodes'])
  234. centrality = nx.degree_centrality(subgraph)
  235. top_nodes = sorted(centrality.items(), key=lambda x: -x[1])[:3]
  236. communities[comm]['central_nodes'] = [n[0] for n in top_nodes]
  237. return communities
  238. def find_relations(self, node_ids, relation_types=None):
  239. """查找指定节点集之间的关系"""
  240. relations = []
  241. for u, v, data in self.graph.edges(data=True):
  242. if (u in node_ids or v in node_ids) and \
  243. (not relation_types or data.get('type') in relation_types):
  244. relations.append({
  245. 'source': u,
  246. 'target': v,
  247. **data
  248. })
  249. return relations
  250. def semantic_search(self, query, top_k=5):
  251. """语义搜索(需要与文本嵌入结合)"""
  252. # 这里需要调用文本处理模块的embedding功能
  253. # 示例实现:简单名称匹配
  254. results = []
  255. query_lower = query.lower()
  256. for node, data in self.graph.nodes(data=True):
  257. if query_lower in data.get('name', '').lower():
  258. results.append({
  259. 'id': node,
  260. 'score': 1.0,
  261. **data
  262. })
  263. return sorted(results, key=lambda x: -x['score'])[:top_k]
  264. def get_node_details(self, node_id):
  265. """获取节点详细信息及其关联"""
  266. if node_id not in self.graph:
  267. return None
  268. details = dict(self.graph.nodes[node_id])
  269. details['degree'] = self.graph.degree(node_id)
  270. details['neighbors'] = list(self.graph.neighbors(node_id))
  271. details['edges'] = []
  272. for u, v, data in self.graph.edges(node_id, data=True):
  273. details['edges'].append({
  274. 'target': v if u == node_id else u,
  275. **data
  276. })
  277. return details