from sqlalchemy.orm import Session from sqlalchemy import or_ from typing import Optional from model.kg_edges import KGEdge from db.session import get_db import logging from sqlalchemy.exc import IntegrityError from cachetools import TTLCache from cachetools.keys import hashkey logger = logging.getLogger(__name__) class KGEdgeService: def __init__(self, db: Session): self.db = db _cache = TTLCache(maxsize=100000, ttl=60*60*24*30) def get_edge(self, edge_id: int): edge = self.db.query(KGEdge).get(edge_id) if not edge: raise ValueError("Edge not found") return edge def create_edge(self, edge_data: dict): try: existing = self.db.query(KGEdge).filter( KGEdge.src_id == edge_data['src_id'], KGEdge.dest_id == edge_data['dest_id'], KGEdge.name == edge_data['name'], KGEdge.version == edge_data.get('version') ).first() if existing: raise ValueError("Edge already exists") new_edge = KGEdge(**edge_data) self.db.add(new_edge) self.db.commit() return new_edge except IntegrityError as e: self.db.rollback() logger.error(f"创建边失败: {str(e)}") raise ValueError("Database integrity error") def update_edge(self, edge_id: int, update_data: dict): edge = self.db.query(KGEdge).get(edge_id) if not edge: raise ValueError("Edge not found") try: for key, value in update_data.items(): setattr(edge, key, value) self.db.commit() return edge except Exception as e: self.db.rollback() logger.error(f"更新边失败: {str(e)}") raise ValueError("Update failed") def delete_edge(self, edge_id: int): edge = self.db.query(KGEdge).get(edge_id) if not edge: raise ValueError("Edge not found") try: self.db.delete(edge) self.db.commit() return None except Exception as e: self.db.rollback() logger.error(f"删除边失败: {str(e)}") raise ValueError("Delete failed") def get_edges_by_nodes(self, src_id: Optional[int]= None, dest_id: Optional[int]= None, category: Optional[str] = None): cache_key = f"get_edges_by_nodes_{src_id}_{dest_id}_{category}" if cache_key in self._cache: return self._cache[cache_key] if src_id is None and dest_id is None: raise ValueError("至少需要提供一个有效的查询条件") try: filters = [] if src_id is not None: filters.append(KGEdge.src_id == src_id) if dest_id is not None: filters.append(KGEdge.dest_id == dest_id) if category is not None: filters.append(KGEdge.category == category) edges = self.db.query(KGEdge).filter(*filters).all() from service.kg_node_service import KGNodeService node_service = KGNodeService(self.db) result = [] for edge in edges: try: edge_info = { 'id': edge.id, 'src_id': edge.src_id, 'dest_id': edge.dest_id, 'name': edge.name, 'version': edge.version, #'src_node': node_service.get_node(edge.src_id), 'dest_node': node_service.get_node(edge.dest_id) } result.append(edge_info) except ValueError as e: logger.warning(f"跳过边关系 {edge.id}: {str(e)}") continue self._cache[cache_key] = result return result except Exception as e: logger.error(f"查询边失败: {str(e)}") raise e