|
@@ -1,9 +1,11 @@
|
|
|
|
+import copy
|
|
|
|
+
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.orm import Session
|
|
from ..model.kg_node import KGNode
|
|
from ..model.kg_node import KGNode
|
|
from ..db.session import get_db
|
|
from ..db.session import get_db
|
|
import logging
|
|
import logging
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.exc import IntegrityError
|
|
-
|
|
|
|
|
|
+from cachetools import TTLCache
|
|
from ..utils.vectorizer import Vectorizer
|
|
from ..utils.vectorizer import Vectorizer
|
|
from sqlalchemy import func
|
|
from sqlalchemy import func
|
|
from ..service.kg_prop_service import KGPropService
|
|
from ..service.kg_prop_service import KGPropService
|
|
@@ -13,11 +15,10 @@ logger = logging.getLogger(__name__)
|
|
DISTANCE_THRESHOLD = 0.65
|
|
DISTANCE_THRESHOLD = 0.65
|
|
DISTANCE_THRESHOLD2 = 0.3
|
|
DISTANCE_THRESHOLD2 = 0.3
|
|
class KGNodeService:
|
|
class KGNodeService:
|
|
|
|
+ _cache = TTLCache(maxsize=10000, ttl=60*60*24)
|
|
def __init__(self, db: Session):
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
self.db = db
|
|
-
|
|
|
|
- _cache = {}
|
|
|
|
-
|
|
|
|
|
|
+
|
|
def search_title_index(self, index: str, title: str, top_k: int = 3):
|
|
def search_title_index(self, index: str, title: str, top_k: int = 3):
|
|
cache_key = f"{index}:{title}:{top_k}"
|
|
cache_key = f"{index}:{title}:{top_k}"
|
|
if cache_key in self._cache:
|
|
if cache_key in self._cache:
|
|
@@ -66,6 +67,12 @@ class KGNodeService:
|
|
page_no = 1
|
|
page_no = 1
|
|
if limit < 1:
|
|
if limit < 1:
|
|
limit = 10
|
|
limit = 10
|
|
|
|
+
|
|
|
|
+ cache_key = f"paginated_search:{keyword}:{category}:{page_no}:{distance}:{limit}:{str(search_params.get('knowledge_ids', ''))}:{load_props}"
|
|
|
|
+ logger.debug(f"Cache key: {cache_key}")
|
|
|
|
+ if cache_key in self._cache:
|
|
|
|
+ cached_value = self._cache[cache_key]
|
|
|
|
+ return copy.deepcopy(cached_value)
|
|
|
|
|
|
embedding = Vectorizer.get_instance().get_embedding(keyword)
|
|
embedding = Vectorizer.get_instance().get_embedding(keyword)
|
|
offset = (page_no - 1) * limit
|
|
offset = (page_no - 1) * limit
|
|
@@ -102,8 +109,9 @@ class KGNodeService:
|
|
query = query.filter(KGNode.embedding.l2_distance(embedding) < distance)
|
|
query = query.filter(KGNode.embedding.l2_distance(embedding) < distance)
|
|
results = query.order_by('distance').offset(offset).limit(limit).all()
|
|
results = query.order_by('distance').offset(offset).limit(limit).all()
|
|
#将results相同distance的category=疾病的放在前面
|
|
#将results相同distance的category=疾病的放在前面
|
|
- results = sorted(results, key=lambda x: (x.distance, not x.category == '疾病'))
|
|
|
|
- return {
|
|
|
|
|
|
+ #results = sorted(results, key=lambda x: (x.distance, not x.category == '疾病'))
|
|
|
|
+
|
|
|
|
+ finalResults = {
|
|
'records': [{
|
|
'records': [{
|
|
'id': r.id,
|
|
'id': r.id,
|
|
'name': r.name,
|
|
'name': r.name,
|
|
@@ -120,6 +128,8 @@ class KGNodeService:
|
|
}
|
|
}
|
|
|
|
|
|
}
|
|
}
|
|
|
|
+ self._cache[cache_key] = copy.deepcopy(finalResults)
|
|
|
|
+ return finalResults
|
|
except Exception as e:
|
|
except Exception as e:
|
|
logger.error(f"分页查询失败: {str(e)}")
|
|
logger.error(f"分页查询失败: {str(e)}")
|
|
raise e
|
|
raise e
|
|
@@ -146,20 +156,28 @@ class KGNodeService:
|
|
raise ValueError("Database integrity error")
|
|
raise ValueError("Database integrity error")
|
|
|
|
|
|
def get_node(self, node_id: int):
|
|
def get_node(self, node_id: int):
|
|
-
|
|
|
|
if node_id is None:
|
|
if node_id is None:
|
|
raise ValueError("Node ID is required")
|
|
raise ValueError("Node ID is required")
|
|
-
|
|
|
|
- node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first()
|
|
|
|
|
|
+
|
|
|
|
+ cache_key = f"get_node_{node_id}"
|
|
|
|
+ if cache_key in self._cache:
|
|
|
|
+ return copy.deepcopy(self._cache[cache_key])
|
|
|
|
+
|
|
|
|
+ node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first()
|
|
|
|
|
|
if not node:
|
|
if not node:
|
|
raise ValueError("Node not found")
|
|
raise ValueError("Node not found")
|
|
- return {
|
|
|
|
|
|
+
|
|
|
|
+ node_data = {
|
|
'id': node.id,
|
|
'id': node.id,
|
|
'name': node.name,
|
|
'name': node.name,
|
|
'category': node.category,
|
|
'category': node.category,
|
|
'version': node.version
|
|
'version': node.version
|
|
}
|
|
}
|
|
|
|
+ #node_data深拷贝
|
|
|
|
+ node_data = node_data.copy()
|
|
|
|
+ self._cache[cache_key] = copy.deepcopy(node_data)
|
|
|
|
+ return node_data
|
|
|
|
|
|
def update_node(self, node_id: int, update_data: dict):
|
|
def update_node(self, node_id: int, update_data: dict):
|
|
node = self.db.query(KGNode).get(node_id)
|
|
node = self.db.query(KGNode).get(node_id)
|