text_search.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877
  1. from fastapi import APIRouter, HTTPException, Depends
  2. from pydantic import BaseModel, Field, validator
  3. from typing import List, Optional
  4. from service.trunks_service import TrunksService
  5. from utils.sentence_util import SentenceUtil
  6. from utils.vector_distance import VectorDistance
  7. from model.response import StandardResponse
  8. from utils.vectorizer import Vectorizer
  9. # from utils.find_text_in_pdf import find_text_in_pdf
  10. import os
  11. DISTANCE_THRESHOLD = 0.73
  12. import logging
  13. import time
  14. from db.session import get_db
  15. from sqlalchemy.orm import Session
  16. from service.kg_node_service import KGNodeService
  17. from service.kg_prop_service import KGPropService
  18. from service.kg_edge_service import KGEdgeService
  19. from cachetools import TTLCache
  20. # 使用TextSimilarityFinder进行文本相似度匹配
  21. from utils.text_similarity import TextSimilarityFinder
  22. logger = logging.getLogger(__name__)
  23. router = APIRouter(tags=["Text Search"])
  24. # 创建全局缓存实例
  25. cache = TTLCache(maxsize=1000, ttl=3600)
  26. class TextSearchRequest(BaseModel):
  27. text: str
  28. conversation_id: Optional[str] = None
  29. need_convert: Optional[bool] = False
  30. class TextCompareRequest(BaseModel):
  31. sentence: str
  32. text: str
  33. class TextMatchRequest(BaseModel):
  34. text: str = Field(..., min_length=1, max_length=10000, description="需要搜索的文本内容")
  35. @validator('text')
  36. def validate_text(cls, v):
  37. # 保留所有可打印字符、换行符和中文字符
  38. v = ''.join(char for char in v if char.isprintable() or char in '\n\r')
  39. # 转义JSON特殊字符
  40. # 先处理反斜杠,避免后续转义时出现问题
  41. v = v.replace('\\', '\\\\')
  42. # 处理引号和其他特殊字符
  43. v = v.replace('"', '\\"')
  44. v = v.replace('/', '\\/')
  45. # 处理控制字符
  46. v = v.replace('\n', '\\n')
  47. v = v.replace('\r', '\\r')
  48. v = v.replace('\t', '\\t')
  49. v = v.replace('\b', '\\b')
  50. v = v.replace('\f', '\\f')
  51. # 处理Unicode转义
  52. # v = v.replace('\u', '\\u')
  53. return v
  54. class TextCompareMultiRequest(BaseModel):
  55. origin: str
  56. similar: str
  57. class NodePropsSearchRequest(BaseModel):
  58. node_id: int
  59. props_ids: List[int]
  60. symptoms: Optional[List[str]] = None
  61. @router.post("/kgrt_api/text/clear_cache", response_model=StandardResponse)
  62. async def clear_cache():
  63. try:
  64. # 清除全局缓存
  65. cache.clear()
  66. return StandardResponse(success=True, data={"message": "缓存已清除"})
  67. except Exception as e:
  68. logger.error(f"清除缓存失败: {str(e)}")
  69. raise HTTPException(status_code=500, detail=str(e))
  70. @router.post("/kgrt_api/text/search", response_model=StandardResponse)
  71. @router.post("/knowledge/text/search", response_model=StandardResponse)
  72. async def search_text(request: TextSearchRequest):
  73. try:
  74. #判断request.text是否为json格式,如果是,使用JsonToText的convert方法转换为text
  75. if request.text.startswith('{') and request.text.endswith('}'):
  76. from utils.json_to_text import JsonToTextConverter
  77. converter = JsonToTextConverter()
  78. request.text = converter.convert(request.text)
  79. # 使用TextSplitter拆分文本
  80. sentences = SentenceUtil.split_text(request.text)
  81. if not sentences:
  82. return StandardResponse(success=True, data={"answer": "", "references": []})
  83. # 初始化服务和结果列表
  84. trunks_service = TrunksService()
  85. result_sentences = []
  86. all_references = []
  87. reference_index = 1
  88. # 根据conversation_id获取缓存结果
  89. cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
  90. for sentence in sentences:
  91. # if request.need_convert:
  92. sentence = sentence.replace("\n", "<br>")
  93. if len(sentence) < 10:
  94. result_sentences.append(sentence)
  95. continue
  96. if cached_results:
  97. # 如果有缓存结果,计算向量距离
  98. min_distance = float('inf')
  99. best_result = None
  100. sentence_vector = Vectorizer.get_embedding(sentence)
  101. for cached_result in cached_results:
  102. content_vector = cached_result['embedding']
  103. distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
  104. if distance < min_distance:
  105. min_distance = distance
  106. best_result = {**cached_result, 'distance': distance}
  107. if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
  108. search_results = [best_result]
  109. else:
  110. search_results = []
  111. else:
  112. # 如果没有缓存结果,进行向量搜索
  113. search_results = trunks_service.search_by_vector(
  114. text=sentence,
  115. limit=1,
  116. type='trunk'
  117. )
  118. # 处理搜索结果
  119. for search_result in search_results:
  120. distance = search_result.get("distance", DISTANCE_THRESHOLD)
  121. if distance >= DISTANCE_THRESHOLD:
  122. result_sentences.append(sentence)
  123. continue
  124. # 检查是否已存在相同引用
  125. existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
  126. current_index = reference_index
  127. if existing_ref:
  128. current_index = int(existing_ref["index"])
  129. else:
  130. # 添加到引用列表
  131. # 从referrence中提取文件名
  132. file_name = ""
  133. referrence = search_result.get("referrence", "")
  134. if referrence and "/books/" in referrence:
  135. file_name = referrence.split("/books/")[-1]
  136. # 去除文件扩展名
  137. file_name = os.path.splitext(file_name)[0]
  138. reference = {
  139. "index": str(reference_index),
  140. "id": search_result["id"],
  141. "content": search_result["content"],
  142. "file_path": search_result.get("file_path", ""),
  143. "title": search_result.get("title", ""),
  144. "distance": distance,
  145. "file_name": file_name,
  146. "referrence": referrence
  147. }
  148. all_references.append(reference)
  149. reference_index += 1
  150. # 添加引用标记
  151. if sentence.endswith('<br>'):
  152. # 如果有多个<br>,在所有<br>前添加^[current_index]^
  153. result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
  154. else:
  155. # 直接在句子末尾添加^[current_index]^
  156. result_sentence = f'{sentence}^[{current_index}]^'
  157. result_sentences.append(result_sentence)
  158. # 组装返回数据
  159. response_data = {
  160. "answer": result_sentences,
  161. "references": all_references
  162. }
  163. return StandardResponse(success=True, data=response_data)
  164. except Exception as e:
  165. logger.error(f"Text search failed: {str(e)}")
  166. raise HTTPException(status_code=500, detail=str(e))
  167. @router.post("/kgrt_api/text/match", response_model=StandardResponse)
  168. @router.post("/knowledge/text/match", response_model=StandardResponse)
  169. async def match_text(request: TextCompareRequest):
  170. try:
  171. sentences = SentenceUtil.split_text(request.text)
  172. sentence_vector = Vectorizer.get_embedding(request.sentence)
  173. min_distance = float('inf')
  174. best_sentence = ""
  175. result_sentences = []
  176. for temp in sentences:
  177. result_sentences.append(temp)
  178. if len(temp) < 10:
  179. continue
  180. temp_vector = Vectorizer.get_embedding(temp)
  181. distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
  182. if distance < min_distance and distance < DISTANCE_THRESHOLD:
  183. min_distance = distance
  184. best_sentence = temp
  185. for i in range(len(result_sentences)):
  186. result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
  187. if result_sentences[i]["sentence"] == best_sentence:
  188. result_sentences[i]["matched"] = True
  189. return StandardResponse(success=True, records=result_sentences)
  190. except Exception as e:
  191. logger.error(f"Text comparison failed: {str(e)}")
  192. raise HTTPException(status_code=500, detail=str(e))
  193. @router.post("/kgrt_api/text/mr_search", response_model=StandardResponse)
  194. @router.post("/knowledge/text/mr_search", response_model=StandardResponse)
  195. async def mr_search_text_content(request: TextMatchRequest):
  196. try:
  197. # 初始化服务
  198. trunks_service = TrunksService()
  199. # 获取文本向量并搜索相似内容
  200. search_results = trunks_service.search_by_vector(
  201. text=request.text,
  202. limit=10,
  203. type="mr"
  204. )
  205. # 处理搜索结果
  206. records = []
  207. for result in search_results:
  208. distance = result.get("distance", DISTANCE_THRESHOLD)
  209. if distance >= DISTANCE_THRESHOLD:
  210. continue
  211. # 添加到引用列表
  212. record = {
  213. "content": result["content"],
  214. "file_path": result.get("file_path", ""),
  215. "title": result.get("title", ""),
  216. "distance": distance,
  217. }
  218. records.append(record)
  219. # 组装返回数据
  220. response_data = {
  221. "records": records
  222. }
  223. return StandardResponse(success=True, data=response_data)
  224. except Exception as e:
  225. logger.error(f"Mr search failed: {str(e)}")
  226. raise HTTPException(status_code=500, detail=str(e))
  227. @router.post("/kgrt_api/text/mr_match", response_model=StandardResponse)
  228. @router.post("/knowledge/text/mr_match", response_model=StandardResponse)
  229. async def compare_text(request: TextCompareMultiRequest):
  230. start_time = time.time()
  231. try:
  232. # 拆分两段文本
  233. origin_sentences = SentenceUtil.split_text(request.origin)
  234. similar_sentences = SentenceUtil.split_text(request.similar)
  235. end_time = time.time()
  236. logger.info(f"mr_match接口处理文本耗时: {(end_time - start_time) * 1000:.2f}ms")
  237. # 初始化结果列表
  238. origin_results = []
  239. # 过滤短句并预计算向量
  240. valid_origin_sentences = [(sent, len(sent) >= 10) for sent in origin_sentences]
  241. valid_similar_sentences = [(sent, len(sent) >= 10) for sent in similar_sentences]
  242. # 初始化similar_results,所有matched设为False
  243. similar_results = [{"sentence": sent, "matched": False} for sent, _ in valid_similar_sentences]
  244. # 批量获取向量
  245. origin_vectors = {}
  246. similar_vectors = {}
  247. origin_batch = [sent for sent, is_valid in valid_origin_sentences if is_valid]
  248. similar_batch = [sent for sent, is_valid in valid_similar_sentences if is_valid]
  249. if origin_batch:
  250. origin_embeddings = [Vectorizer.get_embedding(sent) for sent in origin_batch]
  251. origin_vectors = dict(zip(origin_batch, origin_embeddings))
  252. if similar_batch:
  253. similar_embeddings = [Vectorizer.get_embedding(sent) for sent in similar_batch]
  254. similar_vectors = dict(zip(similar_batch, similar_embeddings))
  255. end_time = time.time()
  256. logger.info(f"mr_match接口处理向量耗时: {(end_time - start_time) * 1000:.2f}ms")
  257. # 处理origin文本
  258. for origin_sent, is_valid in valid_origin_sentences:
  259. if not is_valid:
  260. origin_results.append({"sentence": origin_sent, "matched": False})
  261. continue
  262. origin_vector = origin_vectors[origin_sent]
  263. matched = False
  264. # 优化的相似度计算
  265. for i, similar_result in enumerate(similar_results):
  266. if similar_result["matched"]:
  267. continue
  268. similar_sent = similar_result["sentence"]
  269. if len(similar_sent) < 10:
  270. continue
  271. similar_vector = similar_vectors.get(similar_sent)
  272. if not similar_vector:
  273. continue
  274. distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
  275. if distance < DISTANCE_THRESHOLD:
  276. matched = True
  277. similar_results[i]["matched"] = True
  278. break
  279. origin_results.append({"sentence": origin_sent, "matched": matched})
  280. response_data = {
  281. "origin": origin_results,
  282. "similar": similar_results
  283. }
  284. end_time = time.time()
  285. logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  286. return StandardResponse(success=True, data=response_data)
  287. except Exception as e:
  288. end_time = time.time()
  289. logger.error(f"Text comparison failed: {str(e)}")
  290. logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  291. raise HTTPException(status_code=500, detail=str(e))
  292. def _check_cache(node_id: int) -> Optional[dict]:
  293. """检查并返回缓存结果"""
  294. cache_key = f"xunzheng_{node_id}"
  295. cached_result = cache.get(cache_key)
  296. if cached_result:
  297. logger.info(f"从缓存获取结果,node_id: {node_id}")
  298. return cached_result
  299. return None
  300. def _get_node_info(node_service: KGNodeService, node_id: int) -> dict:
  301. """获取并验证节点信息"""
  302. node = node_service.get_node(node_id)
  303. if not node:
  304. raise ValueError(f"节点不存在: {node_id}")
  305. return {
  306. "id": node_id,
  307. "name": node.get('name', ''),
  308. "category": node.get('category', ''),
  309. "props": [],
  310. "files": [],
  311. "distance": 0
  312. }
  313. def _process_search_result(search_result: dict, reference_index: int) -> tuple[dict, str]:
  314. """处理搜索结果,返回引用信息和文件名"""
  315. file_name = ""
  316. referrence = search_result.get("referrence", "")
  317. if referrence and "/books/" in referrence:
  318. file_name = referrence.split("/books/")[-1]
  319. file_name = os.path.splitext(file_name)[0]
  320. reference = {
  321. "index": str(reference_index),
  322. "id": search_result["id"],
  323. "content": search_result["content"],
  324. "file_path": search_result.get("file_path", ""),
  325. "title": search_result.get("title", ""),
  326. "distance": search_result.get("distance", DISTANCE_THRESHOLD),
  327. "page_no": search_result.get("page_no", ""),
  328. "file_name": file_name,
  329. "referrence": referrence
  330. }
  331. return reference, file_name
  332. def _get_file_type(file_name: str) -> str:
  333. """根据文件名确定文件类型"""
  334. file_name_lower = file_name.lower()
  335. if file_name_lower.endswith(".pdf"):
  336. return "pdf"
  337. elif file_name_lower.endswith((".doc", ".docx")):
  338. return "doc"
  339. elif file_name_lower.endswith((".xls", ".xlsx")):
  340. return "excel"
  341. elif file_name_lower.endswith((".ppt", ".pptx")):
  342. return "ppt"
  343. return "other"
  344. def _process_sentence_search(node_name: str, prop_title: str, sentences: list, trunks_service: TrunksService) -> tuple[list, list]:
  345. keywords = [node_name, prop_title] if node_name and prop_title else None
  346. return _process_sentence_search_keywords(sentences, trunks_service,keywords=keywords)
  347. def _process_sentence_search_keywords(sentences: list, trunks_service: TrunksService,keywords: Optional[List[str]] = None) -> tuple[list, list]:
  348. """处理句子搜索,返回结果句子和引用列表"""
  349. result_sentences = []
  350. all_references = []
  351. reference_index = 1
  352. i = 0
  353. while i < len(sentences):
  354. sentence = sentences[i]
  355. search_text = sentence
  356. if keywords:
  357. search_text = f"{keywords}:{sentence}"
  358. # if len(sentence) < 10 and i + 1 < len(sentences):
  359. # next_sentence = sentences[i + 1]
  360. # # result_sentences.append({"sentence": sentence, "flag": ""})
  361. # search_text = f"{node_name}:{prop_title}:{sentence} {next_sentence}"
  362. # i += 1
  363. # elif len(sentence) < 10:
  364. # result_sentences.append({"sentence": sentence, "flag": ""})
  365. # i += 1
  366. # continue
  367. # else:
  368. i += 1
  369. # 使用向量搜索获取相似内容
  370. search_results = trunks_service.search_by_vector(
  371. text=search_text,
  372. limit=500,
  373. type='trunk',
  374. distance=0.7
  375. )
  376. # 准备语料库数据
  377. trunk_texts = []
  378. trunk_ids = []
  379. # 创建一个字典来存储trunk的详细信息
  380. trunk_details = {}
  381. for trunk in search_results:
  382. trunk_texts.append(trunk.get('content'))
  383. trunk_ids.append(trunk.get('id'))
  384. # 缓存trunk的详细信息
  385. trunk_details[trunk.get('id')] = {
  386. 'id': trunk.get('id'),
  387. 'content': trunk.get('content'),
  388. 'file_path': trunk.get('file_path'),
  389. 'title': trunk.get('title'),
  390. 'referrence': trunk.get('referrence'),
  391. 'page_no': trunk.get('page_no')
  392. }
  393. if len(trunk_texts) == 0:
  394. continue
  395. # 初始化TextSimilarityFinder并加载语料库
  396. similarity_finder = TextSimilarityFinder(method='tfidf', use_jieba=True)
  397. similarity_finder.load_corpus(trunk_texts, trunk_ids)
  398. # 使用TextSimilarityFinder进行相似度匹配
  399. similar_results = similarity_finder.find_most_similar(search_text, top_n=1)
  400. if not similar_results: # 设置相似度阈值
  401. result_sentences.append({"sentence": sentence, "flag": ""})
  402. continue
  403. # 获取最相似的文本对应的trunk_id
  404. trunk_id = similar_results[0]['path']
  405. # 从缓存中获取trunk详细信息
  406. trunk_info = trunk_details.get(trunk_id)
  407. if trunk_info:
  408. search_result = {
  409. **trunk_info,
  410. 'distance': similar_results[0]['similarity'] # 转换相似度为距离
  411. }
  412. # 检查相似度是否达到阈值
  413. if search_result['distance'] >= DISTANCE_THRESHOLD:
  414. result_sentences.append({"sentence": sentence, "flag": ""})
  415. continue
  416. # 检查是否已存在相同引用
  417. existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
  418. current_index = int(existing_ref["index"]) if existing_ref else reference_index
  419. if not existing_ref:
  420. reference, _ = _process_search_result(search_result, reference_index)
  421. all_references.append(reference)
  422. reference_index += 1
  423. result_sentences.append({"sentence": sentence, "flag": str(current_index)})
  424. return result_sentences, all_references
  425. def _mark_symptoms(text: str, symptom_list: List[str]) -> str:
  426. """处理症状标记"""
  427. if not symptom_list:
  428. return text
  429. marked_sentence = text
  430. # 创建一个标记位置的列表,记录每个位置是否已被标记
  431. marked_positions = [False] * len(marked_sentence)
  432. # 创建一个列表来存储已处理的症状
  433. processed_symptoms = []
  434. for symptom in symptom_list:
  435. # 检查是否已处理过该症状或其子集
  436. if any(symptom in processed_sym or processed_sym in symptom for processed_sym in processed_symptoms):
  437. continue
  438. # 查找所有匹配位置
  439. start_pos = 0
  440. while True:
  441. pos = marked_sentence.find(symptom, start_pos)
  442. if pos == -1:
  443. break
  444. # 检查这个位置是否已被标记
  445. if not any(marked_positions[pos:pos + len(symptom)]):
  446. # 标记这个范围的所有位置
  447. for i in range(pos, pos + len(symptom)):
  448. marked_positions[i] = True
  449. # 替换文本
  450. marked_sentence = marked_sentence[:pos] + f'<i style="color:red;">{symptom}</i>' + marked_sentence[pos + len(symptom):]
  451. # 将成功标记的症状添加到已处理列表中
  452. if symptom not in processed_symptoms:
  453. processed_symptoms.append(symptom)
  454. # 更新标记位置数组以适应新插入的标签
  455. new_positions = [False] * (len('<i style="color:red;">') + len('</i>'))
  456. marked_positions = marked_positions[:pos] + new_positions + marked_positions[pos:]
  457. start_pos = pos + len('<i style="color:red;">') + len(symptom) + len('</i>')
  458. return marked_sentence
  459. @router.post("/kgrt_api/text/eb_search", response_model=StandardResponse)
  460. @router.post("/knowledge/text/eb_search", response_model=StandardResponse)
  461. async def node_props_search(request: NodePropsSearchRequest, db: Session = Depends(get_db)):
  462. try:
  463. start_time = time.time()
  464. # 检查缓存
  465. cached_result = _check_cache(request.node_id)
  466. if cached_result:
  467. # 如果有症状列表,处理症状标记
  468. if request.symptoms:
  469. symptom_list = []
  470. try:
  471. # 初始化服务
  472. node_service = KGNodeService(db)
  473. edge_service = KGEdgeService(db)
  474. for symptom in request.symptoms:
  475. # 添加原始症状
  476. symptom_list.append(symptom)
  477. try:
  478. # 获取症状节点
  479. symptom_node = node_service.get_node_by_name_category(symptom, '症状')
  480. # 获取症状相关同义词(包括1.0和2.0版本)
  481. for category in ['症状同义词', '症状同义词2.0']:
  482. edges = edge_service.get_edges_by_nodes(src_id=symptom_node['id'], category=category)
  483. if edges:
  484. # 添加同义词
  485. for edge in edges:
  486. if edge['dest_node'] and edge['dest_node'].get('name'):
  487. symptom_list.append(edge['dest_node']['name'])
  488. except ValueError:
  489. # 如果找不到节点,只添加原始症状
  490. continue
  491. # 按照字符长度进行倒序排序
  492. symptom_list.sort(key=len, reverse=True)
  493. # 处理缓存结果中的症状标记
  494. for prop in cached_result.get('props', []):
  495. if prop.get('prop_title') == '临床表现' and 'answer' in prop:
  496. for answer in prop['answer']:
  497. answer['sentence'] = _mark_symptoms(answer['sentence'], symptom_list)
  498. except Exception as e:
  499. logger.error(f"处理症状标记失败: {str(e)}")
  500. return StandardResponse(success=True, data=cached_result)
  501. # 初始化服务
  502. trunks_service = TrunksService()
  503. node_service = KGNodeService(db)
  504. prop_service = KGPropService(db)
  505. edge_service = KGEdgeService(db)
  506. # 获取节点信息
  507. result = _get_node_info(node_service, request.node_id)
  508. node_name = result["name"]
  509. # 处理症状列表
  510. symptom_list = []
  511. if request.symptoms:
  512. for symptom in request.symptoms:
  513. try:
  514. # 添加原始症状
  515. symptom_list.append(symptom)
  516. # 获取症状节点
  517. symptom_node = node_service.get_node_by_name_category(symptom, '症状')
  518. # 获取症状相关同义词(包括1.0和2.0版本)
  519. for category in ['症状同义词', '症状同义词2.0']:
  520. edges = edge_service.get_edges_by_nodes(src_id=symptom_node['id'], category=category)
  521. if edges:
  522. # 添加同义词
  523. for edge in edges:
  524. if edge['dest_node'] and edge['dest_node'].get('name'):
  525. symptom_list.append(edge['dest_node']['name'])
  526. except ValueError:
  527. # 如果找不到节点,只添加原始症状
  528. continue
  529. # 按照字符长度进行倒序排序
  530. symptom_list.sort(key=len, reverse=True)
  531. # 遍历props_ids查询属性信息
  532. for prop_id in request.props_ids:
  533. prop = prop_service.get_prop_by_id(prop_id)
  534. if not prop:
  535. logger.warning(f"属性不存在: {prop_id}")
  536. continue
  537. prop_title = prop.get('prop_title', '')
  538. prop_value = prop.get('prop_value', '')
  539. # 创建属性结果对象
  540. prop_result = {
  541. "id": prop_id,
  542. "category": prop.get('category', 0),
  543. "prop_name": prop.get('prop_name', ''),
  544. "prop_value": prop_value,
  545. "prop_title": prop_title,
  546. "type": prop.get('type', 1)
  547. }
  548. result["props"].append(prop_result)
  549. # 如果prop_value为'无',则跳过搜索
  550. if prop_value == '无':
  551. prop_result["answer"] = [{
  552. "sentence": prop_value,
  553. "flag": ""
  554. }]
  555. continue
  556. # 先用完整的prop_value进行搜索
  557. search_text = f"{node_name}:{prop_title}:{prop_value}"
  558. # 使用向量搜索获取相似内容
  559. search_results = trunks_service.search_by_vector(
  560. text=search_text,
  561. limit=500,
  562. type='trunk',
  563. distance=0.7
  564. )
  565. # 准备语料库数据
  566. trunk_texts = []
  567. trunk_ids = []
  568. # 创建一个字典来存储trunk的详细信息
  569. trunk_details = {}
  570. for trunk in search_results:
  571. trunk_texts.append(trunk.get('content'))
  572. trunk_ids.append(trunk.get('id'))
  573. # 缓存trunk的详细信息
  574. trunk_details[trunk.get('id')] = {
  575. 'id': trunk.get('id'),
  576. 'content': trunk.get('content'),
  577. 'file_path': trunk.get('file_path'),
  578. 'title': trunk.get('title'),
  579. 'referrence': trunk.get('referrence'),
  580. 'page_no': trunk.get('page_no')
  581. }
  582. if len(trunk_texts)==0:
  583. continue
  584. # 初始化TextSimilarityFinder并加载语料库
  585. similarity_finder = TextSimilarityFinder(method='tfidf', use_jieba=True)
  586. similarity_finder.load_corpus(trunk_texts, trunk_ids)
  587. similar_results = similarity_finder.find_most_similar(search_text, top_n=1)
  588. # 处理搜索结果
  589. if similar_results and similar_results[0]['similarity']>=0.3: # 设置相似度阈值
  590. # 获取最相似的文本对应的trunk_id
  591. trunk_id = similar_results[0]['path']
  592. # 从缓存中获取trunk详细信息
  593. trunk_info = trunk_details.get(trunk_id)
  594. if trunk_info:
  595. search_result = {
  596. **trunk_info,
  597. 'distance': similar_results[0]['similarity'] # 转换相似度为距离
  598. }
  599. reference, _ = _process_search_result(search_result, 1)
  600. prop_result["references"] = [reference]
  601. prop_result["answer"] = [{
  602. "sentence": prop_value,
  603. "flag": "1"
  604. }]
  605. else:
  606. # 如果整体搜索没有找到匹配结果,则进行句子拆分搜索
  607. sentences = SentenceUtil.split_text(prop_value,10)
  608. else:
  609. # 如果整体搜索没有找到匹配结果,则进行句子拆分搜索
  610. sentences = SentenceUtil.split_text(prop_value,10)
  611. result_sentences, references = _process_sentence_search(
  612. node_name, prop_title, sentences, trunks_service
  613. )
  614. if references:
  615. prop_result["references"] = references
  616. if result_sentences:
  617. prop_result["answer"] = result_sentences
  618. # 处理文件信息
  619. all_files = set()
  620. file_index_map = {}
  621. file_index = 1
  622. # 收集文件信息
  623. for prop_result in result["props"]:
  624. if "references" not in prop_result:
  625. continue
  626. for ref in prop_result["references"]:
  627. referrence = ref.get("referrence", "")
  628. if not (referrence and "/books/" in referrence):
  629. continue
  630. file_name = referrence.split("/books/")[-1]
  631. if not file_name:
  632. continue
  633. file_type = _get_file_type(file_name)
  634. if file_name not in file_index_map:
  635. file_index_map[file_name] = file_index
  636. file_index += 1
  637. all_files.add((file_name, file_type))
  638. # 更新引用索引
  639. for prop_result in result["props"]:
  640. if "references" not in prop_result:
  641. continue
  642. for ref in prop_result["references"]:
  643. referrence = ref.get("referrence", "")
  644. if referrence and "/books/" in referrence:
  645. file_name = referrence.split("/books/")[-1]
  646. if file_name in file_index_map:
  647. ref["index"] = f"{file_index_map[file_name]}-{ref['index']}"
  648. # 更新answer中的index
  649. if "answer" in prop_result:
  650. for sentence in prop_result["answer"]:
  651. if sentence["index"]:
  652. for ref in prop_result["references"]:
  653. if ref["index"].endswith(f"-{sentence['index']}"):
  654. sentence["flag"] = ref["index"]
  655. break
  656. # 添加文件信息到结果
  657. result["files"] = sorted([{
  658. "file_name": file_name,
  659. "file_type": file_type,
  660. "index": str(file_index_map[file_name])
  661. } for file_name, file_type in all_files], key=lambda x: int(x["index"]))
  662. # 缓存结果
  663. cache_key = f"xunzheng_{request.node_id}"
  664. cache[cache_key] = result
  665. # 处理症状标记
  666. if request.symptoms:
  667. for prop in result.get('props', []):
  668. if prop.get('prop_title') == '临床表现' and 'answer' in prop:
  669. for answer in prop['answer']:
  670. answer['sentence'] = _mark_symptoms(answer['sentence'], symptom_list)
  671. end_time = time.time()
  672. logger.info(f"node_props_search接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  673. return StandardResponse(success=True, data=result)
  674. except Exception as e:
  675. logger.error(f"Node props search failed: {str(e)}")
  676. raise HTTPException(status_code=500, detail=str(e))
  677. class FindSimilarTexts(BaseModel):
  678. keywords:Optional[List[str]] = None
  679. search_text: str
  680. @router.post("/knowledge/text/find_similar_texts", response_model=StandardResponse)
  681. async def find_similar_texts(request: FindSimilarTexts, db: Session = Depends(get_db)):
  682. trunks_service = TrunksService()
  683. search_text = request.search_text
  684. if request.keywords:
  685. search_text = f"{request.keywords}:{search_text}"
  686. # 使用向量搜索获取相似内容
  687. search_results = trunks_service.search_by_vector(
  688. text=search_text,
  689. limit=500,
  690. type='trunk',
  691. distance=0.7
  692. )
  693. # 准备语料库数据
  694. trunk_texts = []
  695. trunk_ids = []
  696. # 创建一个字典来存储trunk的详细信息
  697. trunk_details = {}
  698. for trunk in search_results:
  699. trunk_texts.append(trunk.get('content'))
  700. trunk_ids.append(trunk.get('id'))
  701. # 缓存trunk的详细信息
  702. trunk_details[trunk.get('id')] = {
  703. 'id': trunk.get('id'),
  704. 'content': trunk.get('content'),
  705. 'file_path': trunk.get('file_path'),
  706. 'title': trunk.get('title'),
  707. 'referrence': trunk.get('referrence'),
  708. 'page_no': trunk.get('page_no')
  709. }
  710. if len(trunk_texts) == 0:
  711. return
  712. # 初始化TextSimilarityFinder并加载语料库
  713. similarity_finder = TextSimilarityFinder(method='tfidf', use_jieba=True)
  714. similarity_finder.load_corpus(trunk_texts, trunk_ids)
  715. similar_results = similarity_finder.find_most_similar(search_text, top_n=1)
  716. prop_result = {}
  717. # 处理搜索结果
  718. if similar_results and similar_results[0]['similarity'] >= 0.3: # 设置相似度阈值
  719. # 获取最相似的文本对应的trunk_id
  720. trunk_id = similar_results[0]['path']
  721. # 从缓存中获取trunk详细信息
  722. trunk_info = trunk_details.get(trunk_id)
  723. if trunk_info:
  724. search_result = {
  725. **trunk_info,
  726. 'distance': similar_results[0]['similarity'] # 转换相似度为距离
  727. }
  728. reference, _ = _process_search_result(search_result, 1)
  729. prop_result["references"] = [reference]
  730. prop_result["answer"] = [{
  731. "sentence": request.search_text,
  732. "flag": "1"
  733. }]
  734. else:
  735. # 如果整体搜索没有找到匹配结果,则进行句子拆分搜索
  736. sentences = SentenceUtil.split_text(request.search_text, 10)
  737. result_sentences, references = _process_sentence_search_keywords(
  738. sentences, trunks_service,keywords=request.keywords
  739. )
  740. if references:
  741. prop_result["references"] = references
  742. if result_sentences:
  743. prop_result["answer"] = result_sentences
  744. return StandardResponse(success=True,data=prop_result)
  745. text_search_router = router