from fastapi import APIRouter, HTTPException from pydantic import BaseModel from typing import List from service.trunks_service import TrunksService from utils.text_splitter import TextSplitter from utils.vector_distance import VectorDistance from model.response import StandardResponse from utils.vectorizer import Vectorizer import logging logger = logging.getLogger(__name__) router = APIRouter(prefix="/text", tags=["Text Search"]) class TextSearchRequest(BaseModel): text: str conversation_id: str @router.post("/search", response_model=StandardResponse) async def search_text(request: TextSearchRequest): try: # 使用TextSplitter拆分文本 sentences = TextSplitter.split_text(request.text) if not sentences: return StandardResponse(success=True, data={"answer": "", "references": []}) # 初始化服务和结果列表 trunks_service = TrunksService() result_sentences = [] all_references = [] reference_index = 1 # 根据conversation_id获取缓存结果 cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else [] for sentence in sentences: if cached_results: # 如果有缓存结果,计算向量距离 min_distance = float('inf') best_result = None sentence_vector = Vectorizer.get_embedding(sentence) for cached_result in cached_results: content_vector = Vectorizer.get_embedding(cached_result['content']) distance = VectorDistance.calculate_distance(sentence_vector, content_vector) if distance < min_distance: min_distance = distance best_result = {**cached_result, 'distance': distance} if best_result and best_result['distance'] < 1: search_results = [best_result] else: search_results = [] else: # 如果没有缓存结果,进行向量搜索 search_results = trunks_service.search_by_vector( text=sentence, limit=1 ) # 处理搜索结果 for result in search_results: # 获取distance值,如果大于等于1则跳过 distance = result.get("distance", 1.0) if distance >= 1: continue # 添加引用标记 result_sentence = sentence + f"^[{reference_index}]^" result_sentences.append(result_sentence) # 添加到引用列表 reference = { "index": str(reference_index), "content": result["content"], "file_path": result.get("file_path", ""), "title": result.get("title", ""), "distance": distance } all_references.append(reference) reference_index += 1 # 组装返回数据 response_data = { "answer": "\n".join(result_sentences), "references": all_references } return StandardResponse(success=True, data=response_data) except Exception as e: logger.error(f"Text search failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) text_search_router = router