import regex from pathlib import Path import pytest from service.trunks_service import TrunksService from model.trunks_model import Trunks from sqlalchemy.exc import IntegrityError @pytest.fixture(scope="module") def trunks_service(): return TrunksService() @pytest.fixture def test_trunk_data(): return { "content": """测试""", "file_path": "test_path.pdf", "type": "default" } class TestTrunksServiceCRUD: def test_create_and_get_trunk(self, trunks_service, test_trunk_data): # 测试创建和查询 created = trunks_service.create_trunk(test_trunk_data) assert created.id is not None def test_update_trunk(self, trunks_service, test_trunk_data): trunk = trunks_service.create_trunk(test_trunk_data) updated = trunks_service.update_trunk(trunk.id, {"content": "更新内容"}) assert updated.content == "更新内容" def test_delete_trunk(self, trunks_service, test_trunk_data): trunk = trunks_service.create_trunk(test_trunk_data) assert trunks_service.delete_trunk(trunk.id) assert trunks_service.get_trunk_by_id(trunk.id) is None class TestSearchOperations: def test_vector_search2(self, trunks_service): page = 1 limit = 100 file_path = '急诊医学(第2版' while True: results = trunks_service.paginated_search_by_type_and_filepath( {'pageNo': page, 'limit': limit, 'type': 'trunk', 'file_path': file_path}) if not results['data']: break for record in results['data']: print(f"{record['id']}{record['type']}{record['title']}{record['file_path']}") if record['type'] != 'trunk' or file_path not in record['file_path']: print('出现异常数据') break page_no = self.get_page_no(record['content'], trunks_service, file_path) if page_no is None: print(f"{record['id']}找到page_no: {page_no}") continue trunks_service.update_trunk(record['id'], {'page_no': page_no}) page += 1 def test_vector_search(self, trunks_service): page = 1 limit = 100 file_path='trunk2' while True: results = trunks_service.paginated_search_by_type_and_filepath({'pageNo': page, 'limit': limit, 'type': 'trunk', 'file_path': file_path}) if not results['data']: break for record in results['data']: print(f"{record['id']}{record['type']}{record['title']}{record['file_path']}") if record['type'] != 'trunk' or file_path not in record['file_path']: print('出现异常数据') break page_no = self.get_page_no(record['content'],trunks_service,file_path) if page_no is None: print(f"{record['id']}找到page_no: {page_no}") continue trunks_service.update_trunk(record['id'], {'page_no': page_no}) page += 1 def get_page_no(self, text: str, trunks_service,file_path:str) -> int: results = trunks_service.search_by_vector(text,1000,type='page',file_path=file_path,conversation_id="1111111aaaa") sentences = self.split_text(text) count = 0 for r in results: #将r["content"]的所有空白字符去掉 content = regex.sub(r'[^\w\d\p{L}]', '', r["content"]) count+=1 match_count = 0 length = len(sentences)/2 for sentence in sentences: sentence = regex.sub(r'[^\w\d\p{L}]', '', sentence) if sentence in content: match_count += 1 if match_count >= 2: return r["page_no"] def test_match_trunk(self,trunks_service) -> int: must_matchs = ['心肌梗死'] keywords = [ '概述'] text = '''- 主要病因: 1. 冠状动脉粥样硬化(占90%以上) 2. 冠状动脉栓塞(如房颤血栓脱落) 3. 冠状动脉痉挛(可卡因滥用等) - 危险因素: 1. 吸烟(RR=2.87) 2. 高血压(RR=2.50) 3. LDL-C≥190mg/dL(RR=4.48) - 遗传因素: 家族性高胆固醇血症(OMIM#143890)''' text = regex.sub(r'[^\w\d\p{L}]', '', text) results = trunks_service.search_by_vector(text,1000,distance=0.72,type='trunk') print(f"原结果: {results[0]["meta_header"]}") print(results[0]["content"]) max_match_count = 0 best_match = None for r in results: if all(must_match in r["content"] or must_match in r["meta_header"] for must_match in must_matchs): match_count = sum(keyword in r["content"] for keyword in keywords) if match_count > max_match_count: max_match_count = match_count best_match = r elif best_match is None and max_match_count == 0: best_match = r if best_match: print(f"最佳匹配: {best_match["title"]}") print(best_match["content"]) return best_match def split_text(self, text): """将文本分割成句子""" print(text) # 使用常见的标点符号作为分隔符 delimiters = ['!', '?', '。', '!', '?', '\n', ';', '。', ';'] sentences = [text] for delimiter in delimiters: new_sentences = [] for sentence in sentences: parts = sentence.split(delimiter) new_sentences.extend([part + delimiter if i < len(parts) - 1 else part for i, part in enumerate(parts)]) sentences = [s.strip() for s in new_sentences if s.strip()] # 合并短句子 merged_sentences = [] buffer = "" for sentence in sentences: buffer += " " + sentence if buffer else sentence if len(buffer) >= 10: merged_sentences.append(buffer) buffer = "" if buffer: merged_sentences.append(buffer) # 打印最终句子 for i, sentence in enumerate(merged_sentences): print(f"句子{i+1}: {sentence.replace(" ","").replace("\u2003", "").replace("\u2002", "").replace("\u2009", "").replace("\n", "").replace("\r", "")}") return merged_sentences class TestExceptionCases: def test_duplicate_id(self, trunks_service, test_trunk_data): with pytest.raises(IntegrityError): trunk1 = trunks_service.create_trunk(test_trunk_data) test_trunk_data["id"] = trunk1.id trunks_service.create_trunk(test_trunk_data) def test_invalid_vector_dimension(self, trunks_service, test_trunk_data): with pytest.raises(ValueError): invalid_data = test_trunk_data.copy() invalid_data["embedding"] = [0.1]*100 trunks_service.create_trunk(invalid_data) @pytest.fixture def trunk_factory(): class TrunkFactory: @staticmethod def create(**overrides): defaults = { "content": "工厂内容", "file_path": "factory_path.pdf", "type": "default" } return {**defaults, **overrides} return TrunkFactory() class TestBatchCreateFromDirectory: def test_batch_create_from_directory(self, trunks_service): # 使用现有目录路径 base_path = Path(r'E:\project\vscode\《急诊医学(第2版)》') # 遍历目录并创建trunk created_ids = [] for txt_path in base_path.glob('**/*_split_*.txt'): relative_path = txt_path.relative_to(base_path.parent.parent) with open(txt_path, 'r', encoding='utf-8') as f: trunk_data = { "content": f.read(), "file_path": str(relative_path).replace('\\', '/') } trunk = trunks_service.create_trunk(trunk_data) created_ids.append(trunk.id) # 验证数据库记录 for trunk_id in created_ids: db_trunk = trunks_service.get_trunk_by_id(trunk_id) assert db_trunk is not None assert ".txt" in db_trunk.file_path assert "_split_" in db_trunk.file_path assert len(db_trunk.content) > 0