text_search.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from fastapi import APIRouter, HTTPException
  2. from pydantic import BaseModel
  3. from typing import List
  4. from service.trunks_service import TrunksService
  5. from utils.text_splitter import TextSplitter
  6. from utils.vector_distance import VectorDistance
  7. from model.response import StandardResponse
  8. import logging
  9. logger = logging.getLogger(__name__)
  10. router = APIRouter(prefix="/text", tags=["Text Search"])
  11. class TextSearchRequest(BaseModel):
  12. text: str
  13. conversation_id: str
  14. @router.post("/search", response_model=StandardResponse)
  15. async def search_text(request: TextSearchRequest):
  16. try:
  17. # 使用TextSplitter拆分文本
  18. sentences = TextSplitter.split_text(request.text)
  19. if not sentences:
  20. return StandardResponse(success=True, data={"answer": "", "references": []})
  21. # 初始化服务和结果列表
  22. trunks_service = TrunksService()
  23. result_sentences = []
  24. all_references = []
  25. reference_index = 1
  26. # 根据conversation_id获取缓存结果
  27. cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
  28. for sentence in sentences:
  29. if cached_results:
  30. # 如果有缓存结果,计算向量距离
  31. min_distance = float('inf')
  32. best_result = None
  33. sentence_vector = trunks_service.get_vector(sentence)
  34. for cached_result in cached_results:
  35. content_vector = trunks_service.get_vector(cached_result['content']);
  36. distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
  37. if distance < min_distance:
  38. min_distance = distance
  39. best_result = {**cached_result, 'distance': distance}
  40. if best_result and best_result['distance'] < 1:
  41. search_results = [best_result]
  42. else:
  43. search_results = []
  44. else:
  45. # 如果没有缓存结果,进行向量搜索
  46. search_results = trunks_service.search_by_vector(
  47. text=sentence,
  48. limit=1
  49. )
  50. # 处理搜索结果
  51. for result in search_results:
  52. # 获取distance值,如果大于等于1则跳过
  53. distance = result.get("distance", 1.0)
  54. if distance >= 1:
  55. continue
  56. # 添加引用标记
  57. result_sentence = sentence + f"^[{reference_index}]^"
  58. result_sentences.append(result_sentence)
  59. # 添加到引用列表
  60. reference = {
  61. "index": str(reference_index),
  62. "content": result["content"],
  63. "file_path": result.get("file_path", ""),
  64. "title": result.get("title", ""),
  65. "distance": distance
  66. }
  67. all_references.append(reference)
  68. reference_index += 1
  69. # 组装返回数据
  70. response_data = {
  71. "answer": "\n".join(result_sentences),
  72. "references": all_references
  73. }
  74. return StandardResponse(success=True, data=response_data)
  75. except Exception as e:
  76. logger.error(f"Text search failed: {str(e)}")
  77. raise HTTPException(status_code=500, detail=str(e))
  78. text_search_router = router