from fastapi import APIRouter, HTTPException from pydantic import BaseModel from typing import List, Optional from service.trunks_service import TrunksService from utils.text_splitter import TextSplitter from utils.vector_distance import VectorDistance from model.response import StandardResponse from utils.vectorizer import Vectorizer DISTANCE_THRESHOLD = 0.65 import logging logger = logging.getLogger(__name__) router = APIRouter(prefix="/text", tags=["Text Search"]) class TextSearchRequest(BaseModel): text: str conversation_id: Optional[str] = None need_convert: Optional[bool] = False class TextCompareRequest(BaseModel): sentence: str text: str @router.post("/search", response_model=StandardResponse) async def search_text(request: TextSearchRequest): try: # 使用TextSplitter拆分文本 sentences = TextSplitter.split_text(request.text) if not sentences: return StandardResponse(success=True, data={"answer": "", "references": []}) # 初始化服务和结果列表 trunks_service = TrunksService() result_sentences = [] all_references = [] reference_index = 1 # 根据conversation_id获取缓存结果 cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else [] for sentence in sentences: # if request.need_convert: sentence = sentence.replace("\n", "
") if len(sentence) < 10: result_sentences.append(sentence) continue if cached_results: # 如果有缓存结果,计算向量距离 min_distance = float('inf') best_result = None sentence_vector = Vectorizer.get_embedding(sentence) for cached_result in cached_results: content_vector = cached_result['embedding'] distance = VectorDistance.calculate_distance(sentence_vector, content_vector) if distance < min_distance: min_distance = distance best_result = {**cached_result, 'distance': distance} if best_result and best_result['distance'] < DISTANCE_THRESHOLD: search_results = [best_result] else: search_results = [] else: # 如果没有缓存结果,进行向量搜索 search_results = trunks_service.search_by_vector( text=sentence, limit=1 ) # 处理搜索结果 for result in search_results: distance = result.get("distance", DISTANCE_THRESHOLD) if distance >= DISTANCE_THRESHOLD: result_sentences.append(sentence) continue # 检查是否已存在相同引用 existing_ref = next((ref for ref in all_references if ref["id"] == result["id"]), None) current_index = reference_index if existing_ref: current_index = int(existing_ref["index"]) else: # 添加到引用列表 reference = { "index": str(reference_index), "id": result["id"], "content": result["content"], "file_path": result.get("file_path", ""), "title": result.get("title", ""), "distance": distance, "referrence": result.get("referrence", "") } all_references.append(reference) reference_index += 1 # 添加引用标记 if sentence.endswith('
'): # 如果有多个
,在所有
前添加^[current_index]^ result_sentence = sentence.replace('
', f'^[{current_index}]^
') else: # 直接在句子末尾添加^[current_index]^ result_sentence = f'{sentence}^[{current_index}]^' result_sentences.append(result_sentence) # 组装返回数据 response_data = { "answer": result_sentences, "references": all_references } return StandardResponse(success=True, data=response_data) except Exception as e: logger.error(f"Text search failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/match", response_model=StandardResponse) async def match_text(request: TextCompareRequest): try: sentences = TextSplitter.split_text(request.text) sentence_vector = Vectorizer.get_embedding(request.sentence) min_distance = float('inf') best_sentence = "" result_sentences = [] for temp in sentences: result_sentences.append(temp) if len(temp) < 10: continue temp_vector = Vectorizer.get_embedding(temp) distance = VectorDistance.calculate_distance(sentence_vector, temp_vector) if distance < min_distance and distance < DISTANCE_THRESHOLD: min_distance = distance best_sentence = temp for i in range(len(result_sentences)): result_sentences[i] = {"sentence": result_sentences[i], "matched": False} if result_sentences[i]["sentence"] == best_sentence: result_sentences[i]["matched"] = True return StandardResponse(success=True, records=result_sentences) except Exception as e: logger.error(f"Text comparison failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) text_search_router = router