kg_edge_service.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from sqlalchemy.orm import Session
  2. from sqlalchemy import or_
  3. from typing import Optional
  4. from model.kg_edges import KGEdge
  5. from db.session import get_db
  6. import logging
  7. from sqlalchemy.exc import IntegrityError
  8. import cachetools
  9. logger = logging.getLogger(__name__)
  10. class KGEdgeService:
  11. def __init__(self, db: Session):
  12. self.db = db
  13. self._edges_cache = cachetools.TTLCache(maxsize=10000, ttl=60*60*24*30)
  14. def get_edge(self, edge_id: int):
  15. edge = self.db.query(KGEdge).get(edge_id)
  16. if not edge:
  17. raise ValueError("Edge not found")
  18. return edge
  19. def create_edge(self, edge_data: dict):
  20. try:
  21. existing = self.db.query(KGEdge).filter(
  22. KGEdge.src_id == edge_data['src_id'],
  23. KGEdge.dest_id == edge_data['dest_id'],
  24. KGEdge.name == edge_data['name'],
  25. KGEdge.version == edge_data.get('version')
  26. ).first()
  27. if existing:
  28. raise ValueError("Edge already exists")
  29. new_edge = KGEdge(**edge_data)
  30. self.db.add(new_edge)
  31. self.db.commit()
  32. return new_edge
  33. except IntegrityError as e:
  34. self.db.rollback()
  35. logger.error(f"创建边失败: {str(e)}")
  36. raise ValueError("Database integrity error")
  37. def update_edge(self, edge_id: int, update_data: dict):
  38. edge = self.db.query(KGEdge).get(edge_id)
  39. if not edge:
  40. raise ValueError("Edge not found")
  41. try:
  42. for key, value in update_data.items():
  43. setattr(edge, key, value)
  44. self.db.commit()
  45. return edge
  46. except Exception as e:
  47. self.db.rollback()
  48. logger.error(f"更新边失败: {str(e)}")
  49. raise ValueError("Update failed")
  50. def delete_edge(self, edge_id: int):
  51. edge = self.db.query(KGEdge).get(edge_id)
  52. if not edge:
  53. raise ValueError("Edge not found")
  54. try:
  55. self.db.delete(edge)
  56. self.db.commit()
  57. return None
  58. except Exception as e:
  59. self.db.rollback()
  60. logger.error(f"删除边失败: {str(e)}")
  61. raise ValueError("Delete failed")
  62. @cachetools.cachedmethod(lambda self: self._edges_cache, key=lambda self, src_id=None, dest_id=None, category=None: (src_id, dest_id, category))
  63. def get_edges_by_nodes(self, src_id: Optional[int]= None, dest_id: Optional[int]= None, category: Optional[str] = None):
  64. if src_id is None and dest_id is None:
  65. raise ValueError("至少需要提供一个有效的查询条件")
  66. try:
  67. filters = []
  68. if src_id is not None:
  69. filters.append(KGEdge.src_id == src_id)
  70. if dest_id is not None:
  71. filters.append(KGEdge.dest_id == dest_id)
  72. if category is not None:
  73. filters.append(KGEdge.category == category)
  74. edges = self.db.query(KGEdge).filter(*filters).all()
  75. from service.kg_node_service import KGNodeService
  76. node_service = KGNodeService(self.db)
  77. result = []
  78. for edge in edges:
  79. try:
  80. edge_info = {
  81. 'id': edge.id,
  82. 'src_id': edge.src_id,
  83. 'dest_id': edge.dest_id,
  84. 'name': edge.name,
  85. 'version': edge.version,
  86. #'src_node': node_service.get_node(edge.src_id),
  87. 'dest_node': node_service.get_node(edge.dest_id)
  88. }
  89. result.append(edge_info)
  90. except ValueError as e:
  91. logger.warning(f"跳过边关系 {edge.id}: {str(e)}")
  92. continue
  93. return result
  94. except Exception as e:
  95. logger.error(f"查询边失败: {str(e)}")
  96. raise e