12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- import unittest
- import json
- import os
- from community.graph_helper2 import GraphHelper
- class TestGraphHelper(unittest.TestCase):
- """
- 图谱助手测试套件
- 测试数据路径:web/cached_data 下的实际医疗知识图谱数据
- """
- @classmethod
- def setUpClass(cls):
- # 初始化图谱助手并构建图谱
- cls.helper = GraphHelper()
- cls.test_node = "感染性发热" # 使用实际存在的测试节点
- cls.test_community_node = "糖尿病" # 用于社区检测的测试节点
- def test_graph_construction(self):
- """
- 测试图谱构建完整性
- 验证点:节点和边数量应大于0
- """
- node_count = len(self.helper.graph.nodes)
- edge_count = len(self.helper.graph.edges)
- self.assertGreater(node_count, 0, "节点数量应大于0")
- self.assertGreater(edge_count, 0, "边数量应大于0")
- def test_node_search(self):
- """
- 测试节点搜索功能
- 场景:1.按节点ID精确搜索 2.按类型过滤 3.自定义属性过滤
- """
- # 精确ID搜索
- result = self.helper.node_search(node_id=self.test_node)
- self.assertEqual(len(result), 1, "应找到唯一匹配节点")
- # 类型过滤搜索
- type_results = self.helper.node_search(node_type="症状")
- self.assertTrue(all(item['type'] == "症状" for item in type_results))
- # 自定义属性过滤
- filter_results = self.helper.node_search(filters={"description": "发热病因"})
- self.assertTrue(len(filter_results) >= 1)
- def test_neighbor_search(self):
- """
- 测试邻居检索功能
- 验证点:1.跳数限制 2.不包含中心节点 3.关系完整性
- """
- entities, relations = self.helper.neighbor_search(self.test_node, hops=2)
-
- self.assertFalse(any(e['name'] == self.test_node for e in entities),
- "结果不应包含中心节点")
-
- # 验证关系双向连接
- for rel in relations:
- self.assertTrue(
- any(e['name'] == rel['src_name'] for e in entities) or
- any(e['name'] == rel['dest_name'] for e in entities)
- )
- def test_path_finding(self):
- """
- 测试路径查找功能
- 使用已知存在路径的节点对进行验证
- """
- target_node = "肺炎"
- result = self.helper.find_paths(self.test_node, target_node)
-
- self.assertIn('shortest_path', result)
- self.assertTrue(len(result['shortest_path']) >= 2)
-
- # 验证所有路径都包含起始节点
- for path in result['all_paths']:
- self.assertEqual(path[0], self.test_node)
- self.assertEqual(path[-1], target_node)
- def test_community_detection(self):
- """
- 测试社区检测功能
- 验证点:1.社区标签存在 2.同类节点聚集 3.社区数量合理
- """
- graph, partition = self.helper.detect_communities()
-
- # 验证节点社区属性
- test_node_community = graph.nodes[self.test_community_node]['community']
- self.assertIsInstance(test_node_community, int)
-
- # 验证同类节点聚集(例如糖尿病相关节点)
- diabetes_nodes = [n for n, attr in graph.nodes(data=True)
- if attr.get('type') == "代谢疾病"]
- communities = set(graph.nodes[n]['community'] for n in diabetes_nodes)
- self.assertLessEqual(len(communities), 3, "同类节点应集中在少数社区")
- if __name__ == '__main__':
- unittest.main()
|