import pytest from service.trunks_service import TrunksService from model.trunks_model import Trunks from sqlalchemy.exc import IntegrityError @pytest.fixture(scope="module") def trunks_service(): return TrunksService() @pytest.fixture def test_trunk_data(): return { "content": """测试""", "file_path": "test_path.pdf", "type": "default" } class TestTrunksServiceCRUD: def test_create_and_get_trunk(self, trunks_service, test_trunk_data): # 测试创建和查询 created = trunks_service.create_trunk(test_trunk_data) assert created.id is not None def test_update_trunk(self, trunks_service, test_trunk_data): trunk = trunks_service.create_trunk(test_trunk_data) updated = trunks_service.update_trunk(trunk.id, {"content": "更新内容"}) assert updated.content == "更新内容" def test_delete_trunk(self, trunks_service, test_trunk_data): trunk = trunks_service.create_trunk(test_trunk_data) assert trunks_service.delete_trunk(trunk.id) assert trunks_service.get_trunk_by_id(trunk.id) is None class TestSearchOperations: def test_vector_search(self, trunks_service, test_trunk_data): results = trunks_service.search_by_vector("各种致病因素作用引起的有效循环血容量急剧减少",10,conversation_id="1111111") print("搜索结果:", results) assert len(results) > 0 # def test_fulltext_search(self, trunks_service, test_trunk_data): # trunks_service.create_trunk(test_trunk_data) # results = trunks_service.fulltext_search("测试") # assert len(results) > 0 class TestExceptionCases: def test_duplicate_id(self, trunks_service, test_trunk_data): with pytest.raises(IntegrityError): trunk1 = trunks_service.create_trunk(test_trunk_data) test_trunk_data["id"] = trunk1.id trunks_service.create_trunk(test_trunk_data) def test_invalid_vector_dimension(self, trunks_service, test_trunk_data): with pytest.raises(ValueError): invalid_data = test_trunk_data.copy() invalid_data["embedding"] = [0.1]*100 trunks_service.create_trunk(invalid_data) @pytest.fixture def trunk_factory(): class TrunkFactory: @staticmethod def create(**overrides): defaults = { "content": "工厂内容", "file_path": "factory_path.pdf", "type": "default" } return {**defaults, **overrides} return TrunkFactory() class TestBatchCreateFromDirectory: def test_batch_create_from_directory(self, trunks_service, test_data_dir): # 使用现有目录路径 base_path = Path(r'E:\project\vscode\files') # 遍历目录并创建trunk created_ids = [] for txt_path in base_path.glob('**/*_split_*.txt'): relative_path = txt_path.relative_to(base_path.parent.parent) with open(txt_path, 'r', encoding='utf-8') as f: trunk_data = { "content": f.read(), "file_path": str(relative_path).replace('\\', '/') } trunk = trunks_service.create_trunk(trunk_data) created_ids.append(trunk.id) # 验证数据库记录 for trunk_id in created_ids: db_trunk = trunks_service.get_trunk_by_id(trunk_id) assert db_trunk is not None assert ".txt" in db_trunk.file_path assert "_split_" in db_trunk.file_path assert len(db_trunk.content) > 0