text_search.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. from fastapi import APIRouter, HTTPException
  2. from pydantic import BaseModel, Field, validator
  3. from typing import List, Optional
  4. from service.trunks_service import TrunksService
  5. from utils.text_splitter import TextSplitter
  6. from utils.vector_distance import VectorDistance
  7. from model.response import StandardResponse
  8. from utils.vectorizer import Vectorizer
  9. DISTANCE_THRESHOLD = 0.8
  10. import logging
  11. import time
  12. logger = logging.getLogger(__name__)
  13. router = APIRouter(prefix="/text", tags=["Text Search"])
  14. class TextSearchRequest(BaseModel):
  15. text: str
  16. conversation_id: Optional[str] = None
  17. need_convert: Optional[bool] = False
  18. class TextCompareRequest(BaseModel):
  19. sentence: str
  20. text: str
  21. class TextMatchRequest(BaseModel):
  22. text: str = Field(..., min_length=1, max_length=10000, description="需要搜索的文本内容")
  23. @validator('text')
  24. def validate_text(cls, v):
  25. # 保留所有可打印字符、换行符和中文字符
  26. v = ''.join(char for char in v if char.isprintable() or char in '\n\r')
  27. # 转义JSON特殊字符
  28. # 先处理反斜杠,避免后续转义时出现问题
  29. v = v.replace('\\', '\\\\')
  30. # 处理引号和其他特殊字符
  31. v = v.replace('"', '\\"')
  32. v = v.replace('/', '\\/')
  33. # 处理控制字符
  34. v = v.replace('\n', '\\n')
  35. v = v.replace('\r', '\\r')
  36. v = v.replace('\t', '\\t')
  37. v = v.replace('\b', '\\b')
  38. v = v.replace('\f', '\\f')
  39. # 处理Unicode转义
  40. # v = v.replace('\u', '\\u')
  41. return v
  42. class TextCompareMultiRequest(BaseModel):
  43. origin: str
  44. similar: str
  45. @router.post("/search", response_model=StandardResponse)
  46. async def search_text(request: TextSearchRequest):
  47. try:
  48. #判断request.text是否为json格式,如果是,使用JsonToText的convert方法转换为text
  49. if request.text.startswith('{') and request.text.endswith('}'):
  50. from utils.json_to_text import JsonToTextConverter
  51. converter = JsonToTextConverter()
  52. request.text = converter.convert(request.text)
  53. # 使用TextSplitter拆分文本
  54. sentences = TextSplitter.split_text(request.text)
  55. if not sentences:
  56. return StandardResponse(success=True, data={"answer": "", "references": []})
  57. # 初始化服务和结果列表
  58. trunks_service = TrunksService()
  59. result_sentences = []
  60. all_references = []
  61. reference_index = 1
  62. # 根据conversation_id获取缓存结果
  63. cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
  64. for sentence in sentences:
  65. # if request.need_convert:
  66. sentence = sentence.replace("\n", "<br>")
  67. if len(sentence) < 10:
  68. result_sentences.append(sentence)
  69. continue
  70. if cached_results:
  71. # 如果有缓存结果,计算向量距离
  72. min_distance = float('inf')
  73. best_result = None
  74. sentence_vector = Vectorizer.get_embedding(sentence)
  75. for cached_result in cached_results:
  76. content_vector = cached_result['embedding']
  77. distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
  78. if distance < min_distance:
  79. min_distance = distance
  80. best_result = {**cached_result, 'distance': distance}
  81. if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
  82. search_results = [best_result]
  83. else:
  84. search_results = []
  85. else:
  86. # 如果没有缓存结果,进行向量搜索
  87. search_results = trunks_service.search_by_vector(
  88. text=sentence,
  89. limit=1,
  90. type='trunk'
  91. )
  92. # 处理搜索结果
  93. for result in search_results:
  94. distance = result.get("distance", DISTANCE_THRESHOLD)
  95. if distance >= DISTANCE_THRESHOLD:
  96. result_sentences.append(sentence)
  97. continue
  98. # 检查是否已存在相同引用
  99. existing_ref = next((ref for ref in all_references if ref["id"] == result["id"]), None)
  100. current_index = reference_index
  101. if existing_ref:
  102. current_index = int(existing_ref["index"])
  103. else:
  104. # 添加到引用列表
  105. reference = {
  106. "index": str(reference_index),
  107. "id": result["id"],
  108. "content": result["content"],
  109. "file_path": result.get("file_path", ""),
  110. "title": result.get("title", ""),
  111. "distance": distance,
  112. "referrence": result.get("referrence", "")
  113. }
  114. all_references.append(reference)
  115. reference_index += 1
  116. # 添加引用标记
  117. if sentence.endswith('<br>'):
  118. # 如果有多个<br>,在所有<br>前添加^[current_index]^
  119. result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
  120. else:
  121. # 直接在句子末尾添加^[current_index]^
  122. result_sentence = f'{sentence}^[{current_index}]^'
  123. result_sentences.append(result_sentence)
  124. # 组装返回数据
  125. response_data = {
  126. "answer": result_sentences,
  127. "references": all_references
  128. }
  129. return StandardResponse(success=True, data=response_data)
  130. except Exception as e:
  131. logger.error(f"Text search failed: {str(e)}")
  132. raise HTTPException(status_code=500, detail=str(e))
  133. @router.post("/match", response_model=StandardResponse)
  134. async def match_text(request: TextCompareRequest):
  135. try:
  136. sentences = TextSplitter.split_text(request.text)
  137. sentence_vector = Vectorizer.get_embedding(request.sentence)
  138. min_distance = float('inf')
  139. best_sentence = ""
  140. result_sentences = []
  141. for temp in sentences:
  142. result_sentences.append(temp)
  143. if len(temp) < 10:
  144. continue
  145. temp_vector = Vectorizer.get_embedding(temp)
  146. distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
  147. if distance < min_distance and distance < DISTANCE_THRESHOLD:
  148. min_distance = distance
  149. best_sentence = temp
  150. for i in range(len(result_sentences)):
  151. result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
  152. if result_sentences[i]["sentence"] == best_sentence:
  153. result_sentences[i]["matched"] = True
  154. return StandardResponse(success=True, records=result_sentences)
  155. except Exception as e:
  156. logger.error(f"Text comparison failed: {str(e)}")
  157. raise HTTPException(status_code=500, detail=str(e))
  158. @router.post("/mr_search", response_model=StandardResponse)
  159. async def mr_search_text_content(request: TextMatchRequest):
  160. try:
  161. # 初始化服务
  162. trunks_service = TrunksService()
  163. # 获取文本向量并搜索相似内容
  164. search_results = trunks_service.search_by_vector(
  165. text=request.text,
  166. limit=10,
  167. type="mr"
  168. )
  169. # 处理搜索结果
  170. records = []
  171. for result in search_results:
  172. distance = result.get("distance", DISTANCE_THRESHOLD)
  173. if distance >= DISTANCE_THRESHOLD:
  174. continue
  175. # 添加到引用列表
  176. record = {
  177. "content": result["content"],
  178. "file_path": result.get("file_path", ""),
  179. "title": result.get("title", ""),
  180. "distance": distance,
  181. }
  182. records.append(record)
  183. # 组装返回数据
  184. response_data = {
  185. "records": records
  186. }
  187. return StandardResponse(success=True, data=response_data)
  188. except Exception as e:
  189. logger.error(f"Mr search failed: {str(e)}")
  190. raise HTTPException(status_code=500, detail=str(e))
  191. @router.post("/mr_match", response_model=StandardResponse)
  192. async def compare_text(request: TextCompareMultiRequest):
  193. start_time = time.time()
  194. try:
  195. # 拆分两段文本
  196. origin_sentences = TextSplitter.split_text(request.origin)
  197. similar_sentences = TextSplitter.split_text(request.similar)
  198. end_time = time.time()
  199. logger.info(f"mr_match接口处理文本耗时: {(end_time - start_time) * 1000:.2f}ms")
  200. # 初始化结果列表
  201. origin_results = []
  202. # 过滤短句并预计算向量
  203. valid_origin_sentences = [(sent, len(sent) >= 10) for sent in origin_sentences]
  204. valid_similar_sentences = [(sent, len(sent) >= 10) for sent in similar_sentences]
  205. # 初始化similar_results,所有matched设为False
  206. similar_results = [{"sentence": sent, "matched": False} for sent, _ in valid_similar_sentences]
  207. # 批量获取向量
  208. origin_vectors = {}
  209. similar_vectors = {}
  210. origin_batch = [sent for sent, is_valid in valid_origin_sentences if is_valid]
  211. similar_batch = [sent for sent, is_valid in valid_similar_sentences if is_valid]
  212. if origin_batch:
  213. origin_embeddings = [Vectorizer.get_embedding(sent) for sent in origin_batch]
  214. origin_vectors = dict(zip(origin_batch, origin_embeddings))
  215. if similar_batch:
  216. similar_embeddings = [Vectorizer.get_embedding(sent) for sent in similar_batch]
  217. similar_vectors = dict(zip(similar_batch, similar_embeddings))
  218. end_time = time.time()
  219. logger.info(f"mr_match接口处理向量耗时: {(end_time - start_time) * 1000:.2f}ms")
  220. # 处理origin文本
  221. for origin_sent, is_valid in valid_origin_sentences:
  222. if not is_valid:
  223. origin_results.append({"sentence": origin_sent, "matched": False})
  224. continue
  225. origin_vector = origin_vectors[origin_sent]
  226. matched = False
  227. # 优化的相似度计算
  228. for i, similar_result in enumerate(similar_results):
  229. if similar_result["matched"]:
  230. continue
  231. similar_sent = similar_result["sentence"]
  232. if len(similar_sent) < 10:
  233. continue
  234. similar_vector = similar_vectors.get(similar_sent)
  235. if not similar_vector:
  236. continue
  237. distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
  238. if distance < DISTANCE_THRESHOLD:
  239. matched = True
  240. similar_results[i]["matched"] = True
  241. break
  242. origin_results.append({"sentence": origin_sent, "matched": matched})
  243. response_data = {
  244. "origin": origin_results,
  245. "similar": similar_results
  246. }
  247. end_time = time.time()
  248. logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  249. return StandardResponse(success=True, data=response_data)
  250. except Exception as e:
  251. end_time = time.time()
  252. logger.error(f"Text comparison failed: {str(e)}")
  253. logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  254. raise HTTPException(status_code=500, detail=str(e))
  255. text_search_router = router