text_search.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. from fastapi import APIRouter, HTTPException, Depends
  2. from pydantic import BaseModel, Field, validator
  3. from typing import List, Optional
  4. from service.trunks_service import TrunksService
  5. from utils.text_splitter import TextSplitter
  6. from utils.vector_distance import VectorDistance
  7. from model.response import StandardResponse
  8. from utils.vectorizer import Vectorizer
  9. # from utils.find_text_in_pdf import find_text_in_pdf
  10. import os
  11. DISTANCE_THRESHOLD = 0.73
  12. import logging
  13. import time
  14. from db.session import get_db
  15. from sqlalchemy.orm import Session
  16. from service.kg_node_service import KGNodeService
  17. from service.kg_prop_service import KGPropService
  18. from cachetools import TTLCache
  19. logger = logging.getLogger(__name__)
  20. router = APIRouter(prefix="/text", tags=["Text Search"])
  21. # 创建全局缓存实例
  22. cache = TTLCache(maxsize=1000, ttl=3600)
  23. class TextSearchRequest(BaseModel):
  24. text: str
  25. conversation_id: Optional[str] = None
  26. need_convert: Optional[bool] = False
  27. class TextCompareRequest(BaseModel):
  28. sentence: str
  29. text: str
  30. class TextMatchRequest(BaseModel):
  31. text: str = Field(..., min_length=1, max_length=10000, description="需要搜索的文本内容")
  32. @validator('text')
  33. def validate_text(cls, v):
  34. # 保留所有可打印字符、换行符和中文字符
  35. v = ''.join(char for char in v if char.isprintable() or char in '\n\r')
  36. # 转义JSON特殊字符
  37. # 先处理反斜杠,避免后续转义时出现问题
  38. v = v.replace('\\', '\\\\')
  39. # 处理引号和其他特殊字符
  40. v = v.replace('"', '\\"')
  41. v = v.replace('/', '\\/')
  42. # 处理控制字符
  43. v = v.replace('\n', '\\n')
  44. v = v.replace('\r', '\\r')
  45. v = v.replace('\t', '\\t')
  46. v = v.replace('\b', '\\b')
  47. v = v.replace('\f', '\\f')
  48. # 处理Unicode转义
  49. # v = v.replace('\u', '\\u')
  50. return v
  51. class TextCompareMultiRequest(BaseModel):
  52. origin: str
  53. similar: str
  54. class NodePropsSearchRequest(BaseModel):
  55. node_id: int
  56. props_ids: List[int]
  57. @router.post("/clear_cache", response_model=StandardResponse)
  58. async def clear_cache():
  59. try:
  60. # 清除全局缓存
  61. cache.clear()
  62. return StandardResponse(success=True, data={"message": "缓存已清除"})
  63. except Exception as e:
  64. logger.error(f"清除缓存失败: {str(e)}")
  65. raise HTTPException(status_code=500, detail=str(e))
  66. @router.post("/search", response_model=StandardResponse)
  67. async def search_text(request: TextSearchRequest):
  68. try:
  69. #判断request.text是否为json格式,如果是,使用JsonToText的convert方法转换为text
  70. if request.text.startswith('{') and request.text.endswith('}'):
  71. from utils.json_to_text import JsonToTextConverter
  72. converter = JsonToTextConverter()
  73. request.text = converter.convert(request.text)
  74. # 使用TextSplitter拆分文本
  75. sentences = TextSplitter.split_text(request.text)
  76. if not sentences:
  77. return StandardResponse(success=True, data={"answer": "", "references": []})
  78. # 初始化服务和结果列表
  79. trunks_service = TrunksService()
  80. result_sentences = []
  81. all_references = []
  82. reference_index = 1
  83. # 根据conversation_id获取缓存结果
  84. cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
  85. for sentence in sentences:
  86. # if request.need_convert:
  87. sentence = sentence.replace("\n", "<br>")
  88. if len(sentence) < 10:
  89. result_sentences.append(sentence)
  90. continue
  91. if cached_results:
  92. # 如果有缓存结果,计算向量距离
  93. min_distance = float('inf')
  94. best_result = None
  95. sentence_vector = Vectorizer.get_embedding(sentence)
  96. for cached_result in cached_results:
  97. content_vector = cached_result['embedding']
  98. distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
  99. if distance < min_distance:
  100. min_distance = distance
  101. best_result = {**cached_result, 'distance': distance}
  102. if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
  103. search_results = [best_result]
  104. else:
  105. search_results = []
  106. else:
  107. # 如果没有缓存结果,进行向量搜索
  108. search_results = trunks_service.search_by_vector(
  109. text=sentence,
  110. limit=1,
  111. type='trunk'
  112. )
  113. # 处理搜索结果
  114. for search_result in search_results:
  115. distance = search_result.get("distance", DISTANCE_THRESHOLD)
  116. if distance >= DISTANCE_THRESHOLD:
  117. result_sentences.append(sentence)
  118. continue
  119. # 检查是否已存在相同引用
  120. existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
  121. current_index = reference_index
  122. if existing_ref:
  123. current_index = int(existing_ref["index"])
  124. else:
  125. # 添加到引用列表
  126. # 从referrence中提取文件名
  127. file_name = ""
  128. referrence = search_result.get("referrence", "")
  129. if referrence and "/books/" in referrence:
  130. file_name = referrence.split("/books/")[-1]
  131. # 去除文件扩展名
  132. file_name = os.path.splitext(file_name)[0]
  133. reference = {
  134. "index": str(reference_index),
  135. "id": search_result["id"],
  136. "content": search_result["content"],
  137. "file_path": search_result.get("file_path", ""),
  138. "title": search_result.get("title", ""),
  139. "distance": distance,
  140. "file_name": file_name,
  141. "referrence": referrence
  142. }
  143. all_references.append(reference)
  144. reference_index += 1
  145. # 添加引用标记
  146. if sentence.endswith('<br>'):
  147. # 如果有多个<br>,在所有<br>前添加^[current_index]^
  148. result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
  149. else:
  150. # 直接在句子末尾添加^[current_index]^
  151. result_sentence = f'{sentence}^[{current_index}]^'
  152. result_sentences.append(result_sentence)
  153. # 组装返回数据
  154. response_data = {
  155. "answer": result_sentences,
  156. "references": all_references
  157. }
  158. return StandardResponse(success=True, data=response_data)
  159. except Exception as e:
  160. logger.error(f"Text search failed: {str(e)}")
  161. raise HTTPException(status_code=500, detail=str(e))
  162. @router.post("/match", response_model=StandardResponse)
  163. async def match_text(request: TextCompareRequest):
  164. try:
  165. sentences = TextSplitter.split_text(request.text)
  166. sentence_vector = Vectorizer.get_embedding(request.sentence)
  167. min_distance = float('inf')
  168. best_sentence = ""
  169. result_sentences = []
  170. for temp in sentences:
  171. result_sentences.append(temp)
  172. if len(temp) < 10:
  173. continue
  174. temp_vector = Vectorizer.get_embedding(temp)
  175. distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
  176. if distance < min_distance and distance < DISTANCE_THRESHOLD:
  177. min_distance = distance
  178. best_sentence = temp
  179. for i in range(len(result_sentences)):
  180. result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
  181. if result_sentences[i]["sentence"] == best_sentence:
  182. result_sentences[i]["matched"] = True
  183. return StandardResponse(success=True, records=result_sentences)
  184. except Exception as e:
  185. logger.error(f"Text comparison failed: {str(e)}")
  186. raise HTTPException(status_code=500, detail=str(e))
  187. @router.post("/mr_search", response_model=StandardResponse)
  188. async def mr_search_text_content(request: TextMatchRequest):
  189. try:
  190. # 初始化服务
  191. trunks_service = TrunksService()
  192. # 获取文本向量并搜索相似内容
  193. search_results = trunks_service.search_by_vector(
  194. text=request.text,
  195. limit=10,
  196. type="mr"
  197. )
  198. # 处理搜索结果
  199. records = []
  200. for result in search_results:
  201. distance = result.get("distance", DISTANCE_THRESHOLD)
  202. if distance >= DISTANCE_THRESHOLD:
  203. continue
  204. # 添加到引用列表
  205. record = {
  206. "content": result["content"],
  207. "file_path": result.get("file_path", ""),
  208. "title": result.get("title", ""),
  209. "distance": distance,
  210. }
  211. records.append(record)
  212. # 组装返回数据
  213. response_data = {
  214. "records": records
  215. }
  216. return StandardResponse(success=True, data=response_data)
  217. except Exception as e:
  218. logger.error(f"Mr search failed: {str(e)}")
  219. raise HTTPException(status_code=500, detail=str(e))
  220. @router.post("/mr_match", response_model=StandardResponse)
  221. async def compare_text(request: TextCompareMultiRequest):
  222. start_time = time.time()
  223. try:
  224. # 拆分两段文本
  225. origin_sentences = TextSplitter.split_text(request.origin)
  226. similar_sentences = TextSplitter.split_text(request.similar)
  227. end_time = time.time()
  228. logger.info(f"mr_match接口处理文本耗时: {(end_time - start_time) * 1000:.2f}ms")
  229. # 初始化结果列表
  230. origin_results = []
  231. # 过滤短句并预计算向量
  232. valid_origin_sentences = [(sent, len(sent) >= 10) for sent in origin_sentences]
  233. valid_similar_sentences = [(sent, len(sent) >= 10) for sent in similar_sentences]
  234. # 初始化similar_results,所有matched设为False
  235. similar_results = [{"sentence": sent, "matched": False} for sent, _ in valid_similar_sentences]
  236. # 批量获取向量
  237. origin_vectors = {}
  238. similar_vectors = {}
  239. origin_batch = [sent for sent, is_valid in valid_origin_sentences if is_valid]
  240. similar_batch = [sent for sent, is_valid in valid_similar_sentences if is_valid]
  241. if origin_batch:
  242. origin_embeddings = [Vectorizer.get_embedding(sent) for sent in origin_batch]
  243. origin_vectors = dict(zip(origin_batch, origin_embeddings))
  244. if similar_batch:
  245. similar_embeddings = [Vectorizer.get_embedding(sent) for sent in similar_batch]
  246. similar_vectors = dict(zip(similar_batch, similar_embeddings))
  247. end_time = time.time()
  248. logger.info(f"mr_match接口处理向量耗时: {(end_time - start_time) * 1000:.2f}ms")
  249. # 处理origin文本
  250. for origin_sent, is_valid in valid_origin_sentences:
  251. if not is_valid:
  252. origin_results.append({"sentence": origin_sent, "matched": False})
  253. continue
  254. origin_vector = origin_vectors[origin_sent]
  255. matched = False
  256. # 优化的相似度计算
  257. for i, similar_result in enumerate(similar_results):
  258. if similar_result["matched"]:
  259. continue
  260. similar_sent = similar_result["sentence"]
  261. if len(similar_sent) < 10:
  262. continue
  263. similar_vector = similar_vectors.get(similar_sent)
  264. if not similar_vector:
  265. continue
  266. distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
  267. if distance < DISTANCE_THRESHOLD:
  268. matched = True
  269. similar_results[i]["matched"] = True
  270. break
  271. origin_results.append({"sentence": origin_sent, "matched": matched})
  272. response_data = {
  273. "origin": origin_results,
  274. "similar": similar_results
  275. }
  276. end_time = time.time()
  277. logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  278. return StandardResponse(success=True, data=response_data)
  279. except Exception as e:
  280. end_time = time.time()
  281. logger.error(f"Text comparison failed: {str(e)}")
  282. logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  283. raise HTTPException(status_code=500, detail=str(e))
  284. def _check_cache(node_id: int) -> Optional[dict]:
  285. """检查并返回缓存结果"""
  286. cache_key = f"xunzheng_{node_id}"
  287. cached_result = cache.get(cache_key)
  288. if cached_result:
  289. logger.info(f"从缓存获取结果,node_id: {node_id}")
  290. return cached_result
  291. return None
  292. def _get_node_info(node_service: KGNodeService, node_id: int) -> dict:
  293. """获取并验证节点信息"""
  294. node = node_service.get_node(node_id)
  295. if not node:
  296. raise ValueError(f"节点不存在: {node_id}")
  297. return {
  298. "id": node_id,
  299. "name": node.get('name', ''),
  300. "category": node.get('category', ''),
  301. "props": [],
  302. "files": [],
  303. "distance": 0
  304. }
  305. def _process_search_result(search_result: dict, reference_index: int) -> tuple[dict, str]:
  306. """处理搜索结果,返回引用信息和文件名"""
  307. file_name = ""
  308. referrence = search_result.get("referrence", "")
  309. if referrence and "/books/" in referrence:
  310. file_name = referrence.split("/books/")[-1]
  311. file_name = os.path.splitext(file_name)[0]
  312. reference = {
  313. "index": str(reference_index),
  314. "id": search_result["id"],
  315. "content": search_result["content"],
  316. "file_path": search_result.get("file_path", ""),
  317. "title": search_result.get("title", ""),
  318. "distance": search_result.get("distance", DISTANCE_THRESHOLD),
  319. "page_no": search_result.get("page_no", ""),
  320. "file_name": file_name,
  321. "referrence": referrence
  322. }
  323. return reference, file_name
  324. def _get_file_type(file_name: str) -> str:
  325. """根据文件名确定文件类型"""
  326. file_name_lower = file_name.lower()
  327. if file_name_lower.endswith(".pdf"):
  328. return "pdf"
  329. elif file_name_lower.endswith((".doc", ".docx")):
  330. return "doc"
  331. elif file_name_lower.endswith((".xls", ".xlsx")):
  332. return "excel"
  333. elif file_name_lower.endswith((".ppt", ".pptx")):
  334. return "ppt"
  335. return "other"
  336. def _process_sentence_search(node_name: str, prop_title: str, sentences: list, trunks_service: TrunksService) -> tuple[list, list]:
  337. """处理句子搜索,返回结果句子和引用列表"""
  338. result_sentences = []
  339. all_references = []
  340. reference_index = 1
  341. i = 0
  342. while i < len(sentences):
  343. sentence = sentences[i]
  344. if len(sentence) < 10 and i + 1 < len(sentences):
  345. next_sentence = sentences[i + 1]
  346. result_sentences.append({"sentence": sentence, "flag": ""})
  347. search_text = f"{node_name}:{prop_title}:{sentence} {next_sentence}"
  348. i += 1
  349. elif len(sentence) < 10:
  350. result_sentences.append({"sentence": sentence, "flag": ""})
  351. i += 1
  352. continue
  353. else:
  354. search_text = f"{node_name}:{prop_title}:{sentence}"
  355. i += 1
  356. search_results = trunks_service.search_by_vector(text=search_text, limit=1, type='trunk')
  357. if not search_results:
  358. result_sentences.append({"sentence": sentence, "flag": ""})
  359. continue
  360. for search_result in search_results:
  361. if search_result.get("distance", DISTANCE_THRESHOLD) >= DISTANCE_THRESHOLD:
  362. result_sentences.append({"sentence": sentence, "flag": ""})
  363. continue
  364. existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
  365. current_index = int(existing_ref["index"]) if existing_ref else reference_index
  366. if not existing_ref:
  367. reference, _ = _process_search_result(search_result, reference_index)
  368. all_references.append(reference)
  369. reference_index += 1
  370. result_sentences.append({"sentence": sentence, "flag": str(current_index)})
  371. return result_sentences, all_references
  372. @router.post("/eb_search", response_model=StandardResponse)
  373. async def node_props_search(request: NodePropsSearchRequest, db: Session = Depends(get_db)):
  374. try:
  375. start_time = time.time()
  376. # 检查缓存
  377. cached_result = _check_cache(request.node_id)
  378. if cached_result:
  379. return StandardResponse(success=True, data=cached_result)
  380. # 初始化服务
  381. trunks_service = TrunksService()
  382. node_service = KGNodeService(db)
  383. prop_service = KGPropService(db)
  384. # 获取节点信息
  385. result = _get_node_info(node_service, request.node_id)
  386. node_name = result["name"]
  387. # 遍历props_ids查询属性信息
  388. for prop_id in request.props_ids:
  389. prop = prop_service.get_props_by_id(prop_id)
  390. if not prop:
  391. logger.warning(f"属性不存在: {prop_id}")
  392. continue
  393. prop_title = prop.get('prop_title', '')
  394. prop_value = prop.get('prop_value', '')
  395. # 创建属性结果对象
  396. prop_result = {
  397. "id": prop_id,
  398. "category": prop.get('category', 0),
  399. "prop_name": prop.get('prop_name', ''),
  400. "prop_value": prop_value,
  401. "prop_title": prop_title,
  402. "type": prop.get('type', 1)
  403. }
  404. result["props"].append(prop_result)
  405. # 先用完整的prop_value进行搜索
  406. search_text = f"{node_name}:{prop_title}:{prop_value}"
  407. full_search_results = trunks_service.search_by_vector(
  408. text=search_text,
  409. limit=1,
  410. type='trunk'
  411. )
  412. # 处理搜索结果
  413. if full_search_results and full_search_results[0].get("distance", DISTANCE_THRESHOLD) < DISTANCE_THRESHOLD:
  414. search_result = full_search_results[0]
  415. reference, _ = _process_search_result(search_result, 1)
  416. prop_result["references"] = [reference]
  417. prop_result["answer"] = [{
  418. "sentence": prop_value,
  419. "flag": "1"
  420. }]
  421. else:
  422. # 如果整体搜索没有找到匹配结果,则进行句子拆分搜索
  423. sentences = TextSplitter.split_text(prop_value)
  424. result_sentences, references = _process_sentence_search(
  425. node_name, prop_title, sentences, trunks_service
  426. )
  427. if references:
  428. prop_result["references"] = references
  429. if result_sentences:
  430. prop_result["answer"] = result_sentences
  431. # 处理文件信息
  432. all_files = set()
  433. file_index_map = {}
  434. file_index = 1
  435. # 收集文件信息
  436. for prop_result in result["props"]:
  437. if "references" not in prop_result:
  438. continue
  439. for ref in prop_result["references"]:
  440. referrence = ref.get("referrence", "")
  441. if not (referrence and "/books/" in referrence):
  442. continue
  443. file_name = referrence.split("/books/")[-1]
  444. if not file_name:
  445. continue
  446. file_type = _get_file_type(file_name)
  447. if file_name not in file_index_map:
  448. file_index_map[file_name] = file_index
  449. file_index += 1
  450. all_files.add((file_name, file_type))
  451. # 更新引用索引
  452. for prop_result in result["props"]:
  453. if "references" not in prop_result:
  454. continue
  455. for ref in prop_result["references"]:
  456. referrence = ref.get("referrence", "")
  457. if referrence and "/books/" in referrence:
  458. file_name = referrence.split("/books/")[-1]
  459. if file_name in file_index_map:
  460. ref["index"] = f"{file_index_map[file_name]}-{ref['index']}"
  461. # 更新answer中的flag
  462. if "answer" in prop_result:
  463. for sentence in prop_result["answer"]:
  464. if sentence["flag"]:
  465. for ref in prop_result["references"]:
  466. if ref["index"].endswith(f"-{sentence['flag']}"):
  467. sentence["flag"] = ref["index"]
  468. break
  469. # 添加文件信息到结果
  470. result["files"] = sorted([{
  471. "file_name": file_name,
  472. "file_type": file_type,
  473. "index": str(file_index_map[file_name])
  474. } for file_name, file_type in all_files], key=lambda x: int(x["index"]))
  475. end_time = time.time()
  476. logger.info(f"node_props_search接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  477. # 缓存结果
  478. cache_key = f"xunzheng_{request.node_id}"
  479. cache[cache_key] = result
  480. return StandardResponse(success=True, data=result)
  481. except Exception as e:
  482. logger.error(f"Node props search failed: {str(e)}")
  483. raise HTTPException(status_code=500, detail=str(e))
  484. text_search_router = router