test_kg_node_service.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import pytest
  2. from service.kg_node_service import KGNodeService
  3. from model.kg_node import KGNode
  4. from sqlalchemy.exc import IntegrityError
  5. @pytest.fixture(scope="module")
  6. def kg_node_service():
  7. from db.session import get_db
  8. return KGNodeService(next(get_db()))
  9. @pytest.fixture
  10. def test_node_data():
  11. return {
  12. "name": "测试节点",
  13. "category": "测试类别",
  14. "version": "1.0"
  15. }
  16. class TestKGNodeServiceCRUD:
  17. def test_create_and_get_node(self, kg_node_service, test_node_data):
  18. created = kg_node_service.create_node(test_node_data)
  19. assert created.id is not None
  20. retrieved = kg_node_service.get_node(created.id)
  21. assert retrieved.name == test_node_data['name']
  22. def test_update_node(self, kg_node_service, test_node_data):
  23. node = kg_node_service.create_node(test_node_data)
  24. updated = kg_node_service.update_node(node.id, {"name": "更新后的节点"})
  25. assert updated.name == "更新后的节点"
  26. def test_delete_node(self, kg_node_service, test_node_data):
  27. node = kg_node_service.create_node(test_node_data)
  28. assert kg_node_service.delete_node(node.id) is None
  29. with pytest.raises(ValueError):
  30. kg_node_service.get_node(node.id)
  31. def test_duplicate_node(self, kg_node_service, test_node_data):
  32. kg_node_service.create_node(test_node_data)
  33. with pytest.raises(ValueError):
  34. kg_node_service.create_node(test_node_data)
  35. class TestPaginatedSearch:
  36. def test_paginated_search(self, kg_node_service, test_node_data):
  37. results = kg_node_service.paginated_search({
  38. 'keyword': '感染性',
  39. 'pageNo': 1,
  40. 'limit': 10
  41. })
  42. assert len(results['records']) > 0
  43. assert results['pagination']['pageNo'] == 1
  44. class TestBatchProcess:
  45. def test_batch_process_er_nodes(self, kg_node_service, test_node_data):
  46. kg_node_service.batch_process_er_nodes()