SGTY 1 månad sedan
förälder
incheckning
696e73d499

+ 14 - 7
src/knowledge/service/kg_edge_service.py

@@ -4,10 +4,12 @@ from typing import Optional
 from ..model.kg_edges import KGEdge
 import logging
 from sqlalchemy.exc import IntegrityError
+from cachetools import TTLCache
 
 logger = logging.getLogger(__name__)
 
 class KGEdgeService:
+    _cache = TTLCache(maxsize=100000, ttl=60*60*24*7)
     def __init__(self, db: Session):
         self.db = db
 
@@ -68,7 +70,12 @@ class KGEdgeService:
             logger.error(f"删除边失败: {str(e)}")
             raise ValueError("Delete failed")
 
-    def get_edges_by_nodes(self, src_id: Optional[int], dest_id: Optional[int], and_logic: bool = True):
+    def get_edges_by_nodes(self, src_id: Optional[int] = None, dest_id: Optional[int] = None,
+                           category: Optional[str] = None):
+        cache_key = f"get_edges_by_nodes_{src_id}_{dest_id}_{category}"
+        if cache_key in self._cache:
+            return self._cache[cache_key]
+
         if src_id is None and dest_id is None:
             raise ValueError("至少需要提供一个有效的查询条件")
         try:
@@ -77,11 +84,10 @@ class KGEdgeService:
                 filters.append(KGEdge.src_id == src_id)
             if dest_id is not None:
                 filters.append(KGEdge.dest_id == dest_id)
+            if category is not None:
+                filters.append(KGEdge.category == category)
+            edges = self.db.query(KGEdge).filter(*filters).all()
 
-            if and_logic:
-                edges = self.db.query(KGEdge).filter(*filters).all()
-            else:
-                edges = self.db.query(KGEdge).filter(or_(*filters)).all()
             from ..service.kg_node_service import KGNodeService
             node_service = KGNodeService(self.db)
             result = []
@@ -93,14 +99,15 @@ class KGEdgeService:
                         'dest_id': edge.dest_id,
                         'name': edge.name,
                         'version': edge.version,
-                        'src_node': node_service.get_node(edge.src_id),
+                        # 'src_node': node_service.get_node(edge.src_id),
                         'dest_node': node_service.get_node(edge.dest_id)
                     }
                     result.append(edge_info)
                 except ValueError as e:
                     logger.warning(f"跳过边关系 {edge.id}: {str(e)}")
                     continue
+            self._cache[cache_key] = result
             return result
         except Exception as e:
             logger.error(f"查询边失败: {str(e)}")
-            raise e
+            raise e

+ 114 - 0
src/knowledge/service/kg_edge_service2.py

@@ -0,0 +1,114 @@
+from sqlalchemy.orm import Session
+from sqlalchemy import or_
+from typing import Optional
+from model.kg_edges import KGEdge
+from db.session import get_db
+import logging
+from sqlalchemy.exc import IntegrityError
+from cachetools import TTLCache
+from cachetools.keys import hashkey
+
+logger = logging.getLogger(__name__)
+
+class KGEdgeService:
+    def __init__(self, db: Session):
+        self.db = db
+
+    _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
+    def get_edge(self, edge_id: int):
+        edge = self.db.query(KGEdge).get(edge_id)
+        if not edge:
+            raise ValueError("Edge not found")
+        return edge
+
+    def create_edge(self, edge_data: dict):
+        try:
+            existing = self.db.query(KGEdge).filter(
+                KGEdge.src_id == edge_data['src_id'],
+                KGEdge.dest_id == edge_data['dest_id'],
+                KGEdge.name == edge_data['name'],
+                KGEdge.version == edge_data.get('version')
+            ).first()
+
+            if existing:
+                raise ValueError("Edge already exists")
+
+            new_edge = KGEdge(**edge_data)
+            self.db.add(new_edge)
+            self.db.commit()
+            return new_edge
+
+        except IntegrityError as e:
+            self.db.rollback()
+            logger.error(f"创建边失败: {str(e)}")
+            raise ValueError("Database integrity error")
+
+    def update_edge(self, edge_id: int, update_data: dict):
+        edge = self.db.query(KGEdge).get(edge_id)
+        if not edge:
+            raise ValueError("Edge not found")
+
+        try:
+            for key, value in update_data.items():
+                setattr(edge, key, value)
+            self.db.commit()
+            return edge
+        except Exception as e:
+            self.db.rollback()
+            logger.error(f"更新边失败: {str(e)}")
+            raise ValueError("Update failed")
+
+    def delete_edge(self, edge_id: int):
+        edge = self.db.query(KGEdge).get(edge_id)
+        if not edge:
+            raise ValueError("Edge not found")
+
+        try:
+            self.db.delete(edge)
+            self.db.commit()
+            return None
+        except Exception as e:
+            self.db.rollback()
+            logger.error(f"删除边失败: {str(e)}")
+            raise ValueError("Delete failed")
+
+    def get_edges_by_nodes(self, src_id: Optional[int]= None, dest_id: Optional[int]= None, category: Optional[str] = None):
+        cache_key = f"get_edges_by_nodes_{src_id}_{dest_id}_{category}"
+        if cache_key in self._cache:
+            return self._cache[cache_key]
+
+        if src_id is None and dest_id is None:
+            raise ValueError("至少需要提供一个有效的查询条件")
+        try:
+            filters = []
+            if src_id is not None:
+                filters.append(KGEdge.src_id == src_id)
+            if dest_id is not None:
+                filters.append(KGEdge.dest_id == dest_id)
+            if category is not None:
+                filters.append(KGEdge.category == category)
+            edges = self.db.query(KGEdge).filter(*filters).all()
+
+            from service.kg_node_service import KGNodeService
+            node_service = KGNodeService(self.db)
+            result = []
+            for edge in edges:
+                try:
+                    edge_info = {
+                        'id': edge.id,
+                        'src_id': edge.src_id,
+                        'dest_id': edge.dest_id,
+                        'name': edge.name,
+                        'version': edge.version,
+                        #'src_node': node_service.get_node(edge.src_id),
+                        'dest_node': node_service.get_node(edge.dest_id)
+                    }
+                    result.append(edge_info)
+                except ValueError as e:
+                    logger.warning(f"跳过边关系 {edge.id}: {str(e)}")
+                    continue
+            self._cache[cache_key] = result
+            return result
+        except Exception as e:
+            logger.error(f"查询边失败: {str(e)}")
+            raise e

+ 2 - 1
src/knowledge/service/kg_node_service.py

@@ -3,7 +3,7 @@ from ..model.kg_node import KGNode
 from ..db.session import get_db
 import logging
 from sqlalchemy.exc import IntegrityError
-
+from cachetools import TTLCache
 from ..utils.vectorizer import Vectorizer
 from sqlalchemy import func
 from ..service.kg_prop_service import KGPropService
@@ -13,6 +13,7 @@ logger = logging.getLogger(__name__)
 DISTANCE_THRESHOLD = 0.65
 DISTANCE_THRESHOLD2 = 0.3
 class KGNodeService:
+    _cache = TTLCache(maxsize=100000, ttl=60*60*24*7)
     def __init__(self, db: Session):
         self.db = db