SGTY vor 1 Monat
Ursprung
Commit
b4cf6a3497

+ 2 - 1
requirements.txt

@@ -9,4 +9,5 @@ urllib3==2.3.0
 uvicorn==0.34.0
 psycopg2-binary==2.9.10
 python-dotenv==1.0.0
-hui-tools[all]==0.5.8
+hui-tools[all]==0.5.8
+cachetools==6.1.0

+ 1 - 1
src/knowledge/service/kg_edge_service.py

@@ -9,7 +9,7 @@ from cachetools import TTLCache
 logger = logging.getLogger(__name__)
 
 class KGEdgeService:
-    _cache = TTLCache(maxsize=100000, ttl=60*60*24*7)
+    _cache = TTLCache(maxsize=10000, ttl=60*60*24)
     def __init__(self, db: Session):
         self.db = db
 

+ 26 - 10
src/knowledge/service/kg_node_service.py

@@ -13,12 +13,10 @@ logger = logging.getLogger(__name__)
 DISTANCE_THRESHOLD = 0.65
 DISTANCE_THRESHOLD2 = 0.3
 class KGNodeService:
-    _cache = TTLCache(maxsize=100000, ttl=60*60*24*7)
+    _cache = TTLCache(maxsize=10000, ttl=60*60*24)
     def __init__(self, db: Session):
         self.db = db
-
-    _cache = {}
-
+        
     def search_title_index(self, index: str, title: str, top_k: int = 3):
         cache_key = f"{index}:{title}:{top_k}"
         if cache_key in self._cache:
@@ -67,6 +65,13 @@ class KGNodeService:
             page_no = 1
         if limit < 1:
             limit = 10
+            
+        cache_key = f"paginated_search:{keyword}:{category}:{page_no}:{distance}:{limit}:{str(search_params.get('knowledge_ids', ''))}:{load_props}"
+        logger.debug(f"Cache key: {cache_key}")
+        if cache_key in self._cache:
+            cached_value = self._cache[cache_key]
+            print(cached_value)
+            return cached_value
 
         embedding = Vectorizer.get_instance().get_embedding(keyword)
         offset = (page_no - 1) * limit
@@ -103,8 +108,9 @@ class KGNodeService:
             query = query.filter(KGNode.embedding.l2_distance(embedding) < distance)
             results = query.order_by('distance').offset(offset).limit(limit).all()
             #将results相同distance的category=疾病的放在前面
-            results = sorted(results, key=lambda x: (x.distance, not x.category == '疾病'))
-            return {
+            #results = sorted(results, key=lambda x: (x.distance, not x.category == '疾病'))
+
+            finalResults = {
                 'records': [{
                     'id': r.id,
                     'name': r.name,
@@ -121,6 +127,8 @@ class KGNodeService:
                 }
             
             }
+            self._cache[cache_key] = finalResults
+            return finalResults
         except Exception as e:
             logger.error(f"分页查询失败: {str(e)}")
             raise e
@@ -147,20 +155,28 @@ class KGNodeService:
             raise ValueError("Database integrity error")
 
     def get_node(self, node_id: int):
-   
         if node_id is None:
             raise ValueError("Node ID is required")
-     
-        node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first()     
+
+        cache_key = f"get_node_{node_id}"
+        if cache_key in self._cache:
+            return self._cache[cache_key]
+
+        node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first()
 
         if not node:
             raise ValueError("Node not found")
-        return {
+
+        node_data = {
             'id': node.id,
             'name': node.name,
             'category': node.category,
             'version': node.version
         }
+        #node_data深拷贝
+        node_data = node_data.copy()
+        self._cache[cache_key] = node_data
+        return node_data
 
     def update_node(self, node_id: int, update_data: dict):
         node = self.db.query(KGNode).get(node_id)

+ 258 - 0
src/knowledge/service/kg_node_service2.py

@@ -0,0 +1,258 @@
+from sqlalchemy.orm import Session
+from typing import Optional
+from model.kg_node import KGNode
+from db.session import get_db
+import logging
+from sqlalchemy.exc import IntegrityError
+
+from utils import vectorizer
+from utils.vectorizer import Vectorizer
+from sqlalchemy import func
+from service.kg_prop_service import KGPropService
+from service.kg_edge_service import KGEdgeService
+from cachetools import TTLCache
+from cachetools.keys import hashkey
+logger = logging.getLogger(__name__)
+DISTANCE_THRESHOLD = 0.65
+class KGNodeService:
+    def __init__(self, db: Session):
+        self.db = db
+
+    _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
+
+    def search_title_index(self, index: str, keywrod: str,category: str, top_k: int = 3,distance: float = 0.3) -> Optional[int]:
+        cache_key = f"search_title_index_{index}:{keywrod}:{category}:{top_k}:{distance}"
+        if cache_key in self._cache:
+            return self._cache[cache_key]
+
+        query_embedding = Vectorizer.get_embedding(keywrod)
+        db = next(get_db())
+        # 执行向量搜索
+        results = (
+            db.query(
+                KGNode.id,
+                KGNode.name,
+                KGNode.category,
+                KGNode.embedding.l2_distance(query_embedding).label('distance')
+            )
+            .filter(KGNode.status == 0)
+            .filter(KGNode.category == category)
+            #todo 是否能提高性能 改成余弦算法
+            .filter(KGNode.embedding.l2_distance(query_embedding) <= distance)
+            .order_by('distance').limit(top_k).all()
+        )
+        results = [
+            {
+                "id": node.id,
+               "title": node.name,
+               "text": node.category,
+               "score": 2.0-node.distance
+            }
+                for node in results
+            ]
+
+        self._cache[cache_key] = results
+        return results
+
+    def paginated_search(self, search_params: dict) -> dict:
+        load_props = search_params.get('load_props', False)
+        prop_service = KGPropService(self.db)
+        edge_service = KGEdgeService(self.db)
+        keyword = search_params.get('keyword', '')
+        category = search_params.get('category', None)
+        page_no = search_params.get('pageNo', 1)
+        #distance 为NONE或不存在时,使用默认值
+        if search_params.get('distance') is None:
+            distance = DISTANCE_THRESHOLD
+        else:
+            distance = search_params.get('distance')
+        limit = search_params.get('limit', 10)
+
+        if page_no < 1:
+            page_no = 1
+        if limit < 1:
+            limit = 10
+
+        embedding = Vectorizer.get_embedding(keyword)
+        offset = (page_no - 1) * limit
+
+        try:
+            # 构建基础查询条件
+            base_query = self.db.query(func.count(KGNode.id)).filter(
+                KGNode.status == 0,
+                KGNode.embedding.l2_distance(embedding) <= distance
+            )
+            # 如果有category,则添加额外过滤条件
+            if category:
+                base_query = base_query.filter(KGNode.category == category)
+            # 如果有knowledge_ids,则添加额外过滤条件
+            if search_params.get('knowledge_ids'):
+                total_count = base_query.filter(
+                    KGNode.version.in_(search_params['knowledge_ids'])
+                ).scalar()
+            else:
+                total_count = base_query.scalar()
+
+            query = self.db.query(
+                KGNode.id,
+                KGNode.name,
+                KGNode.category,
+                KGNode.embedding.l2_distance(embedding).label('distance')
+            )            
+            query = query.filter(KGNode.status == 0)
+            #category有值时,过滤掉category不等于category的节点
+            if category:
+                query = query.filter(KGNode.category == category)
+            if search_params.get('knowledge_ids'):
+                query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
+            query = query.filter(KGNode.embedding.l2_distance(embedding) <= distance)
+            results = query.order_by('distance').offset(offset).limit(limit).all()
+
+            return {
+                'records': [{
+                    'id': r.id,
+                    'name': r.name,
+                    'category': r.category,
+                    'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
+                    #'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
+                    'distance': round(r.distance, 3)
+                } for r in results],
+                'pagination': {
+                    'total': total_count,
+                    'pageNo': page_no,
+                    'limit': limit,
+                    'totalPages': (total_count + limit - 1) // limit
+                }
+            
+            }
+        except Exception as e:
+            logger.error(f"分页查询失败: {str(e)}")
+            raise e
+
+    def create_node(self, node_data: dict):
+        try:
+            existing = self.db.query(KGNode).filter(
+                KGNode.name == node_data['name'],
+                KGNode.category == node_data['category'],
+                KGNode.version == node_data.get('version')
+            ).first()
+            
+            if existing:
+                raise ValueError("Node already exists")
+
+            new_node = KGNode(**node_data)
+            self.db.add(new_node)
+            self.db.commit()
+            return new_node
+
+        except IntegrityError as e:
+            self.db.rollback()
+            logger.error(f"创建节点失败: {str(e)}")
+            raise ValueError("Database integrity error")
+
+    def get_node(self, node_id: int):
+        if node_id is None:
+            raise ValueError("Node ID is required")
+        
+        cache_key = f"get_node_{node_id}"
+        if cache_key in self._cache:
+            return self._cache[cache_key]
+        
+        node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first()
+        
+        if not node:
+            raise ValueError("Node not found")
+        
+        node_data = {
+            'id': node.id,
+            'name': node.name,
+            'category': node.category,
+            'version': node.version
+        }
+        self._cache[cache_key] = node_data
+        return node_data
+
+    def update_node(self, node_id: int, update_data: dict):
+        node = self.db.query(KGNode).get(node_id)
+        if not node:
+            raise ValueError("Node not found")
+
+        try:
+            for key, value in update_data.items():
+                setattr(node, key, value)
+            self.db.commit()
+            return node
+        except Exception as e:
+            self.db.rollback()
+            logger.error(f"更新节点失败: {str(e)}")
+            raise ValueError("Update failed")
+
+    def delete_node(self, node_id: int):
+        node = self.db.query(KGNode).get(node_id)
+        if not node:
+            raise ValueError("Node not found")
+
+        try:
+            self.db.delete(node)
+            self.db.commit()
+            return None
+        except Exception as e:
+            self.db.rollback()
+            logger.error(f"删除节点失败: {str(e)}")
+            raise ValueError("Delete failed")
+
+    def batch_process_er_nodes(self):
+        batch_size = 200
+        offset = 0
+
+        while True:
+            try:
+                #下面的查询语句,增加根据id排序,防止并发问题
+                nodes = self.db.query(KGNode).filter(
+                    #KGNode.version == 'ER',
+                    KGNode.embedding == None
+                ).order_by(KGNode.id).offset(offset).limit(batch_size).all()
+
+                if not nodes:
+                    break
+
+                updated_nodes = []
+                for node in nodes:
+                    if not node.embedding:
+                        embedding = Vectorizer.get_embedding(node.name)
+                        node.embedding = embedding
+                        updated_nodes.append(node)
+                if updated_nodes:
+                    self.db.commit()
+
+                offset += batch_size
+            except Exception as e:
+                self.db.rollback()
+                print(f"批量处理ER节点失败: {str(e)}")
+                raise ValueError("Batch process failed")
+
+    def get_node_by_name_category(self, name: str, category: str):
+        if not name or not category:
+            raise ValueError("Name and category are required")
+        
+        cache_key = f"get_node_by_name_category_{name}:{category}"
+        if cache_key in self._cache:
+            return self._cache[cache_key]
+        
+        node = self.db.query(KGNode).filter(
+            KGNode.name == name,
+            KGNode.category == category,
+            KGNode.status == 0
+        ).first()
+        
+        if not node:
+            return None
+        
+        node_data = {
+            'id': node.id,
+            'name': node.name,
+            'category': node.category,
+            'version': node.version
+        }
+        self._cache[cache_key] = node_data
+        return node_data