text_search.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. from fastapi import APIRouter, HTTPException
  2. from pydantic import BaseModel
  3. from typing import List
  4. from service.trunks_service import TrunksService
  5. from utils.text_splitter import TextSplitter
  6. from model.response import StandardResponse
  7. import logging
  8. logger = logging.getLogger(__name__)
  9. router = APIRouter(prefix="/text", tags=["Text Search"])
  10. class TextSearchRequest(BaseModel):
  11. text: str
  12. limit: int = 1
  13. @router.post("/search", response_model=StandardResponse)
  14. async def search_text(request: TextSearchRequest):
  15. try:
  16. # 使用TextSplitter拆分文本
  17. sentences = TextSplitter.split_text(request.text)
  18. if not sentences:
  19. return StandardResponse(success=True, data={"answer": "", "references": []})
  20. # 对每个句子进行向量搜索
  21. trunks_service = TrunksService()
  22. result_sentences = []
  23. all_references = []
  24. reference_index = 1
  25. for sentence in sentences:
  26. search_results = trunks_service.search_by_vector(
  27. text=sentence,
  28. limit=request.limit
  29. )
  30. # 处理搜索结果
  31. for result in search_results:
  32. # 添加引用标记
  33. result_sentence = sentence + f"^[{reference_index}]^"
  34. result_sentences.append(result_sentence)
  35. # 添加到引用列表
  36. reference = {
  37. "index": str(reference_index),
  38. "content": result["content"],
  39. "file_path": result.get("file_path", ""),
  40. "title": result.get("title", ""),
  41. "distance": result.get("distance", "")
  42. }
  43. all_references.append(reference)
  44. reference_index += 1
  45. # 组装返回数据
  46. response_data = {
  47. "answer": "\n".join(result_sentences),
  48. "references": all_references
  49. }
  50. return StandardResponse(success=True, data=response_data)
  51. except Exception as e:
  52. logger.error(f"Text search failed: {str(e)}")
  53. raise HTTPException(status_code=500, detail=str(e))
  54. text_search_router = router