|
@@ -3,6 +3,7 @@ from pydantic import BaseModel
|
|
|
from typing import List
|
|
|
from service.trunks_service import TrunksService
|
|
|
from utils.text_splitter import TextSplitter
|
|
|
+from utils.vector_distance import VectorDistance
|
|
|
from model.response import StandardResponse
|
|
|
import logging
|
|
|
|
|
@@ -11,6 +12,7 @@ router = APIRouter(prefix="/text", tags=["Text Search"])
|
|
|
|
|
|
class TextSearchRequest(BaseModel):
|
|
|
text: str
|
|
|
+ conversation_id: str
|
|
|
|
|
|
@router.post("/search", response_model=StandardResponse)
|
|
|
async def search_text(request: TextSearchRequest):
|
|
@@ -20,17 +22,40 @@ async def search_text(request: TextSearchRequest):
|
|
|
if not sentences:
|
|
|
return StandardResponse(success=True, data={"answer": "", "references": []})
|
|
|
|
|
|
- # 对每个句子进行向量搜索
|
|
|
+ # 初始化服务和结果列表
|
|
|
trunks_service = TrunksService()
|
|
|
result_sentences = []
|
|
|
all_references = []
|
|
|
reference_index = 1
|
|
|
|
|
|
+ # 根据conversation_id获取缓存结果
|
|
|
+ cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
|
|
|
+
|
|
|
for sentence in sentences:
|
|
|
- search_results = trunks_service.search_by_vector(
|
|
|
- text=sentence,
|
|
|
- limit=1
|
|
|
- )
|
|
|
+ if cached_results:
|
|
|
+ # 如果有缓存结果,计算向量距离
|
|
|
+ min_distance = float('inf')
|
|
|
+ best_result = None
|
|
|
+ sentence_vector = trunks_service.get_vector(sentence)
|
|
|
+
|
|
|
+ for cached_result in cached_results:
|
|
|
+ content_vector = trunks_service.get_vector(cached_result['content']);
|
|
|
+ distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
|
|
|
+ if distance < min_distance:
|
|
|
+ min_distance = distance
|
|
|
+ best_result = {**cached_result, 'distance': distance}
|
|
|
+
|
|
|
+
|
|
|
+ if best_result and best_result['distance'] < 1:
|
|
|
+ search_results = [best_result]
|
|
|
+ else:
|
|
|
+ search_results = []
|
|
|
+ else:
|
|
|
+ # 如果没有缓存结果,进行向量搜索
|
|
|
+ search_results = trunks_service.search_by_vector(
|
|
|
+ text=sentence,
|
|
|
+ limit=1
|
|
|
+ )
|
|
|
|
|
|
# 处理搜索结果
|
|
|
for result in search_results:
|