test_trunks_service.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. import regex
  2. from pathlib import Path
  3. import pytest
  4. from service.kg_edge_service import KGEdgeService
  5. from service.kg_node_service import KGNodeService
  6. from service.kg_prop_service import KGPropService
  7. from service.trunks_service import TrunksService
  8. from model.trunks_model import Trunks
  9. from sqlalchemy.exc import IntegrityError
  10. from utils import DeepseekUtil
  11. from db.session import get_db
  12. @pytest.fixture(scope="module")
  13. def trunks_service():
  14. return TrunksService()
  15. @pytest.fixture
  16. def test_trunk_data():
  17. return {
  18. "content": """测试""",
  19. "file_path": "test_path.pdf",
  20. "type": "default"
  21. }
  22. class TestTrunksServiceCRUD:
  23. def test_create_and_get_trunk(self, trunks_service, test_trunk_data):
  24. # 测试创建和查询
  25. created = trunks_service.create_trunk(test_trunk_data)
  26. assert created.id is not None
  27. def test_update_trunk(self, trunks_service, test_trunk_data):
  28. trunk = trunks_service.create_trunk(test_trunk_data)
  29. updated = trunks_service.update_trunk(trunk.id, {"content": "更新内容"})
  30. assert updated.content == "更新内容"
  31. def test_delete_trunk(self, trunks_service, test_trunk_data):
  32. trunk = trunks_service.create_trunk(test_trunk_data)
  33. assert trunks_service.delete_trunk(trunk.id)
  34. assert trunks_service.get_trunk_by_id(trunk.id) is None
  35. class TestSearchOperations:
  36. def test_vector_search2(self, trunks_service):
  37. page = 1
  38. limit = 100
  39. file_path = '急诊医学(第2版'
  40. while True:
  41. results = trunks_service.paginated_search_by_type_and_filepath(
  42. {'pageNo': page, 'limit': limit, 'type': 'trunk', 'file_path': file_path})
  43. if not results['data']:
  44. break
  45. for record in results['data']:
  46. print(f"{record['id']}{record['type']}{record['title']}{record['file_path']}")
  47. if record['type'] != 'trunk' or file_path not in record['file_path']:
  48. print('出现异常数据')
  49. break
  50. page_no = self.get_page_no(record['content'], trunks_service, file_path)
  51. if page_no is None:
  52. print(f"{record['id']}找到page_no: {page_no}")
  53. continue
  54. trunks_service.update_trunk(record['id'], {'page_no': page_no})
  55. page += 1
  56. def test_vector_search(self, trunks_service):
  57. page = 1
  58. limit = 100
  59. file_path='trunk2'
  60. while True:
  61. results = trunks_service.paginated_search_by_type_and_filepath({'pageNo': page, 'limit': limit, 'type': 'trunk', 'file_path': file_path})
  62. if not results['data']:
  63. break
  64. for record in results['data']:
  65. print(f"{record['id']}{record['type']}{record['title']}{record['file_path']}")
  66. if record['type'] != 'trunk' or file_path not in record['file_path']:
  67. print('出现异常数据')
  68. break
  69. page_no = self.get_page_no(record['content'],trunks_service,file_path)
  70. if page_no is None:
  71. print(f"{record['id']}找到page_no: {page_no}")
  72. continue
  73. trunks_service.update_trunk(record['id'], {'page_no': page_no})
  74. page += 1
  75. def test_trunk_search(self, trunks_service):
  76. page = 1
  77. limit = 100
  78. while True:
  79. results = trunks_service.paginated_search_by_type_and_filepath({'pageNo': page, 'limit': limit, 'type': 'trunk'})
  80. if not results['data']:
  81. break
  82. for record in results['data']:
  83. prompt = '''
  84. ##角色任务
  85. 你是一个医学专家,且精通中文文本的实体和关系标注(自然语言处理),用户将提供两个医学文本:“前置医学文本”和“任务医学文本”,你的任务是从用户输入的“任务医学文本”中抽取所有的“关系型”知识要点,构建符合「头部实体---关系---尾部实体」的三元组集合。
  86. ##请严格依次执行以下步骤及其细节要求:
  87. 第一步,通读“前置医学文本”和“任务医学文本”,对“任务医学文本”形成一个整体的理解,搞明白其围绕的主题是什么?主要讲的是什么?主要说了哪几块或几点内容?。此处“前置医学文本”是“任务医学文本”的上文语义背景,其有助于你对“任务医学文本”语义的理解,且在你提取三元组时,其可能作为参考或辅助信息。
  88. 第二步【实体识别】:
  89. 规则1:可参考但不限于如下医学实体类型:疾病、药品、药品剂型、症状、体格检查项目、体格检查结果、手术和操作、实验室检查套餐、实验室检查子项目、辅助检查项目、辅助检查子项目、辅助检查描述、辅助检查结果、输血类型、麻醉、科室、性别、人群、食物、其他过敏原、医疗器械及物品、给药途径、部位、护理、量表、单位、ICD10疾病类别、药品化学物质类别、药品治疗学类别、药品药理学类别、药品解剖学类别、症状类别、手术和操作类别、ICD10疾病类别根节点、科室疾病类别根节点、药品化学物质类别根节点、药品治疗学类别根节点、药品药理学类别根节点、药品解剖学类别根节点、症状类别根节点、手术和操作类别根节点、实验室检查类别根节点、辅助检查类别根节点、年龄、疾病系统分类、性质、中医疾病、中医证候、诱因、政策法规、否定词、疾病集合、药品通用名集合、药品剂型集合、症状集合、体格检查项目集合、体格检查结果集合、手术和操作集合、实验室检查套餐集合、辅助检查项目集合、辅助检查子项目集合、辅助检查描述集合、辅助检查结果集合、麻醉集合、科室集合、食物集合、其他过敏原集合、医疗器械及物品集合、部位集合、中医疾病集合、中医证候集合、诱因集合、给药途径集合、物理治疗、经典病例、历史病例检查、检查结果、手术操作、其他治疗、人群、人体结构或部位、医疗器械、食物、病理机制等。
  90. 规则2:复合实体要拆分:如“老年人和免疫功能低下者”不能作为一个实体,需拆分为:"老年人"、"免疫功能低下者"两个实体。
  91. 规则3:文字中省略了,但语义中暗含的文字内容要进行语义补全:如文本:“胸、腹部检查”中,这里的“胸”其实是指“胸部检查”,按语义补全后应该有两个实体:胸部检查、腹部检查。同理,文本:“注意有无心内膜炎、心肌炎、心包炎体征”中,应该有三个实体:内膜炎体征、心肌炎体征、心包炎体征;文本:“有无肝脏和脾脏肿大”中,应该有两个实体:肝脏肿大、脾脏肿大。
  92. 第三步【关系构建】:
  93. 规则4:可参考的“基础关系类型”如下(但不仅仅限于这些):属于(是)、包括(包含)、导致(的结果)、是由…导致(的结果)、的原因是、是…的原因、的病因是、是…的病因、基于(基础是)、是…的基础、推荐、被推荐于、区别于、相似于、关联、疾病常关联、疾病可推荐等等。
  94. 规则5:根据上下文整体语义分析并确定出两个实体之间的“基础关系类型”后,可能还不足以描述清楚两个实体间的“详细关系”或“精准的关系”,所以你要尽量构建字数更多的详细的“关系”,需要依据原文中真实的语义,在“基础关系”中增加“修饰词”、“限定条件词”等语义描述,避免关系构建太过粗糙,避免原文关系语义信息的衰减或丢失。
  95. 规则6:语义中属于该三元组关系的修饰词、限定条件词等,需要融入到该“关系”中(如:可能、少数、30%、多数、显著、轻微、手术后、治疗无效时、满足XX条件时等),如果归属于该“关系”,则需要组合到该“关系”中使得关系更丰满(如:“可能导致”、“少数由…导致”、“手术后导致”、“满足XX条件时,推荐”等等)。
  96. 规则7:“尾部实体”不能为复合实体或并排结构,必须合理的拆分为多个三元组。
  97. 规则8:强制质量检测:“头部实体---关系类型---尾部实体”中,三者必须能组合成语法通顺的一句话,不能有语病,且关系指向不能模糊或错误。如三元组1:“A---的原因是---B”中,组合成一句话“A的原因是B”没有语病,且关系明确:B是原因,A是B原因导致的结果;如三元组2:“A---是…的原因---B”中,组合成一句话“A是B的原因”没有语病,且关系明确:A是原因,B是A原因导致的结果。注意:不是所有的关系都是镜像对称的,很多关系是单向关系,请根据语义关系的方向选择正确的关系类型,否则语义会完全相反。
  98. ##输出格式要求:
  99. 对于每个三元组关系,提取以下信息:头部实体的名称(source_entity)、尾部实体的名称(target_entity)、关系(relationship_type)、头部实体的类型(source_entity_type)、尾部实体的类型(target_entity_type),最后这些信息整理成如下格式的一个“字符串”:“头部实体的名称---关系---尾部实体的名称---头部实体的类型---尾部实体的类型”,不同的三元组“字符串”之间用换行符号连接起来。
  100. ##质量红线
  101. 1.禁止出现四元组或嵌套结构
  102. 2.禁止合并多个“差异点”或“并列实体”到单个三元组
  103. 3.每个三元组必须保留原文核心逻辑
  104. ##示例:
  105. “前置医学文本”为:
  106. 第一节,自发性气胸
  107. “任务医学文本”为:
  108. "三、诊断要点
  109. 自发性气胸通过胸部 X 线片确立诊断,条件允许时,应选择直立位拍片。
  110. 1. 既往胸部 X 线检查无明显病变或有 COPD、肺结核、哮喘等肺部基础病变。
  111. 2. 突发一侧胸痛伴不同程度的胸闷、呼吸困难。患侧胸廓饱满、呼吸运动减弱,叩诊呈鼓音,肝、 肺浊音界消失,听诊呼吸音减弱,甚至消失。
  112. 3. 发病时胸部 X 线影像学检查是诊断气胸最为准确和可靠的方法。
  113. 典型自发性气胸诊断不难。继发性气胸病人可因原有基础疾病而影响诊断,因此,对临床不能用 其他原因解释或经急诊处理呼吸困难无改善者,需考虑自发性气胸的可能。因病情危重不能立即行 胸部 X 线检查时,可在胸腔积气体征最明显处进行诊断性穿刺。"
  114. 输出为:
  115. 自发性气胸---通过...确立诊断---胸部X线片---疾病---检查
  116. 自发性气胸---条件允许时推荐选择---直立位拍片---疾病---检查操作
  117. 自发性气胸---诊断要点需结合---既往胸部X线检查无明显病变---疾病---检查结果
  118. 自发性气胸---诊断要点需结合---或有肺部基础病变---疾病---检查结果
  119. 肺部基础病变---的原因包括---COPD---检查结果---疾病
  120. 肺部基础病变---的原因包括---肺结核---检查结果---疾病
  121. 肺部基础病变---的原因包括---哮喘---检查结果---疾病
  122. 自发性气胸---诊断要点需结合---突发一侧胸痛伴不同程度的胸闷
  123. 自发性气胸---诊断要点需结合---突发一侧胸痛伴不同程度的呼吸困难
  124. 突发一侧胸痛---伴发---胸闷---症状---症状
  125. 突发一侧胸痛---伴发---呼吸困难---症状---症状
  126. 自发性气胸---患侧表现---患侧胸廓饱满---疾病---体征结果
  127. 自发性气胸---患侧表现---呼吸运动减弱---疾病---体征结果
  128. 自发性气胸---患侧表现---叩诊呈鼓音---疾病---体征结果
  129. 自发性气胸---体征结果---肝脏浊音界消失---疾病---体征结果
  130. 自发性气胸---体征结果---肺浊音界消失---疾病---体征结果
  131. 自发性气胸---患侧表现---肺浊音界消失---疾病---体征结果
  132. 自发性气胸---患侧表现---听诊呼吸音减弱---疾病---体征结果
  133. 自发性气胸---患侧表现---听诊呼吸音消失---疾病---体征结果
  134. 胸部X线影像学检查---是诊断...最准确可靠方法---气胸---检查---疾病
  135. 继发性气胸病人---可因...被影响诊断---原有基础疾病---人群---疾病
  136. 临床不能用其他原因解释的呼吸困难者---需考虑可能为---自发性气胸---人群---疾病
  137. 急诊处理呼吸困难无改善者---需考虑可能为---自发性气胸---人群---疾病
  138. 病情危重不能立即行胸部X线检查者---推荐进行---诊断性穿刺---人群---检查操作
  139. 诊断性穿刺---实施部位---胸腔积气体征最明显处---检查操作---人体部位
  140. 用户输入的“前置医学文本”为:
  141. '''
  142. prompt+=record['meta_header']+'\n任务文本为:\n'
  143. prompt = prompt + record['content']
  144. llm_result = DeepseekUtil.chat(prompt)
  145. if not llm_result or not isinstance(llm_result, str):
  146. print(f"LLM返回结果无效: {type(llm_result)}")
  147. continue
  148. for line in llm_result.strip().split('\n'):
  149. try:
  150. if not line.strip():
  151. continue
  152. # 验证行格式
  153. if line.count('---') != 4:
  154. print(f"无效的三元组格式: {line}")
  155. continue
  156. # 解析结果
  157. parts = line.strip().split('---')
  158. if len(parts) != 5:
  159. print(f"解析失败,部分数量不符: {len(parts)} parts in {line}")
  160. continue
  161. start_node_name, relation_name, end_node_name, start_category, end_category = parts
  162. # 创建或获取起始节点和结束节点
  163. start_node_id = self._create_or_get_node(start_node_name, start_category)
  164. end_node_id = self._create_or_get_node(end_node_name, end_category)
  165. edgeService = KGEdgeService(next(get_db()))
  166. #edges = edgeService.get_edges_by_nodes(start_node_id, end_node_id, relation_name)
  167. #if len(edges) == 0:
  168. edge_data = {}
  169. edge_data['src_id'] = start_node_id
  170. edge_data['dest_id'] = end_node_id
  171. edge_data['name'] = relation_name
  172. edge_data['category'] = relation_name
  173. edge_data['version'] = 'unk'
  174. edge = edgeService.create_edge(edge_data)
  175. propService = KGPropService(next(get_db()))
  176. props_by_ref_id = propService.get_prop_by_ref_id(edge.id, 'trunk_ids')
  177. if props_by_ref_id:
  178. if record['id'] not in props_by_ref_id['trunk_ids']:
  179. props_by_ref_id['trunk_ids'].append(record['id'])
  180. propService.update_prop(edge.id, {'trunk_ids': props_by_ref_id['trunk_ids']})
  181. continue
  182. prop_data = {}
  183. prop_data['ref_id'] = edge.id
  184. prop_data['category'] = 2
  185. prop_data['type'] = 2
  186. prop_data['prop_name'] = 'trunk_ids'
  187. prop_data['prop_value'] = [record['id']]
  188. prop_data['prop_title'] = '切片id列表'
  189. propService.create_prop(prop_data)
  190. except Exception as e:
  191. print(f"处理行时发生异常: {str(e)}")
  192. continue
  193. page += 1
  194. def _create_or_get_node(self, node_name: str, category: str) -> int:
  195. node_service = KGNodeService(next(get_db()))
  196. node = node_service.get_node_by_name_category(node_name, category)
  197. if node:
  198. return node['id']
  199. node_data = {}
  200. node_data['name'] = node_name
  201. node_data['category'] = category
  202. node_data['version'] = 'unk'
  203. node_data['status'] = 0
  204. return node_service.create_node(node_data).id
  205. def get_page_no(self, text: str, trunks_service,file_path:str) -> int:
  206. results = trunks_service.search_by_vector(text,1000,type='page',file_path=file_path,conversation_id="1111111aaaa")
  207. sentences = self.split_text(text)
  208. count = 0
  209. for r in results:
  210. #将r["content"]的所有空白字符去掉
  211. content = regex.sub(r'[^\w\d\p{L}]', '', r["content"])
  212. count+=1
  213. match_count = 0
  214. length = len(sentences)/2
  215. for sentence in sentences:
  216. sentence = regex.sub(r'[^\w\d\p{L}]', '', sentence)
  217. if sentence in content:
  218. match_count += 1
  219. if match_count >= 2:
  220. return r["page_no"]
  221. def test_match_trunk(self,trunks_service) -> int:
  222. must_matchs = ['心肌梗死']
  223. keywords = [ '概述']
  224. text = '''- 主要病因:
  225. 1. 冠状动脉粥样硬化(占90%以上)
  226. 2. 冠状动脉栓塞(如房颤血栓脱落)
  227. 3. 冠状动脉痉挛(可卡因滥用等)
  228. - 危险因素:
  229. 1. 吸烟(RR=2.87)
  230. 2. 高血压(RR=2.50)
  231. 3. LDL-C≥190mg/dL(RR=4.48)
  232. - 遗传因素:
  233. 家族性高胆固醇血症(OMIM#143890)'''
  234. text = regex.sub(r'[^\w\d\p{L}]', '', text)
  235. results = trunks_service.search_by_vector(text,1000,distance=0.72,type='trunk')
  236. print(f"原结果: {results[0]["meta_header"]}")
  237. print(results[0]["content"])
  238. max_match_count = 0
  239. best_match = None
  240. for r in results:
  241. if all(must_match in r["content"] or must_match in r["meta_header"] for must_match in must_matchs):
  242. match_count = sum(keyword in r["content"] for keyword in keywords)
  243. if match_count > max_match_count:
  244. max_match_count = match_count
  245. best_match = r
  246. elif best_match is None and max_match_count == 0:
  247. best_match = r
  248. if best_match:
  249. print(f"最佳匹配: {best_match["title"]}")
  250. print(best_match["content"])
  251. return best_match
  252. def split_text(self, text):
  253. """将文本分割成句子"""
  254. print(text)
  255. # 使用常见的标点符号作为分隔符
  256. delimiters = ['!', '?', '。', '!', '?', '\n', ';', '。', ';']
  257. sentences = [text]
  258. for delimiter in delimiters:
  259. new_sentences = []
  260. for sentence in sentences:
  261. parts = sentence.split(delimiter)
  262. new_sentences.extend([part + delimiter if i < len(parts) - 1 else part for i, part in enumerate(parts)])
  263. sentences = [s.strip() for s in new_sentences if s.strip()]
  264. # 合并短句子
  265. merged_sentences = []
  266. buffer = ""
  267. for sentence in sentences:
  268. buffer += " " + sentence if buffer else sentence
  269. if len(buffer) >= 10:
  270. merged_sentences.append(buffer)
  271. buffer = ""
  272. if buffer:
  273. merged_sentences.append(buffer)
  274. # 打印最终句子
  275. for i, sentence in enumerate(merged_sentences):
  276. print(f"句子{i+1}: {sentence.replace(" ","").replace("\u2003", "").replace("\u2002", "").replace("\u2009", "").replace("\n", "").replace("\r", "")}")
  277. return merged_sentences
  278. class TestExceptionCases:
  279. def test_duplicate_id(self, trunks_service, test_trunk_data):
  280. with pytest.raises(IntegrityError):
  281. trunk1 = trunks_service.create_trunk(test_trunk_data)
  282. test_trunk_data["id"] = trunk1.id
  283. trunks_service.create_trunk(test_trunk_data)
  284. def test_invalid_vector_dimension(self, trunks_service, test_trunk_data):
  285. with pytest.raises(ValueError):
  286. invalid_data = test_trunk_data.copy()
  287. invalid_data["embedding"] = [0.1]*100
  288. trunks_service.create_trunk(invalid_data)
  289. @pytest.fixture
  290. def trunk_factory():
  291. class TrunkFactory:
  292. @staticmethod
  293. def create(**overrides):
  294. defaults = {
  295. "content": "工厂内容",
  296. "file_path": "factory_path.pdf",
  297. "type": "default"
  298. }
  299. return {**defaults, **overrides}
  300. return TrunkFactory()
  301. class TestBatchCreateFromDirectory:
  302. def test_batch_create_from_directory(self, trunks_service):
  303. # 使用现有目录路径
  304. base_path = Path(r'E:\project\vscode\《急诊医学(第2版)》')
  305. # 遍历目录并创建trunk
  306. created_ids = []
  307. for txt_path in base_path.glob('**/*_split_*.txt'):
  308. relative_path = txt_path.relative_to(base_path.parent.parent)
  309. with open(txt_path, 'r', encoding='utf-8') as f:
  310. trunk_data = {
  311. "content": f.read(),
  312. "file_path": str(relative_path).replace('\\', '/')
  313. }
  314. trunk = trunks_service.create_trunk(trunk_data)
  315. created_ids.append(trunk.id)
  316. # 验证数据库记录
  317. for trunk_id in created_ids:
  318. db_trunk = trunks_service.get_trunk_by_id(trunk_id)
  319. assert db_trunk is not None
  320. assert ".txt" in db_trunk.file_path
  321. assert "_split_" in db_trunk.file_path
  322. assert len(db_trunk.content) > 0