SGTY 2 mesi fa
parent
commit
b5b799507c

+ 4 - 4
router/text_search.py

@@ -1,6 +1,6 @@
 from fastapi import APIRouter, HTTPException
 from pydantic import BaseModel
-from typing import List
+from typing import List, Optional
 from service.trunks_service import TrunksService
 from utils.text_splitter import TextSplitter
 from utils.vector_distance import VectorDistance
@@ -13,7 +13,7 @@ router = APIRouter(prefix="/text", tags=["Text Search"])
 
 class TextSearchRequest(BaseModel):
     text: str
-    conversation_id: str
+    conversation_id: Optional[str] = None
 
 @router.post("/search", response_model=StandardResponse)
 async def search_text(request: TextSearchRequest):
@@ -40,14 +40,14 @@ async def search_text(request: TextSearchRequest):
                 sentence_vector = Vectorizer.get_embedding(sentence)
                 
                 for cached_result in cached_results:
-                    content_vector = Vectorizer.get_embedding(cached_result['content'])
+                    content_vector = cached_result['embedding']
                     distance = VectorDistance.calculate_distance(sentence_vector, content_vector)
                     if distance < min_distance:
                         min_distance = distance
                         best_result = {**cached_result, 'distance': distance}
                         
                 
-                if best_result and best_result['distance'] < 1:
+                if best_result and best_result['distance'] < 0.75:
                     search_results = [best_result]
                 else:
                     search_results = []

+ 4 - 2
service/trunks_service.py

@@ -71,7 +71,8 @@ class TrunksService:
                 Trunks.file_path,
                 Trunks.content,
                 Trunks.embedding.l2_distance(embedding).label('distance'),
-                Trunks.title
+                Trunks.title,
+                Trunks.embedding
             )
             if metadata_condition:
                 query = query.filter_by(**metadata_condition)
@@ -83,7 +84,8 @@ class TrunksService:
                 'file_path': r.file_path,
                 'content': r.content,
                 'distance': r.distance,
-                'title': r.title
+                'title': r.title,
+                'embedding': r.embedding.tolist()
             } for r in results]
 
             if conversation_id:

+ 3 - 1
tests/service/test_trunks_service.py

@@ -33,7 +33,9 @@ class TestTrunksServiceCRUD:
 
 class TestSearchOperations:
     def test_vector_search(self, trunks_service, test_trunk_data):
-        results = trunks_service.search_by_vector("各种致病因素作用引起的有效循环血容量急剧减少",10,conversation_id="1111111")
+        results = trunks_service.search_by_vector("各种致病因素作用引起的有效循环血容量急剧减少",10,conversation_id="1111111aaaa")
+        print("搜索结果:", results)
+        results = trunks_service.get_cache("1111111aaaa")
         print("搜索结果:", results)
         assert len(results) > 0
 

File diff suppressed because it is too large
+ 19 - 2
utils/vector_distance.py