Selaa lähdekoodia

引入conversation_id

yuchengwei 2 kuukautta sitten
vanhempi
commit
434f9f673b
1 muutettua tiedostoa jossa 30 lisäystä ja 5 poistoa
  1. 30 5
      router/text_search.py

+ 30 - 5
router/text_search.py

@@ -3,6 +3,7 @@ from pydantic import BaseModel
 from typing import List
 from service.trunks_service import TrunksService
 from utils.text_splitter import TextSplitter
+from utils.vector_distance import VectorDistance
 from model.response import StandardResponse
 import logging
 
@@ -11,6 +12,7 @@ router = APIRouter(prefix="/text", tags=["Text Search"])
 
 class TextSearchRequest(BaseModel):
     text: str
+    conversation_id: str
 
 @router.post("/search", response_model=StandardResponse)
 async def search_text(request: TextSearchRequest):
@@ -20,17 +22,40 @@ async def search_text(request: TextSearchRequest):
         if not sentences:
             return StandardResponse(success=True, data={"answer": "", "references": []})
         
-        # 对每个句子进行向量搜索
+        # 初始化服务和结果列表
         trunks_service = TrunksService()
         result_sentences = []
         all_references = []
         reference_index = 1
         
+        # 根据conversation_id获取缓存结果
+        cached_results = trunks_service.get_cached_result(request.conversation_id) if request.conversation_id else []
+        
         for sentence in sentences:
-            search_results = trunks_service.search_by_vector(
-                text=sentence,
-                limit=1
-            )
+            if cached_results:
+                # 如果有缓存结果,计算向量距离
+                min_distance = float('inf')
+                best_result = None
+                sentence_vector = trunks_service.get_vector(sentence)
+                
+                for cached_result in cached_results:
+                    content_vector = trunks_service.get_vector(cached_result['content']); 
+                    distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
+                    if distance < min_distance:
+                        min_distance = distance
+                        best_result = {**cached_result, 'distance': distance}
+                        
+                
+                if best_result and best_result['distance'] < 1:
+                    search_results = [best_result]
+                else:
+                    search_results = []
+            else:
+                # 如果没有缓存结果,进行向量搜索
+                search_results = trunks_service.search_by_vector(
+                    text=sentence,
+                    limit=1
+                )
             
             # 处理搜索结果
             for result in search_results: