kg_edge_service.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from sqlalchemy.orm import Session
  2. from sqlalchemy import or_
  3. from typing import Optional
  4. from model.kg_edges import KGEdge
  5. from db.session import get_db
  6. import logging
  7. from sqlalchemy.exc import IntegrityError
  8. logger = logging.getLogger(__name__)
  9. class KGEdgeService:
  10. def __init__(self, db: Session):
  11. self.db = db
  12. def get_edge(self, edge_id: int):
  13. edge = self.db.query(KGEdge).get(edge_id)
  14. if not edge:
  15. raise ValueError("Edge not found")
  16. return edge
  17. def create_edge(self, edge_data: dict):
  18. try:
  19. existing = self.db.query(KGEdge).filter(
  20. KGEdge.src_id == edge_data['src_id'],
  21. KGEdge.dest_id == edge_data['dest_id'],
  22. KGEdge.name == edge_data['name'],
  23. KGEdge.version == edge_data.get('version')
  24. ).first()
  25. if existing:
  26. raise ValueError("Edge already exists")
  27. new_edge = KGEdge(**edge_data)
  28. self.db.add(new_edge)
  29. self.db.commit()
  30. return new_edge
  31. except IntegrityError as e:
  32. self.db.rollback()
  33. logger.error(f"创建边失败: {str(e)}")
  34. raise ValueError("Database integrity error")
  35. def update_edge(self, edge_id: int, update_data: dict):
  36. edge = self.db.query(KGEdge).get(edge_id)
  37. if not edge:
  38. raise ValueError("Edge not found")
  39. try:
  40. for key, value in update_data.items():
  41. setattr(edge, key, value)
  42. self.db.commit()
  43. return edge
  44. except Exception as e:
  45. self.db.rollback()
  46. logger.error(f"更新边失败: {str(e)}")
  47. raise ValueError("Update failed")
  48. def delete_edge(self, edge_id: int):
  49. edge = self.db.query(KGEdge).get(edge_id)
  50. if not edge:
  51. raise ValueError("Edge not found")
  52. try:
  53. self.db.delete(edge)
  54. self.db.commit()
  55. return None
  56. except Exception as e:
  57. self.db.rollback()
  58. logger.error(f"删除边失败: {str(e)}")
  59. raise ValueError("Delete failed")
  60. def get_edges_by_nodes(self, src_id: Optional[int], dest_id: Optional[int], and_logic: bool = True):
  61. if src_id is None and dest_id is None:
  62. raise ValueError("至少需要提供一个有效的查询条件")
  63. try:
  64. filters = []
  65. if src_id is not None:
  66. filters.append(KGEdge.src_id == src_id)
  67. if dest_id is not None:
  68. filters.append(KGEdge.dest_id == dest_id)
  69. if and_logic:
  70. edges = self.db.query(KGEdge).filter(*filters).all()
  71. else:
  72. edges = self.db.query(KGEdge).filter(or_(*filters)).all()
  73. from service.kg_node_service import KGNodeService
  74. node_service = KGNodeService(self.db)
  75. return [{
  76. 'id': edge.id,
  77. 'src_id': edge.src_id,
  78. 'dest_id': edge.dest_id,
  79. 'name': edge.name,
  80. 'version': edge.version,
  81. 'src_node': node_service.get_node(edge.src_id),
  82. 'dest_node': node_service.get_node(edge.dest_id)
  83. } for edge in edges]
  84. except Exception as e:
  85. logger.error(f"查询边失败: {str(e)}")
  86. raise e