123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- import regex
- from pathlib import Path
- import pytest
- from service.trunks_service import TrunksService
- from model.trunks_model import Trunks
- from sqlalchemy.exc import IntegrityError
- @pytest.fixture(scope="module")
- def trunks_service():
- return TrunksService()
- @pytest.fixture
- def test_trunk_data():
- return {
- "content": """测试""",
- "file_path": "test_path.pdf",
- "type": "default"
- }
- class TestTrunksServiceCRUD:
- def test_create_and_get_trunk(self, trunks_service, test_trunk_data):
- # 测试创建和查询
- created = trunks_service.create_trunk(test_trunk_data)
- assert created.id is not None
- def test_update_trunk(self, trunks_service, test_trunk_data):
- trunk = trunks_service.create_trunk(test_trunk_data)
- updated = trunks_service.update_trunk(trunk.id, {"content": "更新内容"})
- assert updated.content == "更新内容"
- def test_delete_trunk(self, trunks_service, test_trunk_data):
- trunk = trunks_service.create_trunk(test_trunk_data)
- assert trunks_service.delete_trunk(trunk.id)
- assert trunks_service.get_trunk_by_id(trunk.id) is None
- class TestSearchOperations:
- def test_vector_search2(self, trunks_service):
- page = 1
- limit = 100
- file_path = '急诊医学(第2版'
- while True:
- results = trunks_service.paginated_search_by_type_and_filepath(
- {'pageNo': page, 'limit': limit, 'type': 'trunk', 'file_path': file_path})
- if not results['data']:
- break
- for record in results['data']:
- print(f"{record['id']}{record['type']}{record['title']}{record['file_path']}")
- if record['type'] != 'trunk' or file_path not in record['file_path']:
- print('出现异常数据')
- break
- page_no = self.get_page_no(record['content'], trunks_service, file_path)
- if page_no is None:
- print(f"{record['id']}找到page_no: {page_no}")
- continue
- trunks_service.update_trunk(record['id'], {'page_no': page_no})
- page += 1
- def test_vector_search(self, trunks_service):
- page = 1
- limit = 100
- file_path='trunk2'
- while True:
- results = trunks_service.paginated_search_by_type_and_filepath({'pageNo': page, 'limit': limit, 'type': 'trunk', 'file_path': file_path})
- if not results['data']:
- break
- for record in results['data']:
- print(f"{record['id']}{record['type']}{record['title']}{record['file_path']}")
- if record['type'] != 'trunk' or file_path not in record['file_path']:
- print('出现异常数据')
- break
- page_no = self.get_page_no(record['content'],trunks_service,file_path)
- if page_no is None:
- print(f"{record['id']}找到page_no: {page_no}")
- continue
- trunks_service.update_trunk(record['id'], {'page_no': page_no})
- page += 1
- def get_page_no(self, text: str, trunks_service,file_path:str) -> int:
- results = trunks_service.search_by_vector(text,1000,type='page',file_path=file_path,conversation_id="1111111aaaa")
- sentences = self.split_text(text)
- count = 0
- for r in results:
- #将r["content"]的所有空白字符去掉
- content = regex.sub(r'[^\w\d\p{L}]', '', r["content"])
- count+=1
- match_count = 0
- length = len(sentences)/2
- for sentence in sentences:
- sentence = regex.sub(r'[^\w\d\p{L}]', '', sentence)
- if sentence in content:
- match_count += 1
- if match_count >= 2:
- return r["page_no"]
-
- def split_text(self, text):
- """将文本分割成句子"""
- print(text)
- # 使用常见的标点符号作为分隔符
- delimiters = ['!', '?', '。', '!', '?', '\n', ';', '。', ';']
- sentences = [text]
- for delimiter in delimiters:
- new_sentences = []
- for sentence in sentences:
- parts = sentence.split(delimiter)
- new_sentences.extend([part + delimiter if i < len(parts) - 1 else part for i, part in enumerate(parts)])
- sentences = [s.strip() for s in new_sentences if s.strip()]
-
- # 合并短句子
- merged_sentences = []
- buffer = ""
- for sentence in sentences:
- buffer += " " + sentence if buffer else sentence
- if len(buffer) >= 10:
- merged_sentences.append(buffer)
- buffer = ""
- if buffer:
- merged_sentences.append(buffer)
-
- # 打印最终句子
- for i, sentence in enumerate(merged_sentences):
- print(f"句子{i+1}: {sentence.replace(" ","").replace("\u2003", "").replace("\u2002", "").replace("\u2009", "").replace("\n", "").replace("\r", "")}")
-
- return merged_sentences
- class TestExceptionCases:
- def test_duplicate_id(self, trunks_service, test_trunk_data):
- with pytest.raises(IntegrityError):
- trunk1 = trunks_service.create_trunk(test_trunk_data)
- test_trunk_data["id"] = trunk1.id
- trunks_service.create_trunk(test_trunk_data)
- def test_invalid_vector_dimension(self, trunks_service, test_trunk_data):
- with pytest.raises(ValueError):
- invalid_data = test_trunk_data.copy()
- invalid_data["embedding"] = [0.1]*100
- trunks_service.create_trunk(invalid_data)
- @pytest.fixture
- def trunk_factory():
- class TrunkFactory:
- @staticmethod
- def create(**overrides):
- defaults = {
- "content": "工厂内容",
- "file_path": "factory_path.pdf",
- "type": "default"
- }
- return {**defaults, **overrides}
- return TrunkFactory()
- class TestBatchCreateFromDirectory:
- def test_batch_create_from_directory(self, trunks_service):
- # 使用现有目录路径
- base_path = Path(r'E:\project\vscode\《急诊医学(第2版)》')
-
- # 遍历目录并创建trunk
- created_ids = []
- for txt_path in base_path.glob('**/*_split_*.txt'):
- relative_path = txt_path.relative_to(base_path.parent.parent)
- with open(txt_path, 'r', encoding='utf-8') as f:
- trunk_data = {
- "content": f.read(),
- "file_path": str(relative_path).replace('\\', '/')
- }
- trunk = trunks_service.create_trunk(trunk_data)
- created_ids.append(trunk.id)
- # 验证数据库记录
- for trunk_id in created_ids:
- db_trunk = trunks_service.get_trunk_by_id(trunk_id)
- assert db_trunk is not None
- assert ".txt" in db_trunk.file_path
- assert "_split_" in db_trunk.file_path
- assert len(db_trunk.content) > 0
|