kg_node_service.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from sqlalchemy.orm import Session
  2. from typing import Optional
  3. from model.kg_node import KGNode
  4. from db.session import get_db
  5. import logging
  6. from sqlalchemy.exc import IntegrityError
  7. from utils.vectorizer import Vectorizer
  8. from sqlalchemy import func
  9. from service.kg_prop_service import KGPropService
  10. from service.kg_edge_service import KGEdgeService
  11. logger = logging.getLogger(__name__)
  12. class KGNodeService:
  13. def __init__(self, db: Session):
  14. self.db = db
  15. def paginated_search(self, search_params: dict) -> dict:
  16. load_props = search_params.get('load_props', False)
  17. prop_service = KGPropService(self.db)
  18. edge_service = KGEdgeService(self.db)
  19. keyword = search_params.get('keyword', '')
  20. page_no = search_params.get('pageNo', 1)
  21. limit = search_params.get('limit', 10)
  22. if page_no < 1:
  23. page_no = 1
  24. if limit < 1:
  25. limit = 10
  26. embedding = Vectorizer.get_embedding(keyword)
  27. offset = (page_no - 1) * limit
  28. try:
  29. total_count = self.db.query(func.count(KGNode.id)).scalar()
  30. query = self.db.query(
  31. KGNode.id,
  32. KGNode.name,
  33. KGNode.category,
  34. KGNode.embedding.l2_distance(embedding).label('distance')
  35. )
  36. if search_params.get('knowledge_ids'):
  37. query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
  38. results = query.order_by('distance').offset(offset).limit(limit).all()
  39. return {
  40. 'records': [{
  41. 'id': r.id,
  42. 'name': r.name,
  43. 'category': r.category,
  44. 'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
  45. 'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
  46. 'distance': r.distance
  47. } for r in results],
  48. 'pagination': {
  49. 'total': total_count,
  50. 'pageNo': page_no,
  51. 'limit': limit,
  52. 'totalPages': (total_count + limit - 1) // limit
  53. }
  54. }
  55. except Exception as e:
  56. logger.error(f"分页查询失败: {str(e)}")
  57. raise e
  58. def create_node(self, node_data: dict):
  59. try:
  60. existing = self.db.query(KGNode).filter(
  61. KGNode.name == node_data['name'],
  62. KGNode.category == node_data['category'],
  63. KGNode.version == node_data.get('version')
  64. ).first()
  65. if existing:
  66. raise ValueError("Node already exists")
  67. new_node = KGNode(**node_data)
  68. self.db.add(new_node)
  69. self.db.commit()
  70. return new_node
  71. except IntegrityError as e:
  72. self.db.rollback()
  73. logger.error(f"创建节点失败: {str(e)}")
  74. raise ValueError("Database integrity error")
  75. def get_node(self, node_id: int):
  76. node = self.db.query(KGNode).get(node_id)
  77. if not node:
  78. raise ValueError("Node not found")
  79. return {
  80. 'id': node.id,
  81. 'name': node.name,
  82. 'category': node.category,
  83. 'version': node.version
  84. }
  85. def update_node(self, node_id: int, update_data: dict):
  86. node = self.db.query(KGNode).get(node_id)
  87. if not node:
  88. raise ValueError("Node not found")
  89. try:
  90. for key, value in update_data.items():
  91. setattr(node, key, value)
  92. self.db.commit()
  93. return node
  94. except Exception as e:
  95. self.db.rollback()
  96. logger.error(f"更新节点失败: {str(e)}")
  97. raise ValueError("Update failed")
  98. def delete_node(self, node_id: int):
  99. node = self.db.query(KGNode).get(node_id)
  100. if not node:
  101. raise ValueError("Node not found")
  102. try:
  103. self.db.delete(node)
  104. self.db.commit()
  105. return None
  106. except Exception as e:
  107. self.db.rollback()
  108. logger.error(f"删除节点失败: {str(e)}")
  109. raise ValueError("Delete failed")
  110. def batch_process_er_nodes(self):
  111. batch_size = 200
  112. offset = 0
  113. while True:
  114. try:
  115. nodes = self.db.query(KGNode).filter(
  116. KGNode.version == 'ER',
  117. KGNode.embedding == None
  118. ).offset(offset).limit(batch_size).all()
  119. if not nodes:
  120. break
  121. for node in nodes:
  122. if not node.embedding:
  123. embedding = Vectorizer.get_embedding(node.name)
  124. node.embedding = embedding
  125. self.db.commit()
  126. offset += batch_size
  127. except Exception as e:
  128. self.db.rollback()
  129. logger.error(f"批量处理ER节点失败: {str(e)}")
  130. raise ValueError("Batch process failed")