test_trunks_service.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import regex
  2. from pathlib import Path
  3. import pytest
  4. from service.trunks_service import TrunksService
  5. from model.trunks_model import Trunks
  6. from sqlalchemy.exc import IntegrityError
  7. @pytest.fixture(scope="module")
  8. def trunks_service():
  9. return TrunksService()
  10. @pytest.fixture
  11. def test_trunk_data():
  12. return {
  13. "content": """测试""",
  14. "file_path": "test_path.pdf",
  15. "type": "default"
  16. }
  17. class TestTrunksServiceCRUD:
  18. def test_create_and_get_trunk(self, trunks_service, test_trunk_data):
  19. # 测试创建和查询
  20. created = trunks_service.create_trunk(test_trunk_data)
  21. assert created.id is not None
  22. def test_update_trunk(self, trunks_service, test_trunk_data):
  23. trunk = trunks_service.create_trunk(test_trunk_data)
  24. updated = trunks_service.update_trunk(trunk.id, {"content": "更新内容"})
  25. assert updated.content == "更新内容"
  26. def test_delete_trunk(self, trunks_service, test_trunk_data):
  27. trunk = trunks_service.create_trunk(test_trunk_data)
  28. assert trunks_service.delete_trunk(trunk.id)
  29. assert trunks_service.get_trunk_by_id(trunk.id) is None
  30. class TestSearchOperations:
  31. def test_vector_search2(self, trunks_service):
  32. page = 1
  33. limit = 100
  34. file_path = '急诊医学(第2版'
  35. while True:
  36. results = trunks_service.paginated_search_by_type_and_filepath(
  37. {'pageNo': page, 'limit': limit, 'type': 'trunk', 'file_path': file_path})
  38. if not results['data']:
  39. break
  40. for record in results['data']:
  41. print(f"{record['id']}{record['type']}{record['title']}{record['file_path']}")
  42. if record['type'] != 'trunk' or file_path not in record['file_path']:
  43. print('出现异常数据')
  44. break
  45. page_no = self.get_page_no(record['content'], trunks_service, file_path)
  46. if page_no is None:
  47. print(f"{record['id']}找到page_no: {page_no}")
  48. continue
  49. trunks_service.update_trunk(record['id'], {'page_no': page_no})
  50. page += 1
  51. def test_vector_search(self, trunks_service):
  52. page = 1
  53. limit = 100
  54. file_path='trunk2'
  55. while True:
  56. results = trunks_service.paginated_search_by_type_and_filepath({'pageNo': page, 'limit': limit, 'type': 'trunk', 'file_path': file_path})
  57. if not results['data']:
  58. break
  59. for record in results['data']:
  60. print(f"{record['id']}{record['type']}{record['title']}{record['file_path']}")
  61. if record['type'] != 'trunk' or file_path not in record['file_path']:
  62. print('出现异常数据')
  63. break
  64. page_no = self.get_page_no(record['content'],trunks_service,file_path)
  65. if page_no is None:
  66. print(f"{record['id']}找到page_no: {page_no}")
  67. continue
  68. trunks_service.update_trunk(record['id'], {'page_no': page_no})
  69. page += 1
  70. def get_page_no(self, text: str, trunks_service,file_path:str) -> int:
  71. results = trunks_service.search_by_vector(text,1000,type='page',file_path=file_path,conversation_id="1111111aaaa")
  72. sentences = self.split_text(text)
  73. count = 0
  74. for r in results:
  75. #将r["content"]的所有空白字符去掉
  76. content = regex.sub(r'[^\w\d\p{L}]', '', r["content"])
  77. count+=1
  78. match_count = 0
  79. length = len(sentences)/2
  80. for sentence in sentences:
  81. sentence = regex.sub(r'[^\w\d\p{L}]', '', sentence)
  82. if sentence in content:
  83. match_count += 1
  84. if match_count >= 2:
  85. return r["page_no"]
  86. def split_text(self, text):
  87. """将文本分割成句子"""
  88. print(text)
  89. # 使用常见的标点符号作为分隔符
  90. delimiters = ['!', '?', '。', '!', '?', '\n', ';', '。', ';']
  91. sentences = [text]
  92. for delimiter in delimiters:
  93. new_sentences = []
  94. for sentence in sentences:
  95. parts = sentence.split(delimiter)
  96. new_sentences.extend([part + delimiter if i < len(parts) - 1 else part for i, part in enumerate(parts)])
  97. sentences = [s.strip() for s in new_sentences if s.strip()]
  98. # 合并短句子
  99. merged_sentences = []
  100. buffer = ""
  101. for sentence in sentences:
  102. buffer += " " + sentence if buffer else sentence
  103. if len(buffer) >= 10:
  104. merged_sentences.append(buffer)
  105. buffer = ""
  106. if buffer:
  107. merged_sentences.append(buffer)
  108. # 打印最终句子
  109. for i, sentence in enumerate(merged_sentences):
  110. print(f"句子{i+1}: {sentence.replace(" ","").replace("\u2003", "").replace("\u2002", "").replace("\u2009", "").replace("\n", "").replace("\r", "")}")
  111. return merged_sentences
  112. class TestExceptionCases:
  113. def test_duplicate_id(self, trunks_service, test_trunk_data):
  114. with pytest.raises(IntegrityError):
  115. trunk1 = trunks_service.create_trunk(test_trunk_data)
  116. test_trunk_data["id"] = trunk1.id
  117. trunks_service.create_trunk(test_trunk_data)
  118. def test_invalid_vector_dimension(self, trunks_service, test_trunk_data):
  119. with pytest.raises(ValueError):
  120. invalid_data = test_trunk_data.copy()
  121. invalid_data["embedding"] = [0.1]*100
  122. trunks_service.create_trunk(invalid_data)
  123. @pytest.fixture
  124. def trunk_factory():
  125. class TrunkFactory:
  126. @staticmethod
  127. def create(**overrides):
  128. defaults = {
  129. "content": "工厂内容",
  130. "file_path": "factory_path.pdf",
  131. "type": "default"
  132. }
  133. return {**defaults, **overrides}
  134. return TrunkFactory()
  135. class TestBatchCreateFromDirectory:
  136. def test_batch_create_from_directory(self, trunks_service):
  137. # 使用现有目录路径
  138. base_path = Path(r'E:\project\vscode\《急诊医学(第2版)》')
  139. # 遍历目录并创建trunk
  140. created_ids = []
  141. for txt_path in base_path.glob('**/*_split_*.txt'):
  142. relative_path = txt_path.relative_to(base_path.parent.parent)
  143. with open(txt_path, 'r', encoding='utf-8') as f:
  144. trunk_data = {
  145. "content": f.read(),
  146. "file_path": str(relative_path).replace('\\', '/')
  147. }
  148. trunk = trunks_service.create_trunk(trunk_data)
  149. created_ids.append(trunk.id)
  150. # 验证数据库记录
  151. for trunk_id in created_ids:
  152. db_trunk = trunks_service.get_trunk_by_id(trunk_id)
  153. assert db_trunk is not None
  154. assert ".txt" in db_trunk.file_path
  155. assert "_split_" in db_trunk.file_path
  156. assert len(db_trunk.content) > 0