graph_helper.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import networkx as nx
  2. import json
  3. from tabulate import tabulate
  4. import leidenalg
  5. import igraph as ig
  6. import sys,os
  7. from db.session import get_db
  8. from service.kg_node_service import KGNodeService
  9. current_path = os.getcwd()
  10. sys.path.append(current_path)
  11. RESOLUTION = 0.07
  12. # 图谱数据缓存路径(由dump_graph_data.py生成)
  13. CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data')
  14. def load_entity_data():
  15. print("load entity data")
  16. with open(os.path.join(CACHED_DATA_PATH,'entities_med.json'), "r", encoding="utf-8") as f:
  17. entities = json.load(f)
  18. return entities
  19. def load_relation_data(g):
  20. for i in range(30):
  21. if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
  22. print("load entity data", os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"))
  23. with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f:
  24. relations = json.load(f)
  25. for item in relations:
  26. if item[0] is None or item[1] is None or item[2] is None:
  27. continue
  28. #删除item[2]['weight']属性
  29. if 'weight' in item[2]:
  30. del item[2]['weight']
  31. g.add_edge(item[0], item[1], weight=1, **item[2])
  32. class GraphHelper:
  33. def __init__(self):
  34. self.graph = None
  35. self.build_graph()
  36. def build_graph(self):
  37. """构建企业关系图谱"""
  38. self.graph = nx.Graph()
  39. # 加载节点数据
  40. entities = load_entity_data()
  41. for item in entities:
  42. node_id = item[0]
  43. attrs = item[1]
  44. self.graph.add_node(node_id, **attrs)
  45. # 加载边数据
  46. load_relation_data(self.graph)
  47. def community_report_search(self, query):
  48. """社区报告检索功能"""
  49. es_result = self.es.search_title_index("graph_community_report_index", query, 10)
  50. results = []
  51. for item in es_result:
  52. results.append({
  53. 'id': item["title"],
  54. 'score': item["score"],
  55. 'text': item["text"]})
  56. return results
  57. def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
  58. """节点检索功能"""
  59. kg_node_service = KGNodeService(next(get_db()))
  60. es_result = kg_node_service.search_title_index("graph_entity_index", node_id, limit)
  61. results = []
  62. for item in es_result:
  63. n = self.graph.nodes.get(item["title"])
  64. score = item["score"]
  65. if n:
  66. results.append({
  67. 'id': item["title"],
  68. 'score': score,
  69. **n
  70. })
  71. return results
  72. # for n in self.graph.nodes(data=True):
  73. # match = True
  74. # if node_id and n[0] != node_id:
  75. # continue
  76. # if node_type and n[1].get('type') != node_type:
  77. # continue
  78. # if filters:
  79. # for k, v in filters.items():
  80. # if n[1].get(k) != v:
  81. # match = False
  82. # break
  83. # if match:
  84. # results.append({
  85. # 'id': n[0],
  86. # **n[1]
  87. # })
  88. return results
  89. def edge_search(self, source=None, target=None, edge_type=None, min_weight=0):
  90. """边检索功能"""
  91. results = []
  92. for u, v, data in self.graph.edges(data=True):
  93. if edge_type and data.get('type') != edge_type:
  94. continue
  95. if data.get('weight', 0) < min_weight:
  96. continue
  97. if (source and u != source and v != source) or \
  98. (target and u != target and v != target):
  99. continue
  100. results.append({
  101. 'source': u,
  102. 'target': v,
  103. **data
  104. })
  105. return results
  106. def neighbor_search(self, node_id, hops=1):
  107. """近邻检索功能"""
  108. if node_id not in self.graph:
  109. return [],[]
  110. # 使用ego_graph获取指定跳数的子图
  111. subgraph = nx.ego_graph(self.graph, node_id, radius=hops)
  112. entities = []
  113. for n in subgraph.nodes(data=True):
  114. if n[0] == node_id: # 跳过中心节点
  115. continue
  116. entities.append({
  117. 'id': n[0],
  118. **n[1]
  119. })
  120. relations = []
  121. for edge in subgraph.edges(data=True):
  122. relations.append({
  123. 'src_name': edge[0],
  124. 'dest_name': edge[1],
  125. **edge[2]
  126. })
  127. return entities, relations
  128. def find_paths(self, source, target, max_depth=3):
  129. """路径查找功能"""
  130. try:
  131. shortest = nx.shortest_path(self.graph, source=source, target=target)
  132. all_paths = nx.all_simple_paths(self.graph, source=source, target=target, cutoff=max_depth)
  133. return {
  134. 'shortest_path': shortest,
  135. 'all_paths': list(all_paths)
  136. }
  137. except nx.NetworkXNoPath:
  138. return {'error': 'No path found'}
  139. except nx.NodeNotFound as e:
  140. return {'error': f'Node not found: {e}'}
  141. def format_output(self, data, fmt='text'):
  142. """格式化输出结果"""
  143. if fmt == 'json':
  144. return json.dumps(data, indent=2, ensure_ascii=False)
  145. # 文本表格格式
  146. if isinstance(data, list):
  147. rows = []
  148. headers = []
  149. if not data:
  150. return "No results found"
  151. # 节点结果
  152. if 'id' in data[0]:
  153. headers = ["ID", "Type", "Description"]
  154. rows = [[d['id'], d.get('type',''), d.get('description','')] for d in data]
  155. # 边结果
  156. elif 'source' in data[0]:
  157. headers = ["Source", "Target", "Type", "Weight"]
  158. rows = [[d['source'], d['target'], d.get('type',''), d.get('weight',0)] for d in data]
  159. return tabulate(rows, headers=headers, tablefmt="grid")
  160. # 路径结果
  161. if isinstance(data, dict):
  162. if 'shortest_path' in data:
  163. output = [
  164. "Shortest Path: " + " → ".join(data['shortest_path']),
  165. "\nAll Paths:"
  166. ]
  167. for path in data['all_paths']:
  168. output.append(" → ".join(path))
  169. return "\n".join(output)
  170. elif 'error' in data:
  171. return data['error']
  172. return str(data)
  173. def detect_communities(self):
  174. """使用Leiden算法进行社区检测"""
  175. # 转换networkx图到igraph格式
  176. print("convert to igraph")
  177. ig_graph = ig.Graph.from_networkx(self.graph)
  178. # 执行Leiden算法
  179. partition = leidenalg.find_partition(
  180. ig_graph,
  181. leidenalg.CPMVertexPartition,
  182. resolution_parameter=RESOLUTION,
  183. n_iterations=2
  184. )
  185. # 将社区标签添加到原始图
  186. for i, node in enumerate(self.graph.nodes()):
  187. self.graph.nodes[node]['community'] = partition.membership[i]
  188. print("convert to igraph finished")
  189. return self.graph, partition
  190. def filter_nodes(self, node_type=None, min_degree=0, attributes=None):
  191. """根据条件过滤节点"""
  192. filtered = []
  193. for node, data in self.graph.nodes(data=True):
  194. if node_type and data.get('type') != node_type:
  195. continue
  196. if min_degree > 0 and self.graph.degree(node) < min_degree:
  197. continue
  198. if attributes:
  199. if not all(data.get(k) == v for k, v in attributes.items()):
  200. continue
  201. filtered.append({'id': node, **data})
  202. return filtered
  203. def get_graph_statistics(self):
  204. """获取图谱统计信息"""
  205. return {
  206. 'node_count': self.graph.number_of_nodes(),
  207. 'edge_count': self.graph.number_of_edges(),
  208. 'density': nx.density(self.graph),
  209. 'components': nx.number_connected_components(self.graph),
  210. 'average_degree': sum(dict(self.graph.degree()).values()) / self.graph.number_of_nodes()
  211. }
  212. def get_community_details(self, community_id=None):
  213. """获取社区详细信息"""
  214. communities = {}
  215. for node, data in self.graph.nodes(data=True):
  216. comm = data.get('community', -1)
  217. if community_id is not None and comm != community_id:
  218. continue
  219. if comm not in communities:
  220. communities[comm] = {
  221. 'node_count': 0,
  222. 'nodes': [],
  223. 'central_nodes': []
  224. }
  225. communities[comm]['node_count'] += 1
  226. communities[comm]['nodes'].append(node)
  227. # 计算每个社区的中心节点
  228. for comm in communities:
  229. subgraph = self.graph.subgraph(communities[comm]['nodes'])
  230. centrality = nx.degree_centrality(subgraph)
  231. top_nodes = sorted(centrality.items(), key=lambda x: -x[1])[:3]
  232. communities[comm]['central_nodes'] = [n[0] for n in top_nodes]
  233. return communities
  234. def find_relations(self, node_ids, relation_types=None):
  235. """查找指定节点集之间的关系"""
  236. relations = []
  237. for u, v, data in self.graph.edges(data=True):
  238. if (u in node_ids or v in node_ids) and \
  239. (not relation_types or data.get('type') in relation_types):
  240. relations.append({
  241. 'source': u,
  242. 'target': v,
  243. **data
  244. })
  245. return relations
  246. def semantic_search(self, query, top_k=5):
  247. """语义搜索(需要与文本嵌入结合)"""
  248. # 这里需要调用文本处理模块的embedding功能
  249. # 示例实现:简单名称匹配
  250. results = []
  251. query_lower = query.lower()
  252. for node, data in self.graph.nodes(data=True):
  253. if query_lower in data.get('name', '').lower():
  254. results.append({
  255. 'id': node,
  256. 'score': 1.0,
  257. **data
  258. })
  259. return sorted(results, key=lambda x: -x['score'])[:top_k]
  260. def get_node_details(self, node_id):
  261. """获取节点详细信息及其关联"""
  262. if node_id not in self.graph:
  263. return None
  264. details = dict(self.graph.nodes[node_id])
  265. details['degree'] = self.graph.degree(node_id)
  266. details['neighbors'] = list(self.graph.neighbors(node_id))
  267. details['edges'] = []
  268. for u, v, data in self.graph.edges(node_id, data=True):
  269. details['edges'].append({
  270. 'target': v if u == node_id else u,
  271. **data
  272. })
  273. return details