text_search.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from fastapi import APIRouter, HTTPException
  2. from pydantic import BaseModel
  3. from typing import List
  4. from service.trunks_service import TrunksService
  5. from utils.text_splitter import TextSplitter
  6. from utils.vector_distance import VectorDistance
  7. from model.response import StandardResponse
  8. from utils.vectorizer import Vectorizer
  9. import logging
  10. logger = logging.getLogger(__name__)
  11. router = APIRouter(prefix="/text", tags=["Text Search"])
  12. class TextSearchRequest(BaseModel):
  13. text: str
  14. conversation_id: str
  15. @router.post("/search", response_model=StandardResponse)
  16. async def search_text(request: TextSearchRequest):
  17. try:
  18. # 使用TextSplitter拆分文本
  19. sentences = TextSplitter.split_text(request.text)
  20. if not sentences:
  21. return StandardResponse(success=True, data={"answer": "", "references": []})
  22. # 初始化服务和结果列表
  23. trunks_service = TrunksService()
  24. result_sentences = []
  25. all_references = []
  26. reference_index = 1
  27. # 根据conversation_id获取缓存结果
  28. cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
  29. for sentence in sentences:
  30. if cached_results:
  31. # 如果有缓存结果,计算向量距离
  32. min_distance = float('inf')
  33. best_result = None
  34. sentence_vector = Vectorizer.get_embedding(sentence)
  35. for cached_result in cached_results:
  36. content_vector = Vectorizer.get_embedding(cached_result['content'])
  37. distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
  38. if distance < min_distance:
  39. min_distance = distance
  40. best_result = {**cached_result, 'distance': distance}
  41. if best_result and best_result['distance'] < 1:
  42. search_results = [best_result]
  43. else:
  44. search_results = []
  45. else:
  46. # 如果没有缓存结果,进行向量搜索
  47. search_results = trunks_service.search_by_vector(
  48. text=sentence,
  49. limit=1
  50. )
  51. # 处理搜索结果
  52. for result in search_results:
  53. # 获取distance值,如果大于等于1则跳过
  54. distance = result.get("distance", 1.0)
  55. if distance >= 1:
  56. continue
  57. # 添加引用标记
  58. result_sentence = sentence + f"^[{reference_index}]^"
  59. result_sentences.append(result_sentence)
  60. # 添加到引用列表
  61. reference = {
  62. "index": str(reference_index),
  63. "content": result["content"],
  64. "file_path": result.get("file_path", ""),
  65. "title": result.get("title", ""),
  66. "distance": distance
  67. }
  68. all_references.append(reference)
  69. reference_index += 1
  70. # 组装返回数据
  71. response_data = {
  72. "answer": "\n".join(result_sentences),
  73. "references": all_references
  74. }
  75. return StandardResponse(success=True, data=response_data)
  76. except Exception as e:
  77. logger.error(f"Text search failed: {str(e)}")
  78. raise HTTPException(status_code=500, detail=str(e))
  79. text_search_router = router