kg_edge_service.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from sqlalchemy.orm import Session
  2. from sqlalchemy import or_
  3. from typing import Optional
  4. from model.kg_edges import KGEdge
  5. from db.session import get_db
  6. import logging
  7. from sqlalchemy.exc import IntegrityError
  8. from cachetools import TTLCache
  9. from cachetools.keys import hashkey
  10. logger = logging.getLogger(__name__)
  11. class KGEdgeService:
  12. def __init__(self, db: Session):
  13. self.db = db
  14. _cache = TTLCache(maxsize=1000, ttl=60*60*24*30)
  15. def get_edge(self, edge_id: int):
  16. try:
  17. edge = self.db.query(KGEdge).get(edge_id)
  18. self.db.commit()
  19. except Exception as e:
  20. self.db.rollback()
  21. raise e
  22. if not edge:
  23. raise ValueError("Edge not found")
  24. return edge
  25. def create_edge(self, edge_data: dict):
  26. try:
  27. existing = self.db.query(KGEdge).filter(
  28. KGEdge.src_id == edge_data['src_id'],
  29. KGEdge.dest_id == edge_data['dest_id'],
  30. KGEdge.name == edge_data['name'],
  31. KGEdge.version == edge_data.get('version')
  32. ).first()
  33. if existing:
  34. raise ValueError("Edge already exists")
  35. new_edge = KGEdge(**edge_data)
  36. self.db.add(new_edge)
  37. self.db.commit()
  38. return new_edge
  39. except IntegrityError as e:
  40. self.db.rollback()
  41. logger.error(f"创建边失败: {str(e)}")
  42. raise ValueError("Database integrity error")
  43. def update_edge(self, edge_id: int, update_data: dict):
  44. edge = self.db.query(KGEdge).get(edge_id)
  45. if not edge:
  46. raise ValueError("Edge not found")
  47. try:
  48. for key, value in update_data.items():
  49. setattr(edge, key, value)
  50. self.db.commit()
  51. return edge
  52. except Exception as e:
  53. self.db.rollback()
  54. logger.error(f"更新边失败: {str(e)}")
  55. raise ValueError("Update failed")
  56. def delete_edge(self, edge_id: int):
  57. edge = self.db.query(KGEdge).get(edge_id)
  58. if not edge:
  59. raise ValueError("Edge not found")
  60. try:
  61. self.db.delete(edge)
  62. self.db.commit()
  63. return None
  64. except Exception as e:
  65. self.db.rollback()
  66. logger.error(f"删除边失败: {str(e)}")
  67. raise ValueError("Delete failed")
  68. def get_edges_by_nodes(self, src_id: Optional[int]= None, dest_id: Optional[int]= None, category: Optional[str] = None):
  69. cache_key = f"get_edges_by_nodes_{src_id}_{dest_id}_{category}"
  70. if cache_key in self._cache:
  71. return self._cache[cache_key]
  72. if src_id is None and dest_id is None:
  73. raise ValueError("至少需要提供一个有效的查询条件")
  74. try:
  75. filters = []
  76. if src_id is not None:
  77. filters.append(KGEdge.src_id == src_id)
  78. if dest_id is not None:
  79. filters.append(KGEdge.dest_id == dest_id)
  80. if category is not None:
  81. filters.append(KGEdge.category == category)
  82. try:
  83. edges = self.db.query(KGEdge).filter(*filters).all()
  84. self.db.commit()
  85. except Exception as e:
  86. self.db.rollback()
  87. raise e
  88. from service.kg_node_service import KGNodeService
  89. node_service = KGNodeService(self.db)
  90. result = []
  91. for edge in edges:
  92. try:
  93. edge_info = {
  94. 'id': edge.id,
  95. 'src_id': edge.src_id,
  96. 'dest_id': edge.dest_id,
  97. 'name': edge.name,
  98. 'version': edge.version,
  99. #'src_node': node_service.get_node(edge.src_id),
  100. 'dest_node': node_service.get_node(edge.dest_id)
  101. }
  102. result.append(edge_info)
  103. except ValueError as e:
  104. logger.warning(f"跳过边关系 {edge.id}: {str(e)}")
  105. continue
  106. self._cache[cache_key] = result
  107. return result
  108. except Exception as e:
  109. logger.error(f"查询边失败: {str(e)}")
  110. raise e