kg_node_service.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. from sqlalchemy.orm import Session
  2. from typing import Optional
  3. from model.trunks_model import KGNode
  4. from db.session import get_db
  5. import logging
  6. from sqlalchemy.exc import IntegrityError
  7. from schema.response import StandardResponse
  8. logger = logging.getLogger(__name__)
  9. class KGNodeService:
  10. def __init__(self):
  11. self.db = next(get_db())
  12. def paginated_search(self, search_params: dict) -> StandardResponse:
  13. keyword = search_params.get('keyword', '')
  14. page_no = search_params.get('pageNo', 1)
  15. limit = search_params.get('limit', 10)
  16. if page_no < 1:
  17. page_no = 1
  18. if limit < 1:
  19. limit = 10
  20. embedding = Vectorizer.get_embedding(keyword)
  21. offset = (page_no - 1) * limit
  22. try:
  23. total_count = self.db.query(func.count(KGNode.id)).scalar()
  24. query = self.db.query(
  25. KGNode.id,
  26. KGNode.name,
  27. KGNode.category,
  28. KGNode.embedding.l2_distance(embedding).label('distance')
  29. )
  30. if search_params.get('knowledge_ids'):
  31. query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
  32. results = query.order_by('distance').offset(offset).limit(limit).all()
  33. return StandardResponse(
  34. success=True,
  35. data={
  36. 'records': [{
  37. 'id': r.id,
  38. 'name': r.name,
  39. 'category': r.category,
  40. 'distance': r.distance
  41. } for r in results],
  42. 'pagination': {
  43. 'total': total_count,
  44. 'pageNo': page_no,
  45. 'limit': limit,
  46. 'totalPages': (total_count + limit - 1) // limit
  47. }
  48. }
  49. )
  50. except Exception as e:
  51. logger.error(f"分页查询失败: {str(e)}")
  52. return StandardResponse(
  53. success=False,
  54. error_code=500,
  55. error_msg=str(e)
  56. )
  57. def create_node(self, node_data: dict) -> StandardResponse:
  58. try:
  59. existing = self.db.query(KGNode).filter(
  60. KGNode.name == node_data['name'],
  61. KGNode.category == node_data['category'],
  62. KGNode.version == node_data.get('version')
  63. ).first()
  64. if existing:
  65. return StandardResponse(
  66. success=False,
  67. error_code=409,
  68. error_msg="Node already exists"
  69. )
  70. new_node = KGNode(**node_data)
  71. self.db.add(new_node)
  72. self.db.commit()
  73. return StandardResponse(success=True, data=new_node)
  74. except IntegrityError as e:
  75. self.db.rollback()
  76. logger.error(f"创建节点失败: {str(e)}")
  77. return StandardResponse(
  78. success=False,
  79. error_code=500,
  80. error_msg="Database integrity error"
  81. )
  82. def get_node(self, node_id: int) -> StandardResponse:
  83. node = self.db.query(KGNode).get(node_id)
  84. if not node:
  85. return StandardResponse(
  86. success=False,
  87. error_code=404,
  88. error_msg="Node not found"
  89. )
  90. return StandardResponse(success=True, data=node)
  91. def update_node(self, node_id: int, update_data: dict) -> StandardResponse:
  92. node = self.db.query(KGNode).get(node_id)
  93. if not node:
  94. return StandardResponse(
  95. success=False,
  96. error_code=404,
  97. error_msg="Node not found"
  98. )
  99. try:
  100. for key, value in update_data.items():
  101. setattr(node, key, value)
  102. self.db.commit()
  103. return StandardResponse(success=True, data=node)
  104. except Exception as e:
  105. self.db.rollback()
  106. logger.error(f"更新节点失败: {str(e)}")
  107. return StandardResponse(
  108. success=False,
  109. error_code=500,
  110. error_msg="Update failed"
  111. )
  112. def delete_node(self, node_id: int) -> StandardResponse:
  113. node = self.db.query(KGNode).get(node_id)
  114. if not node:
  115. return StandardResponse(
  116. success=False,
  117. error_code=404,
  118. error_msg="Node not found"
  119. )
  120. try:
  121. self.db.delete(node)
  122. self.db.commit()
  123. return StandardResponse(success=True)
  124. except Exception as e:
  125. self.db.rollback()
  126. logger.error(f"删除节点失败: {str(e)}")
  127. return StandardResponse(
  128. success=False,
  129. error_code=500,
  130. error_msg="Delete failed"
  131. )