from sqlalchemy.orm import Session from sqlalchemy import or_ from typing import Optional from model.kg_edges import KGEdge from db.session import get_db import logging from sqlalchemy.exc import IntegrityError logger = logging.getLogger(__name__) class KGEdgeService: def __init__(self, db: Session): self.db = db def get_edge(self, edge_id: int): edge = self.db.query(KGEdge).get(edge_id) if not edge: raise ValueError("Edge not found") return edge def create_edge(self, edge_data: dict): try: existing = self.db.query(KGEdge).filter( KGEdge.src_id == edge_data['src_id'], KGEdge.dest_id == edge_data['dest_id'], KGEdge.name == edge_data['name'], KGEdge.version == edge_data.get('version') ).first() if existing: raise ValueError("Edge already exists") new_edge = KGEdge(**edge_data) self.db.add(new_edge) self.db.commit() return new_edge except IntegrityError as e: self.db.rollback() logger.error(f"创建边失败: {str(e)}") raise ValueError("Database integrity error") def update_edge(self, edge_id: int, update_data: dict): edge = self.db.query(KGEdge).get(edge_id) if not edge: raise ValueError("Edge not found") try: for key, value in update_data.items(): setattr(edge, key, value) self.db.commit() return edge except Exception as e: self.db.rollback() logger.error(f"更新边失败: {str(e)}") raise ValueError("Update failed") def delete_edge(self, edge_id: int): edge = self.db.query(KGEdge).get(edge_id) if not edge: raise ValueError("Edge not found") try: self.db.delete(edge) self.db.commit() return None except Exception as e: self.db.rollback() logger.error(f"删除边失败: {str(e)}") raise ValueError("Delete failed") def get_edges_by_nodes(self, src_id: Optional[int], dest_id: Optional[int], and_logic: bool = True): if src_id is None and dest_id is None: raise ValueError("至少需要提供一个有效的查询条件") try: filters = [] if src_id is not None: filters.append(KGEdge.src_id == src_id) if dest_id is not None: filters.append(KGEdge.dest_id == dest_id) if and_logic: edges = self.db.query(KGEdge).filter(*filters).all() else: edges = self.db.query(KGEdge).filter(or_(*filters)).all() from service.kg_node_service import KGNodeService node_service = KGNodeService(self.db) return [{ 'id': edge.id, 'src_id': edge.src_id, 'dest_id': edge.dest_id, 'name': edge.name, 'version': edge.version, 'src_node': node_service.get_node(edge.src_id), 'dest_node': node_service.get_node(edge.dest_id) } for edge in edges] except Exception as e: logger.error(f"查询边失败: {str(e)}") raise e