|
@@ -62,12 +62,7 @@ class TrunksService:
|
|
|
db.close()
|
|
|
|
|
|
def search_by_vector(self, text: str, limit: int = 10, metadata_condition: Optional[dict] = None, type: Optional[str] = None, conversation_id: Optional[str] = None) -> List[dict]:
|
|
|
- if conversation_id:
|
|
|
- from cache import get_cache, set_cache
|
|
|
- cached_result = get_cache(conversation_id)
|
|
|
- if cached_result:
|
|
|
- return cached_result
|
|
|
-
|
|
|
+
|
|
|
embedding = Vectorizer.get_embedding(text)
|
|
|
db = SessionLocal()
|
|
|
try:
|
|
@@ -148,6 +143,22 @@ class TrunksService:
|
|
|
finally:
|
|
|
db.close()
|
|
|
|
|
|
+ def get_cached_result(self, conversation_id: str) -> List[dict]:
|
|
|
+ """
|
|
|
+ 根据conversation_id获取缓存结果
|
|
|
+ :param conversation_id: 会话ID
|
|
|
+ :return: 结果列表
|
|
|
+ """
|
|
|
+ if not conversation_id:
|
|
|
+ return []
|
|
|
+ from cache import get_cache, set_cache
|
|
|
+ cached_result = get_cache(conversation_id)
|
|
|
+ if cached_result:
|
|
|
+ return cached_result
|
|
|
+ return []
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
def paginated_search(self, search_params: dict) -> dict:
|
|
|
"""
|
|
|
分页查询方法
|