text_search.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. from fastapi import APIRouter, HTTPException
  2. from pydantic import BaseModel, Field, validator
  3. from typing import List, Optional
  4. from service.trunks_service import TrunksService
  5. from utils.text_splitter import TextSplitter
  6. from utils.vector_distance import VectorDistance
  7. from model.response import StandardResponse
  8. from utils.vectorizer import Vectorizer
  9. DISTANCE_THRESHOLD = 0.8
  10. import logging
  11. import time
  12. logger = logging.getLogger(__name__)
  13. router = APIRouter(prefix="/text", tags=["Text Search"])
  14. class TextSearchRequest(BaseModel):
  15. text: str
  16. conversation_id: Optional[str] = None
  17. need_convert: Optional[bool] = False
  18. class TextCompareRequest(BaseModel):
  19. sentence: str
  20. text: str
  21. class TextMatchRequest(BaseModel):
  22. text: str = Field(..., min_length=1, max_length=10000, description="需要搜索的文本内容")
  23. @validator('text')
  24. def validate_text(cls, v):
  25. # 保留所有可打印字符、换行符和中文字符
  26. v = ''.join(char for char in v if char.isprintable() or char in '\n\r')
  27. # 转义JSON特殊字符
  28. # 先处理反斜杠,避免后续转义时出现问题
  29. v = v.replace('\\', '\\\\')
  30. # 处理引号和其他特殊字符
  31. v = v.replace('"', '\\"')
  32. v = v.replace('/', '\\/')
  33. # 处理控制字符
  34. v = v.replace('\n', '\\n')
  35. v = v.replace('\r', '\\r')
  36. v = v.replace('\t', '\\t')
  37. v = v.replace('\b', '\\b')
  38. v = v.replace('\f', '\\f')
  39. # 处理Unicode转义
  40. # v = v.replace('\u', '\\u')
  41. return v
  42. class TextCompareMultiRequest(BaseModel):
  43. origin: str
  44. similar: str
  45. @router.post("/search", response_model=StandardResponse)
  46. async def search_text(request: TextSearchRequest):
  47. try:
  48. # 使用TextSplitter拆分文本
  49. sentences = TextSplitter.split_text(request.text)
  50. if not sentences:
  51. return StandardResponse(success=True, data={"answer": "", "references": []})
  52. # 初始化服务和结果列表
  53. trunks_service = TrunksService()
  54. result_sentences = []
  55. all_references = []
  56. reference_index = 1
  57. # 根据conversation_id获取缓存结果
  58. cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
  59. for sentence in sentences:
  60. # if request.need_convert:
  61. sentence = sentence.replace("\n", "<br>")
  62. if len(sentence) < 10:
  63. result_sentences.append(sentence)
  64. continue
  65. if cached_results:
  66. # 如果有缓存结果,计算向量距离
  67. min_distance = float('inf')
  68. best_result = None
  69. sentence_vector = Vectorizer.get_embedding(sentence)
  70. for cached_result in cached_results:
  71. content_vector = cached_result['embedding']
  72. distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
  73. if distance < min_distance:
  74. min_distance = distance
  75. best_result = {**cached_result, 'distance': distance}
  76. if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
  77. search_results = [best_result]
  78. else:
  79. search_results = []
  80. else:
  81. # 如果没有缓存结果,进行向量搜索
  82. search_results = trunks_service.search_by_vector(
  83. text=sentence,
  84. limit=1,
  85. type='trunk'
  86. )
  87. # 处理搜索结果
  88. for result in search_results:
  89. distance = result.get("distance", DISTANCE_THRESHOLD)
  90. if distance >= DISTANCE_THRESHOLD:
  91. result_sentences.append(sentence)
  92. continue
  93. # 检查是否已存在相同引用
  94. existing_ref = next((ref for ref in all_references if ref["id"] == result["id"]), None)
  95. current_index = reference_index
  96. if existing_ref:
  97. current_index = int(existing_ref["index"])
  98. else:
  99. # 添加到引用列表
  100. reference = {
  101. "index": str(reference_index),
  102. "id": result["id"],
  103. "content": result["content"],
  104. "file_path": result.get("file_path", ""),
  105. "title": result.get("title", ""),
  106. "distance": distance,
  107. "referrence": result.get("referrence", "")
  108. }
  109. all_references.append(reference)
  110. reference_index += 1
  111. # 添加引用标记
  112. if sentence.endswith('<br>'):
  113. # 如果有多个<br>,在所有<br>前添加^[current_index]^
  114. result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
  115. else:
  116. # 直接在句子末尾添加^[current_index]^
  117. result_sentence = f'{sentence}^[{current_index}]^'
  118. result_sentences.append(result_sentence)
  119. # 组装返回数据
  120. response_data = {
  121. "answer": result_sentences,
  122. "references": all_references
  123. }
  124. return StandardResponse(success=True, data=response_data)
  125. except Exception as e:
  126. logger.error(f"Text search failed: {str(e)}")
  127. raise HTTPException(status_code=500, detail=str(e))
  128. @router.post("/match", response_model=StandardResponse)
  129. async def match_text(request: TextCompareRequest):
  130. try:
  131. sentences = TextSplitter.split_text(request.text)
  132. sentence_vector = Vectorizer.get_embedding(request.sentence)
  133. min_distance = float('inf')
  134. best_sentence = ""
  135. result_sentences = []
  136. for temp in sentences:
  137. result_sentences.append(temp)
  138. if len(temp) < 10:
  139. continue
  140. temp_vector = Vectorizer.get_embedding(temp)
  141. distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
  142. if distance < min_distance and distance < DISTANCE_THRESHOLD:
  143. min_distance = distance
  144. best_sentence = temp
  145. for i in range(len(result_sentences)):
  146. result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
  147. if result_sentences[i]["sentence"] == best_sentence:
  148. result_sentences[i]["matched"] = True
  149. return StandardResponse(success=True, records=result_sentences)
  150. except Exception as e:
  151. logger.error(f"Text comparison failed: {str(e)}")
  152. raise HTTPException(status_code=500, detail=str(e))
  153. @router.post("/mr_search", response_model=StandardResponse)
  154. async def mr_search_text_content(request: TextMatchRequest):
  155. try:
  156. # 初始化服务
  157. trunks_service = TrunksService()
  158. # 获取文本向量并搜索相似内容
  159. search_results = trunks_service.search_by_vector(
  160. text=request.text,
  161. limit=10,
  162. type="mr"
  163. )
  164. # 处理搜索结果
  165. records = []
  166. for result in search_results:
  167. distance = result.get("distance", DISTANCE_THRESHOLD)
  168. if distance >= DISTANCE_THRESHOLD:
  169. continue
  170. # 添加到引用列表
  171. record = {
  172. "content": result["content"],
  173. "file_path": result.get("file_path", ""),
  174. "title": result.get("title", ""),
  175. "distance": distance,
  176. }
  177. records.append(record)
  178. # 组装返回数据
  179. response_data = {
  180. "records": records
  181. }
  182. return StandardResponse(success=True, data=response_data)
  183. except Exception as e:
  184. logger.error(f"Mr search failed: {str(e)}")
  185. raise HTTPException(status_code=500, detail=str(e))
  186. @router.post("/mr_match", response_model=StandardResponse)
  187. async def compare_text(request: TextCompareMultiRequest):
  188. start_time = time.time()
  189. try:
  190. # 拆分两段文本
  191. origin_sentences = TextSplitter.split_text(request.origin)
  192. similar_sentences = TextSplitter.split_text(request.similar)
  193. end_time = time.time()
  194. logger.info(f"mr_match接口处理文本耗时: {(end_time - start_time) * 1000:.2f}ms")
  195. # 初始化结果列表
  196. origin_results = []
  197. # 过滤短句并预计算向量
  198. valid_origin_sentences = [(sent, len(sent) >= 10) for sent in origin_sentences]
  199. valid_similar_sentences = [(sent, len(sent) >= 10) for sent in similar_sentences]
  200. # 初始化similar_results,所有matched设为False
  201. similar_results = [{"sentence": sent, "matched": False} for sent, _ in valid_similar_sentences]
  202. # 批量获取向量
  203. origin_vectors = {}
  204. similar_vectors = {}
  205. origin_batch = [sent for sent, is_valid in valid_origin_sentences if is_valid]
  206. similar_batch = [sent for sent, is_valid in valid_similar_sentences if is_valid]
  207. if origin_batch:
  208. origin_embeddings = [Vectorizer.get_embedding(sent) for sent in origin_batch]
  209. origin_vectors = dict(zip(origin_batch, origin_embeddings))
  210. if similar_batch:
  211. similar_embeddings = [Vectorizer.get_embedding(sent) for sent in similar_batch]
  212. similar_vectors = dict(zip(similar_batch, similar_embeddings))
  213. end_time = time.time()
  214. logger.info(f"mr_match接口处理向量耗时: {(end_time - start_time) * 1000:.2f}ms")
  215. # 处理origin文本
  216. for origin_sent, is_valid in valid_origin_sentences:
  217. if not is_valid:
  218. origin_results.append({"sentence": origin_sent, "matched": False})
  219. continue
  220. origin_vector = origin_vectors[origin_sent]
  221. matched = False
  222. # 优化的相似度计算
  223. for i, similar_result in enumerate(similar_results):
  224. if similar_result["matched"]:
  225. continue
  226. similar_sent = similar_result["sentence"]
  227. if len(similar_sent) < 10:
  228. continue
  229. similar_vector = similar_vectors.get(similar_sent)
  230. if not similar_vector:
  231. continue
  232. distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
  233. if distance < DISTANCE_THRESHOLD:
  234. matched = True
  235. similar_results[i]["matched"] = True
  236. break
  237. origin_results.append({"sentence": origin_sent, "matched": matched})
  238. response_data = {
  239. "origin": origin_results,
  240. "similar": similar_results
  241. }
  242. end_time = time.time()
  243. logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  244. return StandardResponse(success=True, data=response_data)
  245. except Exception as e:
  246. end_time = time.time()
  247. logger.error(f"Text comparison failed: {str(e)}")
  248. logger.info(f"mr_match接口耗时: {(end_time - start_time) * 1000:.2f}ms")
  249. raise HTTPException(status_code=500, detail=str(e))
  250. text_search_router = router