123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- from sqlalchemy.orm import Session
- from sqlalchemy import or_
- from typing import Optional
- from model.kg_edges import KGEdge
- from db.session import get_db
- import logging
- from sqlalchemy.exc import IntegrityError
- logger = logging.getLogger(__name__)
- class KGEdgeService:
- def __init__(self, db: Session):
- self.db = db
- def get_edge(self, edge_id: int):
- edge = self.db.query(KGEdge).get(edge_id)
- if not edge:
- raise ValueError("Edge not found")
- return edge
- def create_edge(self, edge_data: dict):
- try:
- existing = self.db.query(KGEdge).filter(
- KGEdge.src_id == edge_data['src_id'],
- KGEdge.dest_id == edge_data['dest_id'],
- KGEdge.name == edge_data['name'],
- KGEdge.version == edge_data.get('version')
- ).first()
- if existing:
- raise ValueError("Edge already exists")
- new_edge = KGEdge(**edge_data)
- self.db.add(new_edge)
- self.db.commit()
- return new_edge
- except IntegrityError as e:
- self.db.rollback()
- logger.error(f"创建边失败: {str(e)}")
- raise ValueError("Database integrity error")
- def update_edge(self, edge_id: int, update_data: dict):
- edge = self.db.query(KGEdge).get(edge_id)
- if not edge:
- raise ValueError("Edge not found")
- try:
- for key, value in update_data.items():
- setattr(edge, key, value)
- self.db.commit()
- return edge
- except Exception as e:
- self.db.rollback()
- logger.error(f"更新边失败: {str(e)}")
- raise ValueError("Update failed")
- def delete_edge(self, edge_id: int):
- edge = self.db.query(KGEdge).get(edge_id)
- if not edge:
- raise ValueError("Edge not found")
- try:
- self.db.delete(edge)
- self.db.commit()
- return None
- except Exception as e:
- self.db.rollback()
- logger.error(f"删除边失败: {str(e)}")
- raise ValueError("Delete failed")
- def get_edges_by_nodes(self, src_id: Optional[int], dest_id: Optional[int], and_logic: bool = True):
- if src_id is None and dest_id is None:
- raise ValueError("至少需要提供一个有效的查询条件")
- try:
- filters = []
- if src_id is not None:
- filters.append(KGEdge.src_id == src_id)
- if dest_id is not None:
- filters.append(KGEdge.dest_id == dest_id)
- if and_logic:
- edges = self.db.query(KGEdge).filter(*filters).all()
- else:
- edges = self.db.query(KGEdge).filter(or_(*filters)).all()
- from service.kg_node_service import KGNodeService
- node_service = KGNodeService(self.db)
- return [{
- 'id': edge.id,
- 'src_id': edge.src_id,
- 'dest_id': edge.dest_id,
- 'name': edge.name,
- 'version': edge.version,
- 'src_node': node_service.get_node(edge.src_id),
- 'dest_node': node_service.get_node(edge.dest_id)
- } for edge in edges]
- except Exception as e:
- logger.error(f"查询边失败: {str(e)}")
- raise e
|