text_search.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. from fastapi import APIRouter, HTTPException
  2. from pydantic import BaseModel, Field, validator
  3. from typing import List, Optional
  4. from service.trunks_service import TrunksService
  5. from utils.text_splitter import TextSplitter
  6. from utils.vector_distance import VectorDistance
  7. from model.response import StandardResponse
  8. from utils.vectorizer import Vectorizer
  9. DISTANCE_THRESHOLD = 0.8
  10. import logging
  11. logger = logging.getLogger(__name__)
  12. router = APIRouter(prefix="/text", tags=["Text Search"])
  13. class TextSearchRequest(BaseModel):
  14. text: str
  15. conversation_id: Optional[str] = None
  16. need_convert: Optional[bool] = False
  17. class TextCompareRequest(BaseModel):
  18. sentence: str
  19. text: str
  20. class TextMatchRequest(BaseModel):
  21. text: str = Field(..., min_length=1, max_length=10000, description="需要搜索的文本内容")
  22. @validator('text')
  23. def validate_text(cls, v):
  24. # 保留所有可打印字符、换行符和中文字符
  25. v = ''.join(char for char in v if char.isprintable() or char in '\n\r')
  26. # 转义JSON特殊字符
  27. # 先处理反斜杠,避免后续转义时出现问题
  28. v = v.replace('\\', '\\\\')
  29. # 处理引号和其他特殊字符
  30. v = v.replace('"', '\\"')
  31. v = v.replace('/', '\\/')
  32. # 处理控制字符
  33. v = v.replace('\n', '\\n')
  34. v = v.replace('\r', '\\r')
  35. v = v.replace('\t', '\\t')
  36. v = v.replace('\b', '\\b')
  37. v = v.replace('\f', '\\f')
  38. # 处理Unicode转义
  39. # v = v.replace('\u', '\\u')
  40. return v
  41. class TextCompareMultiRequest(BaseModel):
  42. origin: str
  43. similar: str
  44. @router.post("/search", response_model=StandardResponse)
  45. async def search_text(request: TextSearchRequest):
  46. try:
  47. # 使用TextSplitter拆分文本
  48. sentences = TextSplitter.split_text(request.text)
  49. if not sentences:
  50. return StandardResponse(success=True, data={"answer": "", "references": []})
  51. # 初始化服务和结果列表
  52. trunks_service = TrunksService()
  53. result_sentences = []
  54. all_references = []
  55. reference_index = 1
  56. # 根据conversation_id获取缓存结果
  57. cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
  58. for sentence in sentences:
  59. # if request.need_convert:
  60. sentence = sentence.replace("\n", "<br>")
  61. if len(sentence) < 10:
  62. result_sentences.append(sentence)
  63. continue
  64. if cached_results:
  65. # 如果有缓存结果,计算向量距离
  66. min_distance = float('inf')
  67. best_result = None
  68. sentence_vector = Vectorizer.get_embedding(sentence)
  69. for cached_result in cached_results:
  70. content_vector = cached_result['embedding']
  71. distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
  72. if distance < min_distance:
  73. min_distance = distance
  74. best_result = {**cached_result, 'distance': distance}
  75. if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
  76. search_results = [best_result]
  77. else:
  78. search_results = []
  79. else:
  80. # 如果没有缓存结果,进行向量搜索
  81. search_results = trunks_service.search_by_vector(
  82. text=sentence,
  83. limit=1
  84. )
  85. # 处理搜索结果
  86. for result in search_results:
  87. distance = result.get("distance", DISTANCE_THRESHOLD)
  88. if distance >= DISTANCE_THRESHOLD:
  89. result_sentences.append(sentence)
  90. continue
  91. # 检查是否已存在相同引用
  92. existing_ref = next((ref for ref in all_references if ref["id"] == result["id"]), None)
  93. current_index = reference_index
  94. if existing_ref:
  95. current_index = int(existing_ref["index"])
  96. else:
  97. # 添加到引用列表
  98. reference = {
  99. "index": str(reference_index),
  100. "id": result["id"],
  101. "content": result["content"],
  102. "file_path": result.get("file_path", ""),
  103. "title": result.get("title", ""),
  104. "distance": distance,
  105. "referrence": result.get("referrence", "")
  106. }
  107. all_references.append(reference)
  108. reference_index += 1
  109. # 添加引用标记
  110. if sentence.endswith('<br>'):
  111. # 如果有多个<br>,在所有<br>前添加^[current_index]^
  112. result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
  113. else:
  114. # 直接在句子末尾添加^[current_index]^
  115. result_sentence = f'{sentence}^[{current_index}]^'
  116. result_sentences.append(result_sentence)
  117. # 组装返回数据
  118. response_data = {
  119. "answer": result_sentences,
  120. "references": all_references
  121. }
  122. return StandardResponse(success=True, data=response_data)
  123. except Exception as e:
  124. logger.error(f"Text search failed: {str(e)}")
  125. raise HTTPException(status_code=500, detail=str(e))
  126. @router.post("/match", response_model=StandardResponse)
  127. async def match_text(request: TextCompareRequest):
  128. try:
  129. sentences = TextSplitter.split_text(request.text)
  130. sentence_vector = Vectorizer.get_embedding(request.sentence)
  131. min_distance = float('inf')
  132. best_sentence = ""
  133. result_sentences = []
  134. for temp in sentences:
  135. result_sentences.append(temp)
  136. if len(temp) < 10:
  137. continue
  138. temp_vector = Vectorizer.get_embedding(temp)
  139. distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
  140. if distance < min_distance and distance < DISTANCE_THRESHOLD:
  141. min_distance = distance
  142. best_sentence = temp
  143. for i in range(len(result_sentences)):
  144. result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
  145. if result_sentences[i]["sentence"] == best_sentence:
  146. result_sentences[i]["matched"] = True
  147. return StandardResponse(success=True, records=result_sentences)
  148. except Exception as e:
  149. logger.error(f"Text comparison failed: {str(e)}")
  150. raise HTTPException(status_code=500, detail=str(e))
  151. @router.post("/mr_search", response_model=StandardResponse)
  152. async def mr_search_text_content(request: TextMatchRequest):
  153. try:
  154. # 初始化服务
  155. trunks_service = TrunksService()
  156. # 获取文本向量并搜索相似内容
  157. search_results = trunks_service.search_by_vector(
  158. text=request.text,
  159. limit=10,
  160. type="mr"
  161. )
  162. # 处理搜索结果
  163. records = []
  164. for result in search_results:
  165. distance = result.get("distance", DISTANCE_THRESHOLD)
  166. if distance >= DISTANCE_THRESHOLD:
  167. continue
  168. # 添加到引用列表
  169. record = {
  170. "content": result["content"],
  171. "file_path": result.get("file_path", ""),
  172. "title": result.get("title", ""),
  173. "distance": distance,
  174. }
  175. records.append(record)
  176. # 组装返回数据
  177. response_data = {
  178. "records": records
  179. }
  180. return StandardResponse(success=True, data=response_data)
  181. except Exception as e:
  182. logger.error(f"Mr search failed: {str(e)}")
  183. raise HTTPException(status_code=500, detail=str(e))
  184. @router.post("/mr_match", response_model=StandardResponse)
  185. async def compare_text(request: TextCompareMultiRequest):
  186. try:
  187. # 拆分两段文本
  188. origin_sentences = TextSplitter.split_text(request.origin)
  189. similar_sentences = TextSplitter.split_text(request.similar)
  190. # 初始化结果列表
  191. origin_results = []
  192. similar_results = []
  193. # 获取origin文本的向量
  194. for origin_sent in origin_sentences:
  195. if len(origin_sent) < 10:
  196. origin_results.append({"sentence": origin_sent, "matched": False})
  197. continue
  198. origin_vector = Vectorizer.get_embedding(origin_sent)
  199. min_distance = float('inf')
  200. matched_sent = ""
  201. # 与similar文本的每个句子计算相似度
  202. for similar_sent in similar_sentences:
  203. if len(similar_sent) < 10:
  204. continue
  205. similar_vector = Vectorizer.get_embedding(similar_sent)
  206. distance = VectorDistance.calculate_distance(origin_vector, similar_vector)
  207. if distance < min_distance and distance < DISTANCE_THRESHOLD:
  208. min_distance = distance
  209. matched_sent = similar_sent
  210. origin_results.append({
  211. "sentence": origin_sent,
  212. "matched": bool(matched_sent)
  213. })
  214. # 获取similar文本的向量
  215. for similar_sent in similar_sentences:
  216. if len(similar_sent) < 10:
  217. similar_results.append({"sentence": similar_sent, "matched": False})
  218. continue
  219. similar_vector = Vectorizer.get_embedding(similar_sent)
  220. min_distance = float('inf')
  221. matched_sent = ""
  222. # 与origin文本的每个句子计算相似度
  223. for origin_sent in origin_sentences:
  224. if len(origin_sent) < 10:
  225. continue
  226. origin_vector = Vectorizer.get_embedding(origin_sent)
  227. distance = VectorDistance.calculate_distance(similar_vector, origin_vector)
  228. if distance < min_distance and distance < DISTANCE_THRESHOLD:
  229. min_distance = distance
  230. matched_sent = origin_sent
  231. similar_results.append({
  232. "sentence": similar_sent,
  233. "matched": bool(matched_sent)
  234. })
  235. response_data = {
  236. "origin": origin_results,
  237. "similar": similar_results
  238. }
  239. return StandardResponse(success=True, data=response_data)
  240. except Exception as e:
  241. logger.error(f"Text comparison failed: {str(e)}")
  242. raise HTTPException(status_code=500, detail=str(e))
  243. text_search_router = router