Bladeren bron

代码提交

SGTY 2 maanden geleden
bovenliggende
commit
dc83005a1e
3 gewijzigde bestanden met toevoegingen van 24 en 12 verwijderingen
  1. 2 2
      router/knowledge_dify.py
  2. 20 8
      service/trunks_service.py
  3. 2 2
      tests/service/test_trunks_service.py

+ 2 - 2
router/knowledge_dify.py

@@ -94,7 +94,7 @@ async def dify_retrieval(
         logger.info(f"Starting retrieval for knowledge base {payload.knowledge_id} with query: {payload.query}")
         
         trunks_service = TrunksService()
-        search_results = trunks_service.search_by_vector(payload.query, payload.retrieval_setting.top_k)
+        search_results = trunks_service.search_by_vector(payload.query, payload.retrieval_setting.top_k, conversation_id=conversation_id)
         
         if not search_results:
             logger.warning(f"No results found for query: {request.query}")
@@ -150,7 +150,7 @@ async def dify_chatflow_retrieval(
         query=query,
         retrieval_setting=RetrievalSetting(top_k=top_k, score_threshold=score_threshold)
     )
-    return await dify_retrieval(payload, request, authorization, db)
+    return await dify_retrieval(payload, request, authorization, db, conversation_id=conversation_id)
 
 dify_kb_router = router
 

+ 20 - 8
service/trunks_service.py

@@ -87,7 +87,7 @@ class TrunksService:
             } for r in results]
 
             if conversation_id:
-                set_cache(conversation_id, result_list)
+                self.set_cache(conversation_id, result_list)
 
             return result_list
         finally:
@@ -143,19 +143,31 @@ class TrunksService:
         finally:
             db.close()
 
+    _cache = {}
+
+    def get_cache(self, conversation_id: str) -> List[dict]:
+        """
+        根据conversation_id获取缓存结果
+        :param conversation_id: 会话ID
+        :return: 结果列表
+        """
+        return self._cache.get(conversation_id, [])
+
+    def set_cache(self, conversation_id: str, result: List[dict]) -> None:
+        """
+        设置缓存结果
+        :param conversation_id: 会话ID
+        :param result: 要缓存的结果
+        """
+        self._cache[conversation_id] = result
+
     def get_cached_result(self, conversation_id: str) -> List[dict]:
         """
         根据conversation_id获取缓存结果
         :param conversation_id: 会话ID
         :return: 结果列表
         """
-        if not conversation_id:
-            return []
-        from cache import get_cache, set_cache
-        cached_result = get_cache(conversation_id)
-        if cached_result:
-            return cached_result
-        return []
+        return self.get_cache(conversation_id)
         
         
 

+ 2 - 2
tests/service/test_trunks_service.py

@@ -33,7 +33,7 @@ class TestTrunksServiceCRUD:
 
 class TestSearchOperations:
     def test_vector_search(self, trunks_service, test_trunk_data):
-        results = trunks_service.search_by_vector("各种致病因素作用引起的有效循环血容量急剧减少",10,"1111111")
+        results = trunks_service.search_by_vector("各种致病因素作用引起的有效循环血容量急剧减少",10,conversation_id="1111111")
         print("搜索结果:", results)
         assert len(results) > 0
 
@@ -92,4 +92,4 @@ class TestBatchCreateFromDirectory:
             assert db_trunk is not None
             assert ".txt" in db_trunk.file_path
             assert "_split_" in db_trunk.file_path
-            assert len(db_trunk.content) > 0
+            assert len(db_trunk.content) > 0