text_search.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. from fastapi import APIRouter, HTTPException, Depends
  2. from pydantic import BaseModel, Field, validator
  3. from typing import List, Optional
  4. from service.trunks_service import TrunksService
  5. from utils.text_splitter import TextSplitter
  6. from utils.vector_distance import VectorDistance
  7. from model.response import StandardResponse
  8. from utils.vectorizer import Vectorizer
  9. # from utils.find_text_in_pdf import find_text_in_pdf
  10. import os
  11. DISTANCE_THRESHOLD = 0.73
  12. import logging
  13. import time
  14. from db.session import get_db
  15. from sqlalchemy.orm import Session
  16. from service.kg_node_service import KGNodeService
  17. from service.kg_prop_service import KGPropService
  18. from cachetools import TTLCache
  19. logger = logging.getLogger(__name__)
  20. router = APIRouter(prefix="/text", tags=["Text Search"])
  21. class TextSearchRequest(BaseModel):
  22. text: str
  23. conversation_id: Optional[str] = None
  24. need_convert: Optional[bool] = False
  25. class TextCompareRequest(BaseModel):
  26. sentence: str
  27. text: str
  28. class TextMatchRequest(BaseModel):
  29. text: str = Field(..., min_length=1, max_length=10000, description="需要搜索的文本内容")
  30. @validator('text')
  31. def validate_text(cls, v):
  32. # 保留所有可打印字符、换行符和中文字符
  33. v = ''.join(char for char in v if char.isprintable() or char in '\n\r')
  34. # 转义JSON特殊字符
  35. # 先处理反斜杠,避免后续转义时出现问题
  36. v = v.replace('\\', '\\\\')
  37. # 处理引号和其他特殊字符
  38. v = v.replace('"', '\\"')
  39. v = v.replace('/', '\\/')
  40. # 处理控制字符
  41. v = v.replace('\n', '\\n')
  42. v = v.replace('\r', '\\r')
  43. v = v.replace('\t', '\\t')
  44. v = v.replace('\b', '\\b')
  45. v = v.replace('\f', '\\f')
  46. # 处理Unicode转义
  47. # v = v.replace('\u', '\\u')
  48. return v
  49. class TextCompareMultiRequest(BaseModel):
  50. origin: str
  51. similar: str
  52. class NodePropsSearchRequest(BaseModel):
  53. node_id: int
  54. props_ids: List[int]
  55. @router.post("/search", response_model=StandardResponse)
  56. async def search_text(request: TextSearchRequest):
  57. try:
  58. #判断request.text是否为json格式,如果是,使用JsonToText的convert方法转换为text
  59. if request.text.startswith('{') and request.text.endswith('}'):
  60. from utils.json_to_text import JsonToTextConverter
  61. converter = JsonToTextConverter()
  62. request.text = converter.convert(request.text)
  63. # 使用TextSplitter拆分文本
  64. sentences = TextSplitter.split_text(request.text)
  65. if not sentences:
  66. return StandardResponse(success=True, data={"answer": "", "references": []})
  67. # 初始化服务和结果列表
  68. trunks_service = TrunksService()
  69. result_sentences = []
  70. all_references = []
  71. reference_index = 1
  72. # 根据conversation_id获取缓存结果
  73. cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
  74. for sentence in sentences:
  75. # if request.need_convert:
  76. sentence = sentence.replace("\n", "<br>")
  77. if len(sentence) < 10:
  78. result_sentences.append(sentence)
  79. continue
  80. if cached_results:
  81. # 如果有缓存结果,计算向量距离
  82. min_distance = float('inf')
  83. best_result = None
  84. sentence_vector = Vectorizer.get_embedding(sentence)
  85. for cached_result in cached_results:
  86. content_vector = cached_result['embedding']
  87. distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
  88. if distance < min_distance:
  89. min_distance = distance
  90. best_result = {**cached_result, 'distance': distance}
  91. if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
  92. search_results = [best_result]
  93. else:
  94. search_results = []
  95. else:
  96. # 如果没有缓存结果,进行向量搜索
  97. search_results = trunks_service.search_by_vector(
  98. text=sentence,
  99. limit=1,
  100. type='trunk'
  101. )
  102. # 处理搜索结果
  103. for search_result in search_results:
  104. distance = search_result.get("distance", DISTANCE_THRESHOLD)
  105. if distance >= DISTANCE_THRESHOLD:
  106. result_sentences.append(sentence)
  107. continue
  108. # 检查是否已存在相同引用
  109. existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
  110. current_index = reference_index
  111. if existing_ref:
  112. current_index = int(existing_ref["index"])
  113. else:
  114. # 添加到引用列表
  115. reference = {
  116. "index": str(reference_index),
  117. "id": search_result["id"],
  118. "content": search_result["content"],
  119. "file_path": search_result.get("file_path", ""),
  120. "title": search_result.get("title", ""),
  121. "distance": distance,
  122. "referrence": search_result.get("referrence", "")
  123. }
  124. all_references.append(reference)
  125. reference_index += 1
  126. # 添加引用标记
  127. if sentence.endswith('<br>'):
  128. # 如果有多个<br>,在所有<br>前添加^[current_index]^
  129. result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
  130. else:
  131. # 直接在句子末尾添加^[current_index]^
  132. result_sentence = f'{sentence}^[{current_index}]^'
  133. result_sentences.append(result_sentence)
  134. # 组装返回数据
  135. response_data = {
  136. "answer": result_sentences,
  137. "references": all_references
  138. }
  139. return StandardResponse(success=True, data=response_data)
  140. except Exception as e:
  141. logger.error(f"Text search failed: {str(e)}")
  142. raise HTTPException(status_code=500, detail=str(e))
  143. @router.post("/match", response_model=StandardResponse)
  144. async def match_text(request: TextCompareRequest):
  145. try:
  146. sentences = TextSplitter.split_text(request.text)
  147. sentence_vector = Vectorizer.get_embedding(request.sentence)
  148. min_distance = float('inf')
  149. best_sentence = ""
  150. result_sentences = []
  151. for temp in sentences:
  152. result_sentences.append(temp)
  153. if len(temp) < 10:
  154. continue
  155. temp_vector = Vectorizer.get_embedding(temp)
  156. distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
  157. if distance < min_distance and distance < DISTANCE_THRESHOLD:
  158. min_distance = distance
  159. best_sentence = temp
  160. for i in range(len(result_sentences)):
  161. result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
  162. if result_sentences[i]["sentence"] == best_sentence:
  163. result_sentences[i]["matched"] = True
  164. return StandardResponse(success=True, records=result_sentences)
  165. except Exception as e:
  166. logger.error(f"Text comparison failed: {str(e)}")
  167. raise HTTPException(status_code=500, detail=str(e))
  168. @router.post("/mr_search", response_model=StandardResponse)
  169. async def mr_search_text_content(request: TextMatchRequest):
  170. try:
  171. # 初始化服务
  172. trunks_service = TrunksService()
  173. # 获取文本向量并搜索相似内容
  174. search_results = trunks_service.search_by_vector(
  175. text=request.text,
  176. limit=10,
  177. type="mr"
  178. )
  179. # 处理搜索结果
  180. records = []
  181. for result in search_results:
  182. distance = result.get("distance", DISTANCE_THRESHOLD)
  183. if distance >= DISTANCE_THRESHOLD:
  184. continue
  185. # 添加到引用列表
  186. record = {
  187. "content": result["content"],
  188. "file_path": result.get("file_path", ""),
  189. "title": result.get("title", ""),
  190. "distance": distance,
  191. }
  192. records.append(record)
  193. # 组装返回数据
  194. response_data = {
  195. "records": records
  196. }
  197. return StandardResponse(success=True, data=response_data)
  198. except Exception as e:
  199. logger.error(f"Mr search failed: {str(e)}")
  200. raise HTTPException(status_code=500, detail=str(e))
  201. @router.post("/mr_match", response_model=StandardResponse)
  202. async def compare_text(request: TextCompareMultiRequest):
  203. start_time = time.time()
  204. try:
  205. # 拆分两段文本
  206. origin_sentences = TextSplitter.split_text(request.origin)
  207. similar_sentences = TextSplitter.split_text(request.similar)
  208. end_time = time.time()
  209. logger.info(f"mr_match接口处理文本耗时: {(end_time - start_time) * 1000:.2f}ms")
  210. # 初始化结果列表
  211. origin_results = []
  212. # 过滤短句并预计算向量
  213. valid_origin_sentences = [(sent, len(sent) >= 10) for sent in origin_sentences]
  214. valid_similar_sentences = [(sent, len(sent) >= 10) for sent in similar_sentences]
  215. # 初始化similar_results,所有matched设为False
  216. similar_results = [{"sentence": sent, "matched": False} for sent, _ in valid_similar_sentences]
  217. # 批量获取向量
  218. origin_vectors = {}
  219. similar_vectors = {}
  220. origin_batch = [sent for sent, is_valid in valid_origin_sentences if is_valid]
  221. similar_batch = [sent for sent, is_valid in valid_similar_sentences if is_valid]
  222. if origin_batch:
  223. origin_embeddings = [Vectorizer.get_embedding(sent) for sent in origin_batch]
  224. origin_vectors = dict(zip(origin_batch, origin_embeddings))
  225. if similar_batch:
  226. similar_embeddings = [Vectorizer.get_embedding(sent) for sent in similar_batch]
  227. similar_vectors = dict(zip(similar_batch, similar_embeddings))
  228. end_time = time.time()
  229. logger.info(f"mr_match接口处理向量耗时: {(end_time - start_time) * 1000:.2f}ms")
  230. # 处理origin文本
  231. for origin_sent, is_valid in valid_origin_sentences:
  232. if not is_valid:
  233. origin_results.append({"sentence": origin_sent, "matched": False})
  234. continue
  235. origin_vector = origin_vectors[origin_sent]
  236. matched = False
  237. # 优化的相似度计算
  238. for i, similar_result in enumerate(similar_results):
  239. if similar_result["matched"]:
  240. continue
  241. similar_sent = similar_result["sentence"]
  242. if len(similar_sent) < 10:
  243. continue
  244. similar_vector = similar_vectors.get(similar_sent)
  245. if not similar_vector:
  246. continue
  247. distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
  248. if distance < DISTANCE_THRESHOLD:
  249. matched = True
  250. similar_results[i]["matched"] = True
  251. break
  252. origin_results.append({"sentence": origin_sent, "matched": matched})
  253. response_data = {
  254. "origin": origin_results,
  255. "similar": similar_results
  256. }
  257. end_time = time.time()
  258. logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  259. return StandardResponse(success=True, data=response_data)
  260. except Exception as e:
  261. end_time = time.time()
  262. logger.error(f"Text comparison failed: {str(e)}")
  263. logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  264. raise HTTPException(status_code=500, detail=str(e))
  265. cache = TTLCache(maxsize=1000, ttl=3600)
  266. @router.post("/eb_search", response_model=StandardResponse)
  267. async def node_props_search(request: NodePropsSearchRequest, db: Session = Depends(get_db)):
  268. try:
  269. start_time = time.time()
  270. # 检查缓存
  271. cache_key = f"xunzheng_{request.node_id}"
  272. cached_result = cache.get(cache_key)
  273. if cached_result:
  274. logger.info(f"从缓存获取结果,node_id: {request.node_id}")
  275. return StandardResponse(success=True, data=cached_result)
  276. # 初始化服务
  277. trunks_service = TrunksService()
  278. node_service = KGNodeService(db)
  279. prop_service = KGPropService(db)
  280. # 根据node_id查询节点信息
  281. node = node_service.get_node(request.node_id)
  282. if not node:
  283. raise ValueError(f"节点不存在: {request.node_id}")
  284. node_name = node.get('name', '')
  285. # 初始化结果
  286. result = {
  287. "id": request.node_id,
  288. "name": node_name,
  289. "category": node.get('category', ''),
  290. "props": [],
  291. "files": [],
  292. "distance": 0
  293. }
  294. # 遍历props_ids查询属性信息
  295. for prop_id in request.props_ids:
  296. prop = prop_service.get_props_by_id(prop_id)
  297. if not prop:
  298. logger.warning(f"属性不存在: {prop_id}")
  299. continue
  300. prop_title = prop.get('prop_title', '')
  301. prop_value = prop.get('prop_value', '')
  302. # 先用完整的prop_value进行搜索
  303. search_text = f"{node_name}:{prop_title}:{prop_value}"
  304. full_search_results = trunks_service.search_by_vector(
  305. text=search_text,
  306. limit=1,
  307. type='trunk'
  308. )
  309. prop_result = {
  310. "id": prop_id,
  311. "category": prop.get('category', 0),
  312. "prop_name": prop.get('prop_name', ''),
  313. "prop_value": prop_value,
  314. "prop_title": prop_title,
  315. "type": prop.get('type', 1)
  316. }
  317. # 添加到结果中
  318. result["props"].append(prop_result)
  319. # 处理属性值中的句子
  320. result_sentences = []
  321. all_references = []
  322. reference_index = 1
  323. # 如果整体搜索找到匹配结果且距离小于阈值,直接使用整体结果
  324. if full_search_results and full_search_results[0].get("distance", DISTANCE_THRESHOLD) < DISTANCE_THRESHOLD:
  325. search_result = full_search_results[0]
  326. reference = {
  327. "index": str(reference_index),
  328. "id": search_result["id"],
  329. "content": search_result["content"],
  330. "file_path": search_result.get("file_path", ""),
  331. "title": search_result.get("title", ""),
  332. "distance": search_result.get("distance", DISTANCE_THRESHOLD),
  333. "page_no": search_result.get("page_no", ""),
  334. "referrence": search_result.get("referrence", "")
  335. }
  336. all_references.append(reference)
  337. result_sentences.append({
  338. "sentence": prop_value,
  339. "flag": str(reference_index)
  340. })
  341. else:
  342. # 如果整体搜索没有找到匹配结果,则进行句子拆分搜索
  343. sentences = TextSplitter.split_text(prop_value)
  344. i = 0
  345. while i < len(sentences):
  346. original_sentence = sentences[i]
  347. sentence = original_sentence
  348. # 如果当前句子长度小于10且不是最后一句,则与下一句合并
  349. if len(sentence) < 10 and i + 1 < len(sentences):
  350. next_sentence = sentences[i + 1]
  351. combined_sentence = sentence + " " + next_sentence
  352. # 添加原短句到结果,flag为空
  353. result_sentences.append({
  354. "sentence": sentence,
  355. "flag": ""
  356. })
  357. # 使用合并后的句子进行搜索
  358. search_text = f"{node_name}:{prop_title}:{combined_sentence}"
  359. i += 1 # 跳过下一句,因为已经合并使用
  360. elif len(sentence) < 10:
  361. # 如果是最后一句且长度小于10,直接添加到结果,flag为空
  362. result_sentences.append({
  363. "sentence": sentence,
  364. "flag": ""
  365. })
  366. i += 1
  367. continue
  368. else:
  369. # 句子长度足够,直接使用
  370. search_text = f"{node_name}:{prop_title}:{sentence}"
  371. i += 1
  372. # 进行向量搜索
  373. search_results = trunks_service.search_by_vector(
  374. text=search_text,
  375. limit=1,
  376. type='trunk'
  377. )
  378. # 处理搜索结果
  379. if not search_results:
  380. # 没有搜索结果,添加原句子,flag为空
  381. result_sentences.append({
  382. "sentence": sentence,
  383. "flag": ""
  384. })
  385. continue
  386. for search_result in search_results:
  387. distance = search_result.get("distance", DISTANCE_THRESHOLD)
  388. if distance >= DISTANCE_THRESHOLD:
  389. # 距离过大,添加原句子,flag为空
  390. result_sentences.append({
  391. "sentence": sentence,
  392. "flag": ""
  393. })
  394. continue
  395. # 检查是否已存在相同引用
  396. existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
  397. current_index = reference_index
  398. if existing_ref:
  399. current_index = int(existing_ref["index"])
  400. else:
  401. # 添加到引用列表
  402. reference = {
  403. "index": str(reference_index),
  404. "id": search_result["id"],
  405. "content": search_result["content"],
  406. "file_path": search_result.get("file_path", ""),
  407. "title": search_result.get("title", ""),
  408. "distance": distance,
  409. "page_no": search_result.get("page_no", ""),
  410. "referrence": search_result.get("referrence", "")
  411. }
  412. all_references.append(reference)
  413. reference_index += 1
  414. # 添加句子和引用标记(作为单独的flag字段)
  415. result_sentences.append({
  416. "sentence": sentence,
  417. "flag": str(current_index)
  418. })
  419. # 更新属性值,添加引用信息
  420. if all_references:
  421. prop_result["references"] = all_references
  422. # 将处理后的句子添加到结果中
  423. if result_sentences:
  424. prop_result["answer"] = result_sentences
  425. # 处理所有引用中的文件信息
  426. all_files = set()
  427. file_index_map = {}
  428. file_index = 1
  429. # 第一次遍历收集所有文件信息
  430. for prop_result in result["props"]:
  431. if "references" in prop_result:
  432. for ref in prop_result["references"]:
  433. referrence = ref.get("referrence", "")
  434. if referrence and "/books/" in referrence:
  435. # 提取/books/后面的文件名
  436. file_name = referrence.split("/books/")[-1]
  437. if file_name:
  438. # 根据文件名后缀确定文件类型
  439. file_type = ""
  440. if file_name.lower().endswith(".pdf"):
  441. file_type = "pdf"
  442. elif file_name.lower().endswith(".doc") or file_name.lower().endswith(".docx"):
  443. file_type = "doc"
  444. elif file_name.lower().endswith(".xls") or file_name.lower().endswith(".xlsx"):
  445. file_type = "excel"
  446. elif file_name.lower().endswith(".ppt") or file_name.lower().endswith(".pptx"):
  447. file_type = "ppt"
  448. else:
  449. file_type = "other"
  450. if file_name not in file_index_map:
  451. file_index_map[file_name] = file_index
  452. file_index += 1
  453. all_files.add((file_name, file_type))
  454. # 第二次遍历更新引用的index
  455. for prop_result in result["props"]:
  456. if "references" in prop_result:
  457. for ref in prop_result["references"]:
  458. referrence = ref.get("referrence", "")
  459. if referrence and "/books/" in referrence:
  460. file_name = referrence.split("/books/")[-1]
  461. if file_name in file_index_map:
  462. # 更新reference的index为"文件索引-原索引"
  463. ref["index"] = f"{file_index_map[file_name]}-{ref['index']}"
  464. # 更新answer中的flag
  465. if "answer" in prop_result:
  466. for sentence in prop_result["answer"]:
  467. if sentence["flag"]:
  468. for ref in prop_result["references"]:
  469. if ref["index"].endswith(f"-{sentence['flag']}"):
  470. sentence["flag"] = ref["index"]
  471. break
  472. # 将文件信息添加到结果中
  473. result["files"] = [{
  474. "file_name": file_name,
  475. "file_type": file_type,
  476. "index": str(file_index_map[file_name])
  477. } for file_name, file_type in all_files]
  478. end_time = time.time()
  479. logger.info(f"node_props_search接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  480. # 将结果存入缓存
  481. cache[cache_key] = result
  482. return StandardResponse(success=True, data=result)
  483. except Exception as e:
  484. logger.error(f"Node props search failed: {str(e)}")
  485. raise HTTPException(status_code=500, detail=str(e))
  486. text_search_router = router