kg_node_service.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from sqlalchemy.orm import Session
  2. from typing import Optional
  3. from model.kg_node import KGNode
  4. from db.session import get_db
  5. import logging
  6. from sqlalchemy.exc import IntegrityError
  7. from utils.vectorizer import Vectorizer
  8. from sqlalchemy import func
  9. from service.kg_prop_service import KGPropService
  10. from service.kg_edge_service import KGEdgeService
  11. logger = logging.getLogger(__name__)
  12. DISTANCE_THRESHOLD = 0.65
  13. class KGNodeService:
  14. def __init__(self, db: Session):
  15. self.db = db
  16. def paginated_search(self, search_params: dict) -> dict:
  17. load_props = search_params.get('load_props', False)
  18. prop_service = KGPropService(self.db)
  19. edge_service = KGEdgeService(self.db)
  20. keyword = search_params.get('keyword', '')
  21. page_no = search_params.get('pageNo', 1)
  22. limit = search_params.get('limit', 10)
  23. if page_no < 1:
  24. page_no = 1
  25. if limit < 1:
  26. limit = 10
  27. embedding = Vectorizer.get_embedding(keyword)
  28. offset = (page_no - 1) * limit
  29. try:
  30. total_count = self.db.query(func.count(KGNode.id)).filter(KGNode.version.in_(search_params['knowledge_ids']), KGNode.embedding.l2_distance(embedding) < DISTANCE_THRESHOLD).scalar() if search_params.get('knowledge_ids') else self.db.query(func.count(KGNode.id)).filter(KGNode.embedding.l2_distance(embedding) < DISTANCE_THRESHOLD).scalar()
  31. query = self.db.query(
  32. KGNode.id,
  33. KGNode.name,
  34. KGNode.category,
  35. KGNode.embedding.l2_distance(embedding).label('distance')
  36. )
  37. if search_params.get('knowledge_ids'):
  38. query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
  39. query = query.filter(KGNode.embedding.l2_distance(embedding) < DISTANCE_THRESHOLD)
  40. results = query.order_by('distance').offset(offset).limit(limit).all()
  41. return {
  42. 'records': [{
  43. 'id': r.id,
  44. 'name': r.name,
  45. 'category': r.category,
  46. 'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
  47. #'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
  48. 'distance': round(r.distance, 3)
  49. } for r in results],
  50. 'pagination': {
  51. 'total': total_count,
  52. 'pageNo': page_no,
  53. 'limit': limit,
  54. 'totalPages': (total_count + limit - 1) // limit
  55. }
  56. }
  57. except Exception as e:
  58. logger.error(f"分页查询失败: {str(e)}")
  59. raise e
  60. def create_node(self, node_data: dict):
  61. try:
  62. existing = self.db.query(KGNode).filter(
  63. KGNode.name == node_data['name'],
  64. KGNode.category == node_data['category'],
  65. KGNode.version == node_data.get('version')
  66. ).first()
  67. if existing:
  68. raise ValueError("Node already exists")
  69. new_node = KGNode(**node_data)
  70. self.db.add(new_node)
  71. self.db.commit()
  72. return new_node
  73. except IntegrityError as e:
  74. self.db.rollback()
  75. logger.error(f"创建节点失败: {str(e)}")
  76. raise ValueError("Database integrity error")
  77. def get_node(self, node_id: int):
  78. node = self.db.query(KGNode).get(node_id)
  79. if not node:
  80. raise ValueError("Node not found")
  81. return {
  82. 'id': node.id,
  83. 'name': node.name,
  84. 'category': node.category,
  85. 'version': node.version
  86. }
  87. def update_node(self, node_id: int, update_data: dict):
  88. node = self.db.query(KGNode).get(node_id)
  89. if not node:
  90. raise ValueError("Node not found")
  91. try:
  92. for key, value in update_data.items():
  93. setattr(node, key, value)
  94. self.db.commit()
  95. return node
  96. except Exception as e:
  97. self.db.rollback()
  98. logger.error(f"更新节点失败: {str(e)}")
  99. raise ValueError("Update failed")
  100. def delete_node(self, node_id: int):
  101. node = self.db.query(KGNode).get(node_id)
  102. if not node:
  103. raise ValueError("Node not found")
  104. try:
  105. self.db.delete(node)
  106. self.db.commit()
  107. return None
  108. except Exception as e:
  109. self.db.rollback()
  110. logger.error(f"删除节点失败: {str(e)}")
  111. raise ValueError("Delete failed")
  112. def batch_process_er_nodes(self):
  113. batch_size = 200
  114. offset = 0
  115. while True:
  116. try:
  117. nodes = self.db.query(KGNode).filter(
  118. #KGNode.version == 'ER',
  119. KGNode.embedding == None
  120. ).offset(offset).limit(batch_size).all()
  121. if not nodes:
  122. break
  123. updated_nodes = []
  124. for node in nodes:
  125. if not node.embedding:
  126. embedding = Vectorizer.get_embedding(node.name)
  127. node.embedding = embedding
  128. updated_nodes.append(node)
  129. if updated_nodes:
  130. self.db.commit()
  131. offset += batch_size
  132. except Exception as e:
  133. self.db.rollback()
  134. logger.error(f"批量处理ER节点失败: {str(e)}")
  135. raise ValueError("Batch process failed")