1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- import pytest
- from service.trunks_service import TrunksService
- from model.trunks_model import Trunks
- from sqlalchemy.exc import IntegrityError
- @pytest.fixture(scope="module")
- def trunks_service():
- return TrunksService()
- @pytest.fixture
- def test_trunk_data():
- return {
- "content": """测试""",
- "file_path": "test_path.pdf",
- "type": "default"
- }
- class TestTrunksServiceCRUD:
- def test_create_and_get_trunk(self, trunks_service, test_trunk_data):
- # 测试创建和查询
- created = trunks_service.create_trunk(test_trunk_data)
- assert created.id is not None
- def test_update_trunk(self, trunks_service, test_trunk_data):
- trunk = trunks_service.create_trunk(test_trunk_data)
- updated = trunks_service.update_trunk(trunk.id, {"content": "更新内容"})
- assert updated.content == "更新内容"
- def test_delete_trunk(self, trunks_service, test_trunk_data):
- trunk = trunks_service.create_trunk(test_trunk_data)
- assert trunks_service.delete_trunk(trunk.id)
- assert trunks_service.get_trunk_by_id(trunk.id) is None
- class TestSearchOperations:
- def test_vector_search(self, trunks_service, test_trunk_data):
- results = trunks_service.search_by_vector("各种致病因素作用引起的有效循环血容量急剧减少",10,conversation_id="1111111")
- print("搜索结果:", results)
- assert len(results) > 0
- # def test_fulltext_search(self, trunks_service, test_trunk_data):
- # trunks_service.create_trunk(test_trunk_data)
- # results = trunks_service.fulltext_search("测试")
- # assert len(results) > 0
- class TestExceptionCases:
- def test_duplicate_id(self, trunks_service, test_trunk_data):
- with pytest.raises(IntegrityError):
- trunk1 = trunks_service.create_trunk(test_trunk_data)
- test_trunk_data["id"] = trunk1.id
- trunks_service.create_trunk(test_trunk_data)
- def test_invalid_vector_dimension(self, trunks_service, test_trunk_data):
- with pytest.raises(ValueError):
- invalid_data = test_trunk_data.copy()
- invalid_data["embedding"] = [0.1]*100
- trunks_service.create_trunk(invalid_data)
- @pytest.fixture
- def trunk_factory():
- class TrunkFactory:
- @staticmethod
- def create(**overrides):
- defaults = {
- "content": "工厂内容",
- "file_path": "factory_path.pdf",
- "type": "default"
- }
- return {**defaults, **overrides}
- return TrunkFactory()
- class TestBatchCreateFromDirectory:
- def test_batch_create_from_directory(self, trunks_service, test_data_dir):
- # 使用现有目录路径
- base_path = Path(r'E:\project\vscode\files')
-
- # 遍历目录并创建trunk
- created_ids = []
- for txt_path in base_path.glob('**/*_split_*.txt'):
- relative_path = txt_path.relative_to(base_path.parent.parent)
- with open(txt_path, 'r', encoding='utf-8') as f:
- trunk_data = {
- "content": f.read(),
- "file_path": str(relative_path).replace('\\', '/')
- }
- trunk = trunks_service.create_trunk(trunk_data)
- created_ids.append(trunk.id)
- # 验证数据库记录
- for trunk_id in created_ids:
- db_trunk = trunks_service.get_trunk_by_id(trunk_id)
- assert db_trunk is not None
- assert ".txt" in db_trunk.file_path
- assert "_split_" in db_trunk.file_path
- assert len(db_trunk.content) > 0
|