test_trunks_service.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import pytest
  2. from service.trunks_service import TrunksService
  3. from model.trunks_model import Trunks
  4. from sqlalchemy.exc import IntegrityError
  5. @pytest.fixture(scope="module")
  6. def trunks_service():
  7. return TrunksService()
  8. @pytest.fixture
  9. def test_trunk_data():
  10. return {
  11. "content": """测试""",
  12. "file_path": "test_path.pdf",
  13. "type": "default"
  14. }
  15. class TestTrunksServiceCRUD:
  16. def test_create_and_get_trunk(self, trunks_service, test_trunk_data):
  17. # 测试创建和查询
  18. created = trunks_service.create_trunk(test_trunk_data)
  19. assert created.id is not None
  20. def test_update_trunk(self, trunks_service, test_trunk_data):
  21. trunk = trunks_service.create_trunk(test_trunk_data)
  22. updated = trunks_service.update_trunk(trunk.id, {"content": "更新内容"})
  23. assert updated.content == "更新内容"
  24. def test_delete_trunk(self, trunks_service, test_trunk_data):
  25. trunk = trunks_service.create_trunk(test_trunk_data)
  26. assert trunks_service.delete_trunk(trunk.id)
  27. assert trunks_service.get_trunk_by_id(trunk.id) is None
  28. class TestSearchOperations:
  29. def test_vector_search(self, trunks_service, test_trunk_data):
  30. results = trunks_service.search_by_vector("各种致病因素作用引起的有效循环血容量急剧减少",10,conversation_id="1111111aaaa")
  31. print("搜索结果:", results)
  32. results = trunks_service.get_cache("1111111aaaa")
  33. print("搜索结果:", results)
  34. assert len(results) > 0
  35. # def test_fulltext_search(self, trunks_service, test_trunk_data):
  36. # trunks_service.create_trunk(test_trunk_data)
  37. # results = trunks_service.fulltext_search("测试")
  38. # assert len(results) > 0
  39. class TestExceptionCases:
  40. def test_duplicate_id(self, trunks_service, test_trunk_data):
  41. with pytest.raises(IntegrityError):
  42. trunk1 = trunks_service.create_trunk(test_trunk_data)
  43. test_trunk_data["id"] = trunk1.id
  44. trunks_service.create_trunk(test_trunk_data)
  45. def test_invalid_vector_dimension(self, trunks_service, test_trunk_data):
  46. with pytest.raises(ValueError):
  47. invalid_data = test_trunk_data.copy()
  48. invalid_data["embedding"] = [0.1]*100
  49. trunks_service.create_trunk(invalid_data)
  50. @pytest.fixture
  51. def trunk_factory():
  52. class TrunkFactory:
  53. @staticmethod
  54. def create(**overrides):
  55. defaults = {
  56. "content": "工厂内容",
  57. "file_path": "factory_path.pdf",
  58. "type": "default"
  59. }
  60. return {**defaults, **overrides}
  61. return TrunkFactory()
  62. class TestBatchCreateFromDirectory:
  63. def test_batch_create_from_directory(self, trunks_service, test_data_dir):
  64. # 使用现有目录路径
  65. base_path = Path(r'E:\project\vscode\files')
  66. # 遍历目录并创建trunk
  67. created_ids = []
  68. for txt_path in base_path.glob('**/*_split_*.txt'):
  69. relative_path = txt_path.relative_to(base_path.parent.parent)
  70. with open(txt_path, 'r', encoding='utf-8') as f:
  71. trunk_data = {
  72. "content": f.read(),
  73. "file_path": str(relative_path).replace('\\', '/')
  74. }
  75. trunk = trunks_service.create_trunk(trunk_data)
  76. created_ids.append(trunk.id)
  77. # 验证数据库记录
  78. for trunk_id in created_ids:
  79. db_trunk = trunks_service.get_trunk_by_id(trunk_id)
  80. assert db_trunk is not None
  81. assert ".txt" in db_trunk.file_path
  82. assert "_split_" in db_trunk.file_path
  83. assert len(db_trunk.content) > 0