test_graph_helper.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import unittest
  2. import json
  3. import os
  4. from community.graph_helper2 import GraphHelper
  5. class TestGraphHelper(unittest.TestCase):
  6. """
  7. 图谱助手测试套件
  8. 测试数据路径:web/cached_data 下的实际医疗知识图谱数据
  9. """
  10. @classmethod
  11. def setUpClass(cls):
  12. # 初始化图谱助手并构建图谱
  13. cls.helper = GraphHelper()
  14. cls.test_node = "感染性发热" # 使用实际存在的测试节点
  15. cls.test_community_node = "糖尿病" # 用于社区检测的测试节点
  16. def test_graph_construction(self):
  17. """
  18. 测试图谱构建完整性
  19. 验证点:节点和边数量应大于0
  20. """
  21. node_count = len(self.helper.graph.nodes)
  22. edge_count = len(self.helper.graph.edges)
  23. self.assertGreater(node_count, 0, "节点数量应大于0")
  24. self.assertGreater(edge_count, 0, "边数量应大于0")
  25. def test_node_search(self):
  26. """
  27. 测试节点搜索功能
  28. 场景:1.按节点ID精确搜索 2.按类型过滤 3.自定义属性过滤
  29. """
  30. # 精确ID搜索
  31. result = self.helper.node_search(node_id=self.test_node)
  32. self.assertEqual(len(result), 1, "应找到唯一匹配节点")
  33. # 类型过滤搜索
  34. type_results = self.helper.node_search(node_type="症状")
  35. self.assertTrue(all(item['type'] == "症状" for item in type_results))
  36. # 自定义属性过滤
  37. filter_results = self.helper.node_search(filters={"description": "发热病因"})
  38. self.assertTrue(len(filter_results) >= 1)
  39. def test_neighbor_search(self):
  40. """
  41. 测试邻居检索功能
  42. 验证点:1.跳数限制 2.不包含中心节点 3.关系完整性
  43. """
  44. entities, relations = self.helper.neighbor_search(self.test_node, hops=2)
  45. self.assertFalse(any(e['name'] == self.test_node for e in entities),
  46. "结果不应包含中心节点")
  47. # 验证关系双向连接
  48. for rel in relations:
  49. self.assertTrue(
  50. any(e['name'] == rel['src_name'] for e in entities) or
  51. any(e['name'] == rel['dest_name'] for e in entities)
  52. )
  53. def test_path_finding(self):
  54. """
  55. 测试路径查找功能
  56. 使用已知存在路径的节点对进行验证
  57. """
  58. target_node = "肺炎"
  59. result = self.helper.find_paths(self.test_node, target_node)
  60. self.assertIn('shortest_path', result)
  61. self.assertTrue(len(result['shortest_path']) >= 2)
  62. # 验证所有路径都包含起始节点
  63. for path in result['all_paths']:
  64. self.assertEqual(path[0], self.test_node)
  65. self.assertEqual(path[-1], target_node)
  66. def test_community_detection(self):
  67. """
  68. 测试社区检测功能
  69. 验证点:1.社区标签存在 2.同类节点聚集 3.社区数量合理
  70. """
  71. graph, partition = self.helper.detect_communities()
  72. # 验证节点社区属性
  73. test_node_community = graph.nodes[self.test_community_node]['community']
  74. self.assertIsInstance(test_node_community, int)
  75. # 验证同类节点聚集(例如糖尿病相关节点)
  76. diabetes_nodes = [n for n, attr in graph.nodes(data=True)
  77. if attr.get('type') == "代谢疾病"]
  78. communities = set(graph.nodes[n]['community'] for n in diabetes_nodes)
  79. self.assertLessEqual(len(communities), 3, "同类节点应集中在少数社区")
  80. if __name__ == '__main__':
  81. unittest.main()