test_trunks_service.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import regex
  2. from pathlib import Path
  3. import pytest
  4. from service.trunks_service import TrunksService
  5. from model.trunks_model import Trunks
  6. from sqlalchemy.exc import IntegrityError
  7. @pytest.fixture(scope="module")
  8. def trunks_service():
  9. return TrunksService()
  10. @pytest.fixture
  11. def test_trunk_data():
  12. return {
  13. "content": """测试""",
  14. "file_path": "test_path.pdf",
  15. "type": "default"
  16. }
  17. class TestTrunksServiceCRUD:
  18. def test_create_and_get_trunk(self, trunks_service, test_trunk_data):
  19. # 测试创建和查询
  20. created = trunks_service.create_trunk(test_trunk_data)
  21. assert created.id is not None
  22. def test_update_trunk(self, trunks_service, test_trunk_data):
  23. trunk = trunks_service.create_trunk(test_trunk_data)
  24. updated = trunks_service.update_trunk(trunk.id, {"content": "更新内容"})
  25. assert updated.content == "更新内容"
  26. def test_delete_trunk(self, trunks_service, test_trunk_data):
  27. trunk = trunks_service.create_trunk(test_trunk_data)
  28. assert trunks_service.delete_trunk(trunk.id)
  29. assert trunks_service.get_trunk_by_id(trunk.id) is None
  30. class TestSearchOperations:
  31. def test_vector_search2(self, trunks_service):
  32. page = 1
  33. limit = 100
  34. file_path = '急诊医学(第2版'
  35. while True:
  36. results = trunks_service.paginated_search_by_type_and_filepath(
  37. {'pageNo': page, 'limit': limit, 'type': 'trunk', 'file_path': file_path})
  38. if not results['data']:
  39. break
  40. for record in results['data']:
  41. print(f"{record['id']}{record['type']}{record['title']}{record['file_path']}")
  42. if record['type'] != 'trunk' or file_path not in record['file_path']:
  43. print('出现异常数据')
  44. break
  45. page_no = self.get_page_no(record['content'], trunks_service, file_path)
  46. if page_no is None:
  47. print(f"{record['id']}找到page_no: {page_no}")
  48. continue
  49. trunks_service.update_trunk(record['id'], {'page_no': page_no})
  50. page += 1
  51. def test_vector_search(self, trunks_service):
  52. page = 1
  53. limit = 100
  54. file_path='trunk2'
  55. while True:
  56. results = trunks_service.paginated_search_by_type_and_filepath({'pageNo': page, 'limit': limit, 'type': 'trunk', 'file_path': file_path})
  57. if not results['data']:
  58. break
  59. for record in results['data']:
  60. print(f"{record['id']}{record['type']}{record['title']}{record['file_path']}")
  61. if record['type'] != 'trunk' or file_path not in record['file_path']:
  62. print('出现异常数据')
  63. break
  64. page_no = self.get_page_no(record['content'],trunks_service,file_path)
  65. if page_no is None:
  66. print(f"{record['id']}找到page_no: {page_no}")
  67. continue
  68. trunks_service.update_trunk(record['id'], {'page_no': page_no})
  69. page += 1
  70. def get_page_no(self, text: str, trunks_service,file_path:str) -> int:
  71. results = trunks_service.search_by_vector(text,1000,type='page',file_path=file_path,conversation_id="1111111aaaa")
  72. sentences = self.split_text(text)
  73. count = 0
  74. for r in results:
  75. #将r["content"]的所有空白字符去掉
  76. content = regex.sub(r'[^\w\d\p{L}]', '', r["content"])
  77. count+=1
  78. match_count = 0
  79. length = len(sentences)/2
  80. for sentence in sentences:
  81. sentence = regex.sub(r'[^\w\d\p{L}]', '', sentence)
  82. if sentence in content:
  83. match_count += 1
  84. if match_count >= 2:
  85. return r["page_no"]
  86. def test_match_trunk(self,trunks_service) -> int:
  87. must_matchs = ['心肌梗死']
  88. keywords = [ '概述']
  89. text = '''- 主要病因:
  90. 1. 冠状动脉粥样硬化(占90%以上)
  91. 2. 冠状动脉栓塞(如房颤血栓脱落)
  92. 3. 冠状动脉痉挛(可卡因滥用等)
  93. - 危险因素:
  94. 1. 吸烟(RR=2.87)
  95. 2. 高血压(RR=2.50)
  96. 3. LDL-C≥190mg/dL(RR=4.48)
  97. - 遗传因素:
  98. 家族性高胆固醇血症(OMIM#143890)'''
  99. text = regex.sub(r'[^\w\d\p{L}]', '', text)
  100. results = trunks_service.search_by_vector(text,1000,distance=0.72,type='trunk')
  101. print(f"原结果: {results[0]["meta_header"]}")
  102. print(results[0]["content"])
  103. max_match_count = 0
  104. best_match = None
  105. for r in results:
  106. if all(must_match in r["content"] or must_match in r["meta_header"] for must_match in must_matchs):
  107. match_count = sum(keyword in r["content"] for keyword in keywords)
  108. if match_count > max_match_count:
  109. max_match_count = match_count
  110. best_match = r
  111. elif best_match is None and max_match_count == 0:
  112. best_match = r
  113. if best_match:
  114. print(f"最佳匹配: {best_match["title"]}")
  115. print(best_match["content"])
  116. return best_match
  117. def split_text(self, text):
  118. """将文本分割成句子"""
  119. print(text)
  120. # 使用常见的标点符号作为分隔符
  121. delimiters = ['!', '?', '。', '!', '?', '\n', ';', '。', ';']
  122. sentences = [text]
  123. for delimiter in delimiters:
  124. new_sentences = []
  125. for sentence in sentences:
  126. parts = sentence.split(delimiter)
  127. new_sentences.extend([part + delimiter if i < len(parts) - 1 else part for i, part in enumerate(parts)])
  128. sentences = [s.strip() for s in new_sentences if s.strip()]
  129. # 合并短句子
  130. merged_sentences = []
  131. buffer = ""
  132. for sentence in sentences:
  133. buffer += " " + sentence if buffer else sentence
  134. if len(buffer) >= 10:
  135. merged_sentences.append(buffer)
  136. buffer = ""
  137. if buffer:
  138. merged_sentences.append(buffer)
  139. # 打印最终句子
  140. for i, sentence in enumerate(merged_sentences):
  141. print(f"句子{i+1}: {sentence.replace(" ","").replace("\u2003", "").replace("\u2002", "").replace("\u2009", "").replace("\n", "").replace("\r", "")}")
  142. return merged_sentences
  143. class TestExceptionCases:
  144. def test_duplicate_id(self, trunks_service, test_trunk_data):
  145. with pytest.raises(IntegrityError):
  146. trunk1 = trunks_service.create_trunk(test_trunk_data)
  147. test_trunk_data["id"] = trunk1.id
  148. trunks_service.create_trunk(test_trunk_data)
  149. def test_invalid_vector_dimension(self, trunks_service, test_trunk_data):
  150. with pytest.raises(ValueError):
  151. invalid_data = test_trunk_data.copy()
  152. invalid_data["embedding"] = [0.1]*100
  153. trunks_service.create_trunk(invalid_data)
  154. @pytest.fixture
  155. def trunk_factory():
  156. class TrunkFactory:
  157. @staticmethod
  158. def create(**overrides):
  159. defaults = {
  160. "content": "工厂内容",
  161. "file_path": "factory_path.pdf",
  162. "type": "default"
  163. }
  164. return {**defaults, **overrides}
  165. return TrunkFactory()
  166. class TestBatchCreateFromDirectory:
  167. def test_batch_create_from_directory(self, trunks_service):
  168. # 使用现有目录路径
  169. base_path = Path(r'E:\project\vscode\《急诊医学(第2版)》')
  170. # 遍历目录并创建trunk
  171. created_ids = []
  172. for txt_path in base_path.glob('**/*_split_*.txt'):
  173. relative_path = txt_path.relative_to(base_path.parent.parent)
  174. with open(txt_path, 'r', encoding='utf-8') as f:
  175. trunk_data = {
  176. "content": f.read(),
  177. "file_path": str(relative_path).replace('\\', '/')
  178. }
  179. trunk = trunks_service.create_trunk(trunk_data)
  180. created_ids.append(trunk.id)
  181. # 验证数据库记录
  182. for trunk_id in created_ids:
  183. db_trunk = trunks_service.get_trunk_by_id(trunk_id)
  184. assert db_trunk is not None
  185. assert ".txt" in db_trunk.file_path
  186. assert "_split_" in db_trunk.file_path
  187. assert len(db_trunk.content) > 0