kg_node_service.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from sqlalchemy.orm import Session
  2. from typing import Optional
  3. from model.kg_node import KGNode
  4. from db.session import get_db
  5. import logging
  6. from sqlalchemy.exc import IntegrityError
  7. from utils import vectorizer
  8. from utils.vectorizer import Vectorizer
  9. from sqlalchemy import func
  10. from service.kg_prop_service import KGPropService
  11. from service.kg_edge_service import KGEdgeService
  12. logger = logging.getLogger(__name__)
  13. DISTANCE_THRESHOLD = 0.65
  14. DISTANCE_THRESHOLD2 = 0.3
  15. class KGNodeService:
  16. def __init__(self, db: Session):
  17. self.db = db
  18. _cache = {}
  19. def search_title_index(self, index: str, title: str, top_k: int = 3):
  20. cache_key = f"{index}:{title}:{top_k}"
  21. if cache_key in self._cache:
  22. return self._cache[cache_key]
  23. query_embedding = Vectorizer.get_embedding(title)
  24. db = next(get_db())
  25. # 执行向量搜索
  26. results = (
  27. db.query(
  28. KGNode.id,
  29. KGNode.name,
  30. KGNode.category,
  31. KGNode.embedding.l2_distance(query_embedding).label('distance')
  32. )
  33. .filter(KGNode.embedding.l2_distance(query_embedding) <= DISTANCE_THRESHOLD2)
  34. .order_by('distance').limit(top_k).all()
  35. )
  36. results = [
  37. {
  38. "id": node.id,
  39. "title": node.name,
  40. "text": node.category,
  41. "score": 2.0-node.distance
  42. }
  43. for node in results
  44. ]
  45. self._cache[cache_key] = results
  46. return results
  47. def paginated_search(self, search_params: dict) -> dict:
  48. load_props = search_params.get('load_props', False)
  49. prop_service = KGPropService(self.db)
  50. edge_service = KGEdgeService(self.db)
  51. keyword = search_params.get('keyword', '')
  52. page_no = search_params.get('pageNo', 1)
  53. limit = search_params.get('limit', 10)
  54. if page_no < 1:
  55. page_no = 1
  56. if limit < 1:
  57. limit = 10
  58. embedding = Vectorizer.get_embedding(keyword)
  59. offset = (page_no - 1) * limit
  60. try:
  61. total_count = self.db.query(func.count(KGNode.id)).filter(KGNode.version.in_(search_params['knowledge_ids']), KGNode.embedding.l2_distance(embedding) < DISTANCE_THRESHOLD).scalar() if search_params.get('knowledge_ids') else self.db.query(func.count(KGNode.id)).filter(KGNode.embedding.l2_distance(embedding) < DISTANCE_THRESHOLD).scalar()
  62. query = self.db.query(
  63. KGNode.id,
  64. KGNode.name,
  65. KGNode.category,
  66. KGNode.embedding.l2_distance(embedding).label('distance')
  67. )
  68. if search_params.get('knowledge_ids'):
  69. query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
  70. query = query.filter(KGNode.embedding.l2_distance(embedding) < DISTANCE_THRESHOLD)
  71. results = query.order_by('distance').offset(offset).limit(limit).all()
  72. return {
  73. 'records': [{
  74. 'id': r.id,
  75. 'name': r.name,
  76. 'category': r.category,
  77. 'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
  78. #'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
  79. 'distance': round(r.distance, 3)
  80. } for r in results],
  81. 'pagination': {
  82. 'total': total_count,
  83. 'pageNo': page_no,
  84. 'limit': limit,
  85. 'totalPages': (total_count + limit - 1) // limit
  86. }
  87. }
  88. except Exception as e:
  89. logger.error(f"分页查询失败: {str(e)}")
  90. raise e
  91. def create_node(self, node_data: dict):
  92. try:
  93. existing = self.db.query(KGNode).filter(
  94. KGNode.name == node_data['name'],
  95. KGNode.category == node_data['category'],
  96. KGNode.version == node_data.get('version')
  97. ).first()
  98. if existing:
  99. raise ValueError("Node already exists")
  100. new_node = KGNode(**node_data)
  101. self.db.add(new_node)
  102. self.db.commit()
  103. return new_node
  104. except IntegrityError as e:
  105. self.db.rollback()
  106. logger.error(f"创建节点失败: {str(e)}")
  107. raise ValueError("Database integrity error")
  108. def get_node(self, node_id: int):
  109. node = self.db.query(KGNode).get(node_id)
  110. if not node:
  111. raise ValueError("Node not found")
  112. return {
  113. 'id': node.id,
  114. 'name': node.name,
  115. 'category': node.category,
  116. 'version': node.version
  117. }
  118. def update_node(self, node_id: int, update_data: dict):
  119. node = self.db.query(KGNode).get(node_id)
  120. if not node:
  121. raise ValueError("Node not found")
  122. try:
  123. for key, value in update_data.items():
  124. setattr(node, key, value)
  125. self.db.commit()
  126. return node
  127. except Exception as e:
  128. self.db.rollback()
  129. logger.error(f"更新节点失败: {str(e)}")
  130. raise ValueError("Update failed")
  131. def delete_node(self, node_id: int):
  132. node = self.db.query(KGNode).get(node_id)
  133. if not node:
  134. raise ValueError("Node not found")
  135. try:
  136. self.db.delete(node)
  137. self.db.commit()
  138. return None
  139. except Exception as e:
  140. self.db.rollback()
  141. logger.error(f"删除节点失败: {str(e)}")
  142. raise ValueError("Delete failed")
  143. def batch_process_er_nodes(self):
  144. batch_size = 200
  145. offset = 0
  146. while True:
  147. try:
  148. nodes = self.db.query(KGNode).filter(
  149. #KGNode.version == 'ER',
  150. KGNode.embedding == None
  151. ).offset(offset).limit(batch_size).all()
  152. if not nodes:
  153. break
  154. updated_nodes = []
  155. for node in nodes:
  156. if not node.embedding:
  157. embedding = Vectorizer.get_embedding(node.name)
  158. node.embedding = embedding
  159. updated_nodes.append(node)
  160. if updated_nodes:
  161. self.db.commit()
  162. offset += batch_size
  163. except Exception as e:
  164. self.db.rollback()
  165. print(f"批量处理ER节点失败: {str(e)}")
  166. raise ValueError("Batch process failed")