text_search.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from fastapi import APIRouter, HTTPException
  2. from pydantic import BaseModel
  3. from typing import List
  4. from service.trunks_service import TrunksService
  5. from utils.text_splitter import TextSplitter
  6. from model.response import StandardResponse
  7. import logging
  8. logger = logging.getLogger(__name__)
  9. router = APIRouter(prefix="/text", tags=["Text Search"])
  10. class TextSearchRequest(BaseModel):
  11. text: str
  12. @router.post("/search", response_model=StandardResponse)
  13. async def search_text(request: TextSearchRequest):
  14. try:
  15. # 使用TextSplitter拆分文本
  16. sentences = TextSplitter.split_text(request.text)
  17. if not sentences:
  18. return StandardResponse(success=True, data={"answer": "", "references": []})
  19. # 对每个句子进行向量搜索
  20. trunks_service = TrunksService()
  21. result_sentences = []
  22. all_references = []
  23. reference_index = 1
  24. for sentence in sentences:
  25. search_results = trunks_service.search_by_vector(
  26. text=sentence,
  27. limit=1
  28. )
  29. # 处理搜索结果
  30. for result in search_results:
  31. # 添加引用标记
  32. result_sentence = sentence + f"^[{reference_index}]^"
  33. result_sentences.append(result_sentence)
  34. # 添加到引用列表
  35. reference = {
  36. "index": str(reference_index),
  37. "content": result["content"],
  38. "file_path": result.get("file_path", ""),
  39. "title": result.get("title", ""),
  40. "distance": result.get("distance", "")
  41. }
  42. all_references.append(reference)
  43. reference_index += 1
  44. # 组装返回数据
  45. response_data = {
  46. "answer": "\n".join(result_sentences),
  47. "references": all_references
  48. }
  49. return StandardResponse(success=True, data=response_data)
  50. except Exception as e:
  51. logger.error(f"Text search failed: {str(e)}")
  52. raise HTTPException(status_code=500, detail=str(e))
  53. text_search_router = router