text_search.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from fastapi import APIRouter, HTTPException
  2. from pydantic import BaseModel
  3. from typing import List
  4. from service.trunks_service import TrunksService
  5. from utils.text_splitter import TextSplitter
  6. from model.response import StandardResponse
  7. import logging
  8. logger = logging.getLogger(__name__)
  9. router = APIRouter(prefix="/text", tags=["Text Search"])
  10. class TextSearchRequest(BaseModel):
  11. text: str
  12. @router.post("/search", response_model=StandardResponse)
  13. async def search_text(request: TextSearchRequest):
  14. try:
  15. # 使用TextSplitter拆分文本
  16. sentences = TextSplitter.split_text(request.text)
  17. if not sentences:
  18. return StandardResponse(success=True, data={"answer": "", "references": []})
  19. # 对每个句子进行向量搜索
  20. trunks_service = TrunksService()
  21. result_sentences = []
  22. all_references = []
  23. reference_index = 1
  24. for sentence in sentences:
  25. search_results = trunks_service.search_by_vector(
  26. text=sentence,
  27. limit=1
  28. )
  29. # 处理搜索结果
  30. for result in search_results:
  31. # 获取distance值,如果大于等于1则跳过
  32. distance = result.get("distance", 1.0)
  33. if distance >= 1:
  34. continue
  35. # 添加引用标记
  36. result_sentence = sentence + f"^[{reference_index}]^"
  37. result_sentences.append(result_sentence)
  38. # 添加到引用列表
  39. reference = {
  40. "index": str(reference_index),
  41. "content": result["content"],
  42. "file_path": result.get("file_path", ""),
  43. "title": result.get("title", ""),
  44. "distance": distance
  45. }
  46. all_references.append(reference)
  47. reference_index += 1
  48. # 组装返回数据
  49. response_data = {
  50. "answer": "\n".join(result_sentences),
  51. "references": all_references
  52. }
  53. return StandardResponse(success=True, data=response_data)
  54. except Exception as e:
  55. logger.error(f"Text search failed: {str(e)}")
  56. raise HTTPException(status_code=500, detail=str(e))
  57. text_search_router = router