|
@@ -0,0 +1,114 @@
|
|
|
+from sqlalchemy.orm import Session
|
|
|
+from sqlalchemy import or_
|
|
|
+from typing import Optional
|
|
|
+from model.kg_edges import KGEdge
|
|
|
+from db.session import get_db
|
|
|
+import logging
|
|
|
+from sqlalchemy.exc import IntegrityError
|
|
|
+from cachetools import TTLCache
|
|
|
+from cachetools.keys import hashkey
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+class KGEdgeService:
|
|
|
+ def __init__(self, db: Session):
|
|
|
+ self.db = db
|
|
|
+
|
|
|
+ _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
|
|
|
+ def get_edge(self, edge_id: int):
|
|
|
+ edge = self.db.query(KGEdge).get(edge_id)
|
|
|
+ if not edge:
|
|
|
+ raise ValueError("Edge not found")
|
|
|
+ return edge
|
|
|
+
|
|
|
+ def create_edge(self, edge_data: dict):
|
|
|
+ try:
|
|
|
+ existing = self.db.query(KGEdge).filter(
|
|
|
+ KGEdge.src_id == edge_data['src_id'],
|
|
|
+ KGEdge.dest_id == edge_data['dest_id'],
|
|
|
+ KGEdge.name == edge_data['name'],
|
|
|
+ KGEdge.version == edge_data.get('version')
|
|
|
+ ).first()
|
|
|
+
|
|
|
+ if existing:
|
|
|
+ raise ValueError("Edge already exists")
|
|
|
+
|
|
|
+ new_edge = KGEdge(**edge_data)
|
|
|
+ self.db.add(new_edge)
|
|
|
+ self.db.commit()
|
|
|
+ return new_edge
|
|
|
+
|
|
|
+ except IntegrityError as e:
|
|
|
+ self.db.rollback()
|
|
|
+ logger.error(f"创建边失败: {str(e)}")
|
|
|
+ raise ValueError("Database integrity error")
|
|
|
+
|
|
|
+ def update_edge(self, edge_id: int, update_data: dict):
|
|
|
+ edge = self.db.query(KGEdge).get(edge_id)
|
|
|
+ if not edge:
|
|
|
+ raise ValueError("Edge not found")
|
|
|
+
|
|
|
+ try:
|
|
|
+ for key, value in update_data.items():
|
|
|
+ setattr(edge, key, value)
|
|
|
+ self.db.commit()
|
|
|
+ return edge
|
|
|
+ except Exception as e:
|
|
|
+ self.db.rollback()
|
|
|
+ logger.error(f"更新边失败: {str(e)}")
|
|
|
+ raise ValueError("Update failed")
|
|
|
+
|
|
|
+ def delete_edge(self, edge_id: int):
|
|
|
+ edge = self.db.query(KGEdge).get(edge_id)
|
|
|
+ if not edge:
|
|
|
+ raise ValueError("Edge not found")
|
|
|
+
|
|
|
+ try:
|
|
|
+ self.db.delete(edge)
|
|
|
+ self.db.commit()
|
|
|
+ return None
|
|
|
+ except Exception as e:
|
|
|
+ self.db.rollback()
|
|
|
+ logger.error(f"删除边失败: {str(e)}")
|
|
|
+ raise ValueError("Delete failed")
|
|
|
+
|
|
|
+ def get_edges_by_nodes(self, src_id: Optional[int]= None, dest_id: Optional[int]= None, category: Optional[str] = None):
|
|
|
+ cache_key = f"get_edges_by_nodes_{src_id}_{dest_id}_{category}"
|
|
|
+ if cache_key in self._cache:
|
|
|
+ return self._cache[cache_key]
|
|
|
+
|
|
|
+ if src_id is None and dest_id is None:
|
|
|
+ raise ValueError("至少需要提供一个有效的查询条件")
|
|
|
+ try:
|
|
|
+ filters = []
|
|
|
+ if src_id is not None:
|
|
|
+ filters.append(KGEdge.src_id == src_id)
|
|
|
+ if dest_id is not None:
|
|
|
+ filters.append(KGEdge.dest_id == dest_id)
|
|
|
+ if category is not None:
|
|
|
+ filters.append(KGEdge.category == category)
|
|
|
+ edges = self.db.query(KGEdge).filter(*filters).all()
|
|
|
+
|
|
|
+ from service.kg_node_service import KGNodeService
|
|
|
+ node_service = KGNodeService(self.db)
|
|
|
+ result = []
|
|
|
+ for edge in edges:
|
|
|
+ try:
|
|
|
+ edge_info = {
|
|
|
+ 'id': edge.id,
|
|
|
+ 'src_id': edge.src_id,
|
|
|
+ 'dest_id': edge.dest_id,
|
|
|
+ 'name': edge.name,
|
|
|
+ 'version': edge.version,
|
|
|
+ #'src_node': node_service.get_node(edge.src_id),
|
|
|
+ 'dest_node': node_service.get_node(edge.dest_id)
|
|
|
+ }
|
|
|
+ result.append(edge_info)
|
|
|
+ except ValueError as e:
|
|
|
+ logger.warning(f"跳过边关系 {edge.id}: {str(e)}")
|
|
|
+ continue
|
|
|
+ self._cache[cache_key] = result
|
|
|
+ return result
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"查询边失败: {str(e)}")
|
|
|
+ raise e
|