import unittest import json import os from community.graph_helper2 import GraphHelper class TestGraphHelper(unittest.TestCase): """ 图谱助手测试套件 测试数据路径:web/cached_data 下的实际医疗知识图谱数据 """ @classmethod def setUpClass(cls): # 初始化图谱助手并构建图谱 cls.helper = GraphHelper() cls.test_node = "感染性发热" # 使用实际存在的测试节点 cls.test_community_node = "糖尿病" # 用于社区检测的测试节点 def test_graph_construction(self): """ 测试图谱构建完整性 验证点:节点和边数量应大于0 """ node_count = len(self.helper.graph.nodes) edge_count = len(self.helper.graph.edges) self.assertGreater(node_count, 0, "节点数量应大于0") self.assertGreater(edge_count, 0, "边数量应大于0") def test_node_search(self): """ 测试节点搜索功能 场景:1.按节点ID精确搜索 2.按类型过滤 3.自定义属性过滤 """ # 精确ID搜索 result = self.helper.node_search(node_id=self.test_node) self.assertEqual(len(result), 1, "应找到唯一匹配节点") # 类型过滤搜索 type_results = self.helper.node_search(node_type="症状") self.assertTrue(all(item['type'] == "症状" for item in type_results)) # 自定义属性过滤 filter_results = self.helper.node_search(filters={"description": "发热病因"}) self.assertTrue(len(filter_results) >= 1) def test_neighbor_search(self): """ 测试邻居检索功能 验证点:1.跳数限制 2.不包含中心节点 3.关系完整性 """ entities, relations = self.helper.neighbor_search(self.test_node, hops=2) self.assertFalse(any(e['name'] == self.test_node for e in entities), "结果不应包含中心节点") # 验证关系双向连接 for rel in relations: self.assertTrue( any(e['name'] == rel['src_name'] for e in entities) or any(e['name'] == rel['dest_name'] for e in entities) ) def test_path_finding(self): """ 测试路径查找功能 使用已知存在路径的节点对进行验证 """ target_node = "肺炎" result = self.helper.find_paths(self.test_node, target_node) self.assertIn('shortest_path', result) self.assertTrue(len(result['shortest_path']) >= 2) # 验证所有路径都包含起始节点 for path in result['all_paths']: self.assertEqual(path[0], self.test_node) self.assertEqual(path[-1], target_node) def test_community_detection(self): """ 测试社区检测功能 验证点:1.社区标签存在 2.同类节点聚集 3.社区数量合理 """ graph, partition = self.helper.detect_communities() # 验证节点社区属性 test_node_community = graph.nodes[self.test_community_node]['community'] self.assertIsInstance(test_node_community, int) # 验证同类节点聚集(例如糖尿病相关节点) diabetes_nodes = [n for n, attr in graph.nodes(data=True) if attr.get('type') == "代谢疾病"] communities = set(graph.nodes[n]['community'] for n in diabetes_nodes) self.assertLessEqual(len(communities), 3, "同类节点应集中在少数社区") if __name__ == '__main__': unittest.main()