test_trunks_service.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import pytest
  2. from service.trunks_service import TrunksService
  3. from model.trunks_model import Trunks
  4. from sqlalchemy.exc import IntegrityError
  5. @pytest.fixture(scope="module")
  6. def trunks_service():
  7. return TrunksService()
  8. @pytest.fixture
  9. def test_trunk_data():
  10. return {
  11. "content": """测试""",
  12. "file_path": "test_path.pdf",
  13. "type": "default"
  14. }
  15. class TestTrunksServiceCRUD:
  16. def test_create_and_get_trunk(self, trunks_service, test_trunk_data):
  17. # 测试创建和查询
  18. created = trunks_service.create_trunk(test_trunk_data)
  19. assert created.id is not None
  20. def test_update_trunk(self, trunks_service, test_trunk_data):
  21. trunk = trunks_service.create_trunk(test_trunk_data)
  22. updated = trunks_service.update_trunk(trunk.id, {"content": "更新内容"})
  23. assert updated.content == "更新内容"
  24. def test_delete_trunk(self, trunks_service, test_trunk_data):
  25. trunk = trunks_service.create_trunk(test_trunk_data)
  26. assert trunks_service.delete_trunk(trunk.id)
  27. assert trunks_service.get_trunk_by_id(trunk.id) is None
  28. class TestSearchOperations:
  29. def test_vector_search(self, trunks_service, test_trunk_data):
  30. results = trunks_service.search_by_vector("各种致病因素作用引起的有效循环血容量急剧减少",10,conversation_id="1111111")
  31. print("搜索结果:", results)
  32. assert len(results) > 0
  33. # def test_fulltext_search(self, trunks_service, test_trunk_data):
  34. # trunks_service.create_trunk(test_trunk_data)
  35. # results = trunks_service.fulltext_search("测试")
  36. # assert len(results) > 0
  37. class TestExceptionCases:
  38. def test_duplicate_id(self, trunks_service, test_trunk_data):
  39. with pytest.raises(IntegrityError):
  40. trunk1 = trunks_service.create_trunk(test_trunk_data)
  41. test_trunk_data["id"] = trunk1.id
  42. trunks_service.create_trunk(test_trunk_data)
  43. def test_invalid_vector_dimension(self, trunks_service, test_trunk_data):
  44. with pytest.raises(ValueError):
  45. invalid_data = test_trunk_data.copy()
  46. invalid_data["embedding"] = [0.1]*100
  47. trunks_service.create_trunk(invalid_data)
  48. @pytest.fixture
  49. def trunk_factory():
  50. class TrunkFactory:
  51. @staticmethod
  52. def create(**overrides):
  53. defaults = {
  54. "content": "工厂内容",
  55. "file_path": "factory_path.pdf",
  56. "type": "default"
  57. }
  58. return {**defaults, **overrides}
  59. return TrunkFactory()
  60. class TestBatchCreateFromDirectory:
  61. def test_batch_create_from_directory(self, trunks_service, test_data_dir):
  62. # 使用现有目录路径
  63. base_path = Path(r'E:\project\vscode\files')
  64. # 遍历目录并创建trunk
  65. created_ids = []
  66. for txt_path in base_path.glob('**/*_split_*.txt'):
  67. relative_path = txt_path.relative_to(base_path.parent.parent)
  68. with open(txt_path, 'r', encoding='utf-8') as f:
  69. trunk_data = {
  70. "content": f.read(),
  71. "file_path": str(relative_path).replace('\\', '/')
  72. }
  73. trunk = trunks_service.create_trunk(trunk_data)
  74. created_ids.append(trunk.id)
  75. # 验证数据库记录
  76. for trunk_id in created_ids:
  77. db_trunk = trunks_service.get_trunk_by_id(trunk_id)
  78. assert db_trunk is not None
  79. assert ".txt" in db_trunk.file_path
  80. assert "_split_" in db_trunk.file_path
  81. assert len(db_trunk.content) > 0