text_search.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. from fastapi import APIRouter, HTTPException, Depends
  2. from pydantic import BaseModel, Field, validator
  3. from typing import List, Optional
  4. from service.trunks_service import TrunksService
  5. from utils.text_splitter import TextSplitter
  6. from utils.vector_distance import VectorDistance
  7. from model.response import StandardResponse
  8. from utils.vectorizer import Vectorizer
  9. # from utils.find_text_in_pdf import find_text_in_pdf
  10. import os
  11. DISTANCE_THRESHOLD = 0.8
  12. import logging
  13. import time
  14. from db.session import get_db
  15. from sqlalchemy.orm import Session
  16. from service.kg_node_service import KGNodeService
  17. from service.kg_prop_service import KGPropService
  18. logger = logging.getLogger(__name__)
  19. router = APIRouter(prefix="/text", tags=["Text Search"])
  20. class TextSearchRequest(BaseModel):
  21. text: str
  22. conversation_id: Optional[str] = None
  23. need_convert: Optional[bool] = False
  24. class TextCompareRequest(BaseModel):
  25. sentence: str
  26. text: str
  27. class TextMatchRequest(BaseModel):
  28. text: str = Field(..., min_length=1, max_length=10000, description="需要搜索的文本内容")
  29. @validator('text')
  30. def validate_text(cls, v):
  31. # 保留所有可打印字符、换行符和中文字符
  32. v = ''.join(char for char in v if char.isprintable() or char in '\n\r')
  33. # 转义JSON特殊字符
  34. # 先处理反斜杠,避免后续转义时出现问题
  35. v = v.replace('\\', '\\\\')
  36. # 处理引号和其他特殊字符
  37. v = v.replace('"', '\\"')
  38. v = v.replace('/', '\\/')
  39. # 处理控制字符
  40. v = v.replace('\n', '\\n')
  41. v = v.replace('\r', '\\r')
  42. v = v.replace('\t', '\\t')
  43. v = v.replace('\b', '\\b')
  44. v = v.replace('\f', '\\f')
  45. # 处理Unicode转义
  46. # v = v.replace('\u', '\\u')
  47. return v
  48. class TextCompareMultiRequest(BaseModel):
  49. origin: str
  50. similar: str
  51. class NodePropsSearchRequest(BaseModel):
  52. node_id: int
  53. props_ids: List[int]
  54. @router.post("/search", response_model=StandardResponse)
  55. async def search_text(request: TextSearchRequest):
  56. try:
  57. #判断request.text是否为json格式,如果是,使用JsonToText的convert方法转换为text
  58. if request.text.startswith('{') and request.text.endswith('}'):
  59. from utils.json_to_text import JsonToTextConverter
  60. converter = JsonToTextConverter()
  61. request.text = converter.convert(request.text)
  62. # 使用TextSplitter拆分文本
  63. sentences = TextSplitter.split_text(request.text)
  64. if not sentences:
  65. return StandardResponse(success=True, data={"answer": "", "references": []})
  66. # 初始化服务和结果列表
  67. trunks_service = TrunksService()
  68. result_sentences = []
  69. all_references = []
  70. reference_index = 1
  71. # 根据conversation_id获取缓存结果
  72. cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
  73. for sentence in sentences:
  74. # if request.need_convert:
  75. sentence = sentence.replace("\n", "<br>")
  76. if len(sentence) < 10:
  77. result_sentences.append(sentence)
  78. continue
  79. if cached_results:
  80. # 如果有缓存结果,计算向量距离
  81. min_distance = float('inf')
  82. best_result = None
  83. sentence_vector = Vectorizer.get_embedding(sentence)
  84. for cached_result in cached_results:
  85. content_vector = cached_result['embedding']
  86. distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
  87. if distance < min_distance:
  88. min_distance = distance
  89. best_result = {**cached_result, 'distance': distance}
  90. if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
  91. search_results = [best_result]
  92. else:
  93. search_results = []
  94. else:
  95. # 如果没有缓存结果,进行向量搜索
  96. search_results = trunks_service.search_by_vector(
  97. text=sentence,
  98. limit=1,
  99. type='trunk'
  100. )
  101. # 处理搜索结果
  102. for search_result in search_results:
  103. distance = search_result.get("distance", DISTANCE_THRESHOLD)
  104. if distance >= DISTANCE_THRESHOLD:
  105. result_sentences.append(sentence)
  106. continue
  107. # 检查是否已存在相同引用
  108. existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
  109. current_index = reference_index
  110. if existing_ref:
  111. current_index = int(existing_ref["index"])
  112. else:
  113. # 添加到引用列表
  114. reference = {
  115. "index": str(reference_index),
  116. "id": search_result["id"],
  117. "content": search_result["content"],
  118. "file_path": search_result.get("file_path", ""),
  119. "title": search_result.get("title", ""),
  120. "distance": distance,
  121. "referrence": search_result.get("referrence", "")
  122. }
  123. all_references.append(reference)
  124. reference_index += 1
  125. # 添加引用标记
  126. if sentence.endswith('<br>'):
  127. # 如果有多个<br>,在所有<br>前添加^[current_index]^
  128. result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
  129. else:
  130. # 直接在句子末尾添加^[current_index]^
  131. result_sentence = f'{sentence}^[{current_index}]^'
  132. result_sentences.append(result_sentence)
  133. # 组装返回数据
  134. response_data = {
  135. "answer": result_sentences,
  136. "references": all_references
  137. }
  138. return StandardResponse(success=True, data=response_data)
  139. except Exception as e:
  140. logger.error(f"Text search failed: {str(e)}")
  141. raise HTTPException(status_code=500, detail=str(e))
  142. @router.post("/match", response_model=StandardResponse)
  143. async def match_text(request: TextCompareRequest):
  144. try:
  145. sentences = TextSplitter.split_text(request.text)
  146. sentence_vector = Vectorizer.get_embedding(request.sentence)
  147. min_distance = float('inf')
  148. best_sentence = ""
  149. result_sentences = []
  150. for temp in sentences:
  151. result_sentences.append(temp)
  152. if len(temp) < 10:
  153. continue
  154. temp_vector = Vectorizer.get_embedding(temp)
  155. distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
  156. if distance < min_distance and distance < DISTANCE_THRESHOLD:
  157. min_distance = distance
  158. best_sentence = temp
  159. for i in range(len(result_sentences)):
  160. result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
  161. if result_sentences[i]["sentence"] == best_sentence:
  162. result_sentences[i]["matched"] = True
  163. return StandardResponse(success=True, records=result_sentences)
  164. except Exception as e:
  165. logger.error(f"Text comparison failed: {str(e)}")
  166. raise HTTPException(status_code=500, detail=str(e))
  167. @router.post("/mr_search", response_model=StandardResponse)
  168. async def mr_search_text_content(request: TextMatchRequest):
  169. try:
  170. # 初始化服务
  171. trunks_service = TrunksService()
  172. # 获取文本向量并搜索相似内容
  173. search_results = trunks_service.search_by_vector(
  174. text=request.text,
  175. limit=10,
  176. type="mr"
  177. )
  178. # 处理搜索结果
  179. records = []
  180. for result in search_results:
  181. distance = result.get("distance", DISTANCE_THRESHOLD)
  182. if distance >= DISTANCE_THRESHOLD:
  183. continue
  184. # 添加到引用列表
  185. record = {
  186. "content": result["content"],
  187. "file_path": result.get("file_path", ""),
  188. "title": result.get("title", ""),
  189. "distance": distance,
  190. }
  191. records.append(record)
  192. # 组装返回数据
  193. response_data = {
  194. "records": records
  195. }
  196. return StandardResponse(success=True, data=response_data)
  197. except Exception as e:
  198. logger.error(f"Mr search failed: {str(e)}")
  199. raise HTTPException(status_code=500, detail=str(e))
  200. @router.post("/mr_match", response_model=StandardResponse)
  201. async def compare_text(request: TextCompareMultiRequest):
  202. start_time = time.time()
  203. try:
  204. # 拆分两段文本
  205. origin_sentences = TextSplitter.split_text(request.origin)
  206. similar_sentences = TextSplitter.split_text(request.similar)
  207. end_time = time.time()
  208. logger.info(f"mr_match接口处理文本耗时: {(end_time - start_time) * 1000:.2f}ms")
  209. # 初始化结果列表
  210. origin_results = []
  211. # 过滤短句并预计算向量
  212. valid_origin_sentences = [(sent, len(sent) >= 10) for sent in origin_sentences]
  213. valid_similar_sentences = [(sent, len(sent) >= 10) for sent in similar_sentences]
  214. # 初始化similar_results,所有matched设为False
  215. similar_results = [{"sentence": sent, "matched": False} for sent, _ in valid_similar_sentences]
  216. # 批量获取向量
  217. origin_vectors = {}
  218. similar_vectors = {}
  219. origin_batch = [sent for sent, is_valid in valid_origin_sentences if is_valid]
  220. similar_batch = [sent for sent, is_valid in valid_similar_sentences if is_valid]
  221. if origin_batch:
  222. origin_embeddings = [Vectorizer.get_embedding(sent) for sent in origin_batch]
  223. origin_vectors = dict(zip(origin_batch, origin_embeddings))
  224. if similar_batch:
  225. similar_embeddings = [Vectorizer.get_embedding(sent) for sent in similar_batch]
  226. similar_vectors = dict(zip(similar_batch, similar_embeddings))
  227. end_time = time.time()
  228. logger.info(f"mr_match接口处理向量耗时: {(end_time - start_time) * 1000:.2f}ms")
  229. # 处理origin文本
  230. for origin_sent, is_valid in valid_origin_sentences:
  231. if not is_valid:
  232. origin_results.append({"sentence": origin_sent, "matched": False})
  233. continue
  234. origin_vector = origin_vectors[origin_sent]
  235. matched = False
  236. # 优化的相似度计算
  237. for i, similar_result in enumerate(similar_results):
  238. if similar_result["matched"]:
  239. continue
  240. similar_sent = similar_result["sentence"]
  241. if len(similar_sent) < 10:
  242. continue
  243. similar_vector = similar_vectors.get(similar_sent)
  244. if not similar_vector:
  245. continue
  246. distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
  247. if distance < DISTANCE_THRESHOLD:
  248. matched = True
  249. similar_results[i]["matched"] = True
  250. break
  251. origin_results.append({"sentence": origin_sent, "matched": matched})
  252. response_data = {
  253. "origin": origin_results,
  254. "similar": similar_results
  255. }
  256. end_time = time.time()
  257. logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  258. return StandardResponse(success=True, data=response_data)
  259. except Exception as e:
  260. end_time = time.time()
  261. logger.error(f"Text comparison failed: {str(e)}")
  262. logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  263. raise HTTPException(status_code=500, detail=str(e))
  264. @router.post("/eb_search", response_model=StandardResponse)
  265. async def node_props_search(request: NodePropsSearchRequest, db: Session = Depends(get_db)):
  266. try:
  267. start_time = time.time()
  268. # 初始化服务
  269. trunks_service = TrunksService()
  270. node_service = KGNodeService(db)
  271. prop_service = KGPropService(db)
  272. # 根据node_id查询节点信息
  273. node = node_service.get_node(request.node_id)
  274. if not node:
  275. raise ValueError(f"节点不存在: {request.node_id}")
  276. node_name = node.get('name', '')
  277. # 初始化结果
  278. result = {
  279. "id": request.node_id,
  280. "name": node_name,
  281. "category": node.get('category', ''),
  282. "props": [],
  283. "files": [],
  284. "distance": 0
  285. }
  286. # 遍历props_ids查询属性信息
  287. for prop_id in request.props_ids:
  288. prop = prop_service.get_props_by_id(prop_id)
  289. if not prop:
  290. logger.warning(f"属性不存在: {prop_id}")
  291. continue
  292. prop_title = prop.get('prop_title', '')
  293. prop_value = prop.get('prop_value', '')
  294. # 拆分属性值为句子
  295. sentences = TextSplitter.split_text(prop_value)
  296. prop_result = {
  297. "id": prop_id,
  298. "category": prop.get('category', 0),
  299. "prop_name": prop.get('prop_name', ''),
  300. "prop_value": prop_value,
  301. "prop_title": prop_title,
  302. "type": prop.get('type', 1)
  303. }
  304. # 添加到结果中
  305. result["props"].append(prop_result)
  306. # 处理属性值中的句子
  307. result_sentences = []
  308. all_references = []
  309. reference_index = 1
  310. # 对每个句子进行向量搜索
  311. i = 0
  312. while i < len(sentences):
  313. original_sentence = sentences[i]
  314. sentence = original_sentence
  315. # 如果当前句子长度小于10且不是最后一句,则与下一句合并
  316. if len(sentence) < 10 and i + 1 < len(sentences):
  317. next_sentence = sentences[i + 1]
  318. combined_sentence = sentence + " " + next_sentence
  319. # 添加原短句到结果,flag为空
  320. result_sentences.append({
  321. "sentence": sentence,
  322. "flag": ""
  323. })
  324. # 使用合并后的句子进行搜索
  325. search_text = f"{node_name}:{prop_title}:{combined_sentence}"
  326. i += 1 # 跳过下一句,因为已经合并使用
  327. elif len(sentence) < 10:
  328. # 如果是最后一句且长度小于10,直接添加到结果,flag为空
  329. result_sentences.append({
  330. "sentence": sentence,
  331. "flag": ""
  332. })
  333. i += 1
  334. continue
  335. else:
  336. # 句子长度足够,直接使用
  337. search_text = f"{node_name}:{prop_title}:{sentence}"
  338. i += 1
  339. # 进行向量搜索
  340. search_results = trunks_service.search_by_vector(
  341. text=search_text,
  342. limit=1,
  343. type='trunk'
  344. )
  345. # 处理搜索结果
  346. if not search_results:
  347. # 没有搜索结果,添加原句子,flag为空
  348. result_sentences.append({
  349. "sentence": sentence,
  350. "flag": ""
  351. })
  352. continue
  353. for search_result in search_results:
  354. distance = search_result.get("distance", DISTANCE_THRESHOLD)
  355. if distance >= DISTANCE_THRESHOLD:
  356. # 距离过大,添加原句子,flag为空
  357. result_sentences.append({
  358. "sentence": sentence,
  359. "flag": ""
  360. })
  361. continue
  362. # 检查是否已存在相同引用
  363. existing_ref = next((ref for ref in all_references if ref["id"] == search_result["id"]), None)
  364. current_index = reference_index
  365. if existing_ref:
  366. current_index = int(existing_ref["index"])
  367. else:
  368. # 添加到引用列表
  369. reference = {
  370. "index": str(reference_index),
  371. "id": search_result["id"],
  372. "content": search_result["content"],
  373. "file_path": search_result.get("file_path", ""),
  374. "title": search_result.get("title", ""),
  375. "distance": distance,
  376. "page_no": search_result.get("page_no", ""),
  377. "referrence": search_result.get("referrence", "")
  378. }
  379. all_references.append(reference)
  380. reference_index += 1
  381. # 添加句子和引用标记(作为单独的flag字段)
  382. result_sentences.append({
  383. "sentence": sentence,
  384. "flag": str(current_index)
  385. })
  386. # 更新属性值,添加引用信息
  387. if all_references:
  388. prop_result["references"] = all_references
  389. # 将处理后的句子添加到结果中
  390. if result_sentences:
  391. prop_result["answer"] = result_sentences
  392. # 处理所有引用中的文件信息
  393. all_files = set()
  394. for prop_result in result["props"]:
  395. if "references" in prop_result:
  396. for ref in prop_result["references"]:
  397. referrence = ref.get("referrence", "")
  398. if referrence and "/books/" in referrence:
  399. # 提取/books/后面的文件名
  400. file_name = referrence.split("/books/")[-1]
  401. if file_name:
  402. # 根据文件名后缀确定文件类型
  403. file_type = ""
  404. if file_name.lower().endswith(".pdf"):
  405. file_type = "pdf"
  406. elif file_name.lower().endswith(".doc") or file_name.lower().endswith(".docx"):
  407. file_type = "doc"
  408. elif file_name.lower().endswith(".xls") or file_name.lower().endswith(".xlsx"):
  409. file_type = "excel"
  410. elif file_name.lower().endswith(".ppt") or file_name.lower().endswith(".pptx"):
  411. file_type = "ppt"
  412. else:
  413. file_type = "other"
  414. all_files.add((file_name, file_type))
  415. # 将文件信息添加到结果中
  416. result["files"] = [{
  417. "file_name": file_name,
  418. "file_type": file_type
  419. } for file_name, file_type in all_files]
  420. end_time = time.time()
  421. logger.info(f"node_props_search接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  422. return StandardResponse(success=True, data=result)
  423. except Exception as e:
  424. logger.error(f"Node props search failed: {str(e)}")
  425. raise HTTPException(status_code=500, detail=str(e))
  426. text_search_router = router