text_search.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. from fastapi import APIRouter, HTTPException
  2. from pydantic import BaseModel
  3. from typing import List, Optional
  4. from service.trunks_service import TrunksService
  5. from utils.text_splitter import TextSplitter
  6. from utils.vector_distance import VectorDistance
  7. from model.response import StandardResponse
  8. from utils.vectorizer import Vectorizer
  9. DISTANCE_THRESHOLD = 0.65
  10. import logging
  11. logger = logging.getLogger(__name__)
  12. router = APIRouter(prefix="/text", tags=["Text Search"])
  13. class TextSearchRequest(BaseModel):
  14. text: str
  15. conversation_id: Optional[str] = None
  16. need_convert: Optional[bool] = False
  17. class TextCompareRequest(BaseModel):
  18. sentence: str
  19. text: str
  20. @router.post("/search", response_model=StandardResponse)
  21. async def search_text(request: TextSearchRequest):
  22. try:
  23. # 使用TextSplitter拆分文本
  24. sentences = TextSplitter.split_text(request.text)
  25. if not sentences:
  26. return StandardResponse(success=True, data={"answer": "", "references": []})
  27. # 初始化服务和结果列表
  28. trunks_service = TrunksService()
  29. result_sentences = []
  30. all_references = []
  31. reference_index = 1
  32. # 根据conversation_id获取缓存结果
  33. cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
  34. for sentence in sentences:
  35. # if request.need_convert:
  36. sentence = sentence.replace("\n", "<br>")
  37. if len(sentence) < 10:
  38. result_sentences.append(sentence)
  39. continue
  40. if cached_results:
  41. # 如果有缓存结果,计算向量距离
  42. min_distance = float('inf')
  43. best_result = None
  44. sentence_vector = Vectorizer.get_embedding(sentence)
  45. for cached_result in cached_results:
  46. content_vector = cached_result['embedding']
  47. distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
  48. if distance < min_distance:
  49. min_distance = distance
  50. best_result = {**cached_result, 'distance': distance}
  51. if best_result and best_result['distance'] < DISTANCE_THRESHOLD:
  52. search_results = [best_result]
  53. else:
  54. search_results = []
  55. else:
  56. # 如果没有缓存结果,进行向量搜索
  57. search_results = trunks_service.search_by_vector(
  58. text=sentence,
  59. limit=1
  60. )
  61. # 处理搜索结果
  62. for result in search_results:
  63. distance = result.get("distance", DISTANCE_THRESHOLD)
  64. if distance >= DISTANCE_THRESHOLD:
  65. result_sentences.append(sentence)
  66. continue
  67. # 检查是否已存在相同引用
  68. existing_ref = next((ref for ref in all_references if ref["id"] == result["id"]), None)
  69. current_index = reference_index
  70. if existing_ref:
  71. current_index = int(existing_ref["index"])
  72. else:
  73. # 添加到引用列表
  74. reference = {
  75. "index": str(reference_index),
  76. "id": result["id"],
  77. "content": result["content"],
  78. "file_path": result.get("file_path", ""),
  79. "title": result.get("title", ""),
  80. "distance": distance,
  81. "referrence": result.get("referrence", "")
  82. }
  83. all_references.append(reference)
  84. reference_index += 1
  85. # 添加引用标记
  86. if sentence.endswith('<br>'):
  87. # 如果有多个<br>,在所有<br>前添加^[current_index]^
  88. result_sentence = sentence.replace('<br>', f'^[{current_index}]^<br>')
  89. else:
  90. # 直接在句子末尾添加^[current_index]^
  91. result_sentence = f'{sentence}^[{current_index}]^'
  92. result_sentences.append(result_sentence)
  93. # 组装返回数据
  94. response_data = {
  95. "answer": result_sentences,
  96. "references": all_references
  97. }
  98. return StandardResponse(success=True, data=response_data)
  99. except Exception as e:
  100. logger.error(f"Text search failed: {str(e)}")
  101. raise HTTPException(status_code=500, detail=str(e))
  102. @router.post("/match", response_model=StandardResponse)
  103. async def match_text(request: TextCompareRequest):
  104. try:
  105. sentences = TextSplitter.split_text(request.text)
  106. sentence_vector = Vectorizer.get_embedding(request.sentence)
  107. min_distance = float('inf')
  108. best_sentence = ""
  109. result_sentences = []
  110. for temp in sentences:
  111. result_sentences.append(temp)
  112. if len(temp) < 10:
  113. continue
  114. temp_vector = Vectorizer.get_embedding(temp)
  115. distance = VectorDistance.calculate_distance(sentence_vector, temp_vector)
  116. if distance < min_distance and distance < DISTANCE_THRESHOLD:
  117. min_distance = distance
  118. best_sentence = temp
  119. for i in range(len(result_sentences)):
  120. result_sentences[i] = {"sentence": result_sentences[i], "matched": False}
  121. if result_sentences[i]["sentence"] == best_sentence:
  122. result_sentences[i]["matched"] = True
  123. return StandardResponse(success=True, records=result_sentences)
  124. except Exception as e:
  125. logger.error(f"Text comparison failed: {str(e)}")
  126. raise HTTPException(status_code=500, detail=str(e))
  127. text_search_router = router