test_trunks_service.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from pathlib import Path
  2. import pytest
  3. from service.trunks_service import TrunksService
  4. from model.trunks_model import Trunks
  5. from sqlalchemy.exc import IntegrityError
  6. @pytest.fixture(scope="module")
  7. def trunks_service():
  8. return TrunksService()
  9. @pytest.fixture
  10. def test_trunk_data():
  11. return {
  12. "content": """测试""",
  13. "file_path": "test_path.pdf",
  14. "type": "default"
  15. }
  16. class TestTrunksServiceCRUD:
  17. def test_create_and_get_trunk(self, trunks_service, test_trunk_data):
  18. # 测试创建和查询
  19. created = trunks_service.create_trunk(test_trunk_data)
  20. assert created.id is not None
  21. def test_update_trunk(self, trunks_service, test_trunk_data):
  22. trunk = trunks_service.create_trunk(test_trunk_data)
  23. updated = trunks_service.update_trunk(trunk.id, {"content": "更新内容"})
  24. assert updated.content == "更新内容"
  25. def test_delete_trunk(self, trunks_service, test_trunk_data):
  26. trunk = trunks_service.create_trunk(test_trunk_data)
  27. assert trunks_service.delete_trunk(trunk.id)
  28. assert trunks_service.get_trunk_by_id(trunk.id) is None
  29. class TestSearchOperations:
  30. def test_vector_search(self, trunks_service, test_trunk_data):
  31. results = trunks_service.search_by_vector("急性胰腺炎是常见的急腹症之一,以突发上腹部剧痛伴恶心呕吐为特征。轻症预后良好,重症可并发多器官衰竭,死亡率高达20-30%。",10,conversation_id="1111111aaaa")
  32. print("搜索结果:", results[0])
  33. results = trunks_service.get_cache("1111111aaaa")
  34. print("搜索结果:", results)
  35. assert len(results) > 0
  36. # def test_fulltext_search(self, trunks_service, test_trunk_data):
  37. # trunks_service.create_trunk(test_trunk_data)
  38. # results = trunks_service.fulltext_search("测试")
  39. # assert len(results) > 0
  40. class TestExceptionCases:
  41. def test_duplicate_id(self, trunks_service, test_trunk_data):
  42. with pytest.raises(IntegrityError):
  43. trunk1 = trunks_service.create_trunk(test_trunk_data)
  44. test_trunk_data["id"] = trunk1.id
  45. trunks_service.create_trunk(test_trunk_data)
  46. def test_invalid_vector_dimension(self, trunks_service, test_trunk_data):
  47. with pytest.raises(ValueError):
  48. invalid_data = test_trunk_data.copy()
  49. invalid_data["embedding"] = [0.1]*100
  50. trunks_service.create_trunk(invalid_data)
  51. @pytest.fixture
  52. def trunk_factory():
  53. class TrunkFactory:
  54. @staticmethod
  55. def create(**overrides):
  56. defaults = {
  57. "content": "工厂内容",
  58. "file_path": "factory_path.pdf",
  59. "type": "default"
  60. }
  61. return {**defaults, **overrides}
  62. return TrunkFactory()
  63. class TestBatchCreateFromDirectory:
  64. def test_batch_create_from_directory(self, trunks_service):
  65. # 使用现有目录路径
  66. base_path = Path(r'E:\project\vscode\files1')
  67. # 遍历目录并创建trunk
  68. created_ids = []
  69. for txt_path in base_path.glob('**/*_split_*.txt'):
  70. relative_path = txt_path.relative_to(base_path.parent.parent)
  71. with open(txt_path, 'r', encoding='utf-8') as f:
  72. trunk_data = {
  73. "content": f.read(),
  74. "file_path": str(relative_path).replace('\\', '/')
  75. }
  76. trunk = trunks_service.create_trunk(trunk_data)
  77. created_ids.append(trunk.id)
  78. # 验证数据库记录
  79. for trunk_id in created_ids:
  80. db_trunk = trunks_service.get_trunk_by_id(trunk_id)
  81. assert db_trunk is not None
  82. assert ".txt" in db_trunk.file_path
  83. assert "_split_" in db_trunk.file_path
  84. assert len(db_trunk.content) > 0