kg_node_service.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. from sqlalchemy.orm import Session
  2. from typing import Optional
  3. from model.kg_node import KGNode
  4. from db.session import get_db
  5. import logging
  6. from sqlalchemy.exc import IntegrityError
  7. from utils import vectorizer
  8. from utils.vectorizer import Vectorizer
  9. from sqlalchemy import func
  10. from service.kg_prop_service import KGPropService
  11. from service.kg_edge_service import KGEdgeService
  12. from cachetools import TTLCache
  13. from cachetools.keys import hashkey
  14. logger = logging.getLogger(__name__)
  15. DISTANCE_THRESHOLD = 0.65
  16. class KGNodeService:
  17. def __init__(self, db: Session):
  18. self.db = db
  19. _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
  20. def search_title_index(self, index: str, keywrod: str,category: str, top_k: int = 3,distance: float = 0.3) -> Optional[int]:
  21. cache_key = f"search_title_index_{index}:{keywrod}:{category}:{top_k}:{distance}"
  22. if cache_key in self._cache:
  23. return self._cache[cache_key]
  24. query_embedding = Vectorizer.get_embedding(keywrod)
  25. db = next(get_db())
  26. # 执行向量搜索
  27. results = (
  28. db.query(
  29. KGNode.id,
  30. KGNode.name,
  31. KGNode.category,
  32. KGNode.embedding.l2_distance(query_embedding).label('distance')
  33. )
  34. .filter(KGNode.status == 0)
  35. .filter(KGNode.category == category)
  36. #todo 是否能提高性能 改成余弦算法
  37. .filter(KGNode.embedding.l2_distance(query_embedding) <= distance)
  38. .order_by('distance').limit(top_k).all()
  39. )
  40. results = [
  41. {
  42. "id": node.id,
  43. "title": node.name,
  44. "text": node.category,
  45. "score": 2.0-node.distance
  46. }
  47. for node in results
  48. ]
  49. self._cache[cache_key] = results
  50. return results
  51. def paginated_search(self, search_params: dict) -> dict:
  52. load_props = search_params.get('load_props', False)
  53. prop_service = KGPropService(self.db)
  54. edge_service = KGEdgeService(self.db)
  55. keyword = search_params.get('keyword', '')
  56. category = search_params.get('category', None)
  57. page_no = search_params.get('pageNo', 1)
  58. #distance 为NONE或不存在时,使用默认值
  59. if search_params.get('distance') is None:
  60. distance = DISTANCE_THRESHOLD
  61. else:
  62. distance = search_params.get('distance')
  63. limit = search_params.get('limit', 10)
  64. if page_no < 1:
  65. page_no = 1
  66. if limit < 1:
  67. limit = 10
  68. embedding = Vectorizer.get_embedding(keyword)
  69. offset = (page_no - 1) * limit
  70. try:
  71. # 构建基础查询条件
  72. base_query = self.db.query(func.count(KGNode.id)).filter(
  73. KGNode.status == 0,
  74. KGNode.embedding.l2_distance(embedding) <= distance
  75. )
  76. # 如果有category,则添加额外过滤条件
  77. if category:
  78. base_query = base_query.filter(KGNode.category == category)
  79. # 如果有knowledge_ids,则添加额外过滤条件
  80. if search_params.get('knowledge_ids'):
  81. total_count = base_query.filter(
  82. KGNode.version.in_(search_params['knowledge_ids'])
  83. ).scalar()
  84. else:
  85. total_count = base_query.scalar()
  86. query = self.db.query(
  87. KGNode.id,
  88. KGNode.name,
  89. KGNode.category,
  90. KGNode.embedding.l2_distance(embedding).label('distance')
  91. )
  92. query = query.filter(KGNode.status == 0)
  93. #category有值时,过滤掉category不等于category的节点
  94. if category:
  95. query = query.filter(KGNode.category == category)
  96. if search_params.get('knowledge_ids'):
  97. query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
  98. query = query.filter(KGNode.embedding.l2_distance(embedding) <= distance)
  99. results = query.order_by('distance').offset(offset).limit(limit).all()
  100. return {
  101. 'records': [{
  102. 'id': r.id,
  103. 'name': r.name,
  104. 'category': r.category,
  105. 'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
  106. #'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
  107. 'distance': round(r.distance, 3)
  108. } for r in results],
  109. 'pagination': {
  110. 'total': total_count,
  111. 'pageNo': page_no,
  112. 'limit': limit,
  113. 'totalPages': (total_count + limit - 1) // limit
  114. }
  115. }
  116. except Exception as e:
  117. logger.error(f"分页查询失败: {str(e)}")
  118. raise e
  119. def create_node(self, node_data: dict):
  120. try:
  121. existing = self.db.query(KGNode).filter(
  122. KGNode.name == node_data['name'],
  123. KGNode.category == node_data['category'],
  124. KGNode.version == node_data.get('version')
  125. ).first()
  126. if existing:
  127. raise ValueError("Node already exists")
  128. new_node = KGNode(**node_data)
  129. self.db.add(new_node)
  130. self.db.commit()
  131. return new_node
  132. except IntegrityError as e:
  133. self.db.rollback()
  134. logger.error(f"创建节点失败: {str(e)}")
  135. raise ValueError("Database integrity error")
  136. def get_node(self, node_id: int):
  137. if node_id is None:
  138. raise ValueError("Node ID is required")
  139. cache_key = f"get_node_{node_id}"
  140. if cache_key in self._cache:
  141. return self._cache[cache_key]
  142. node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first()
  143. if not node:
  144. raise ValueError("Node not found")
  145. node_data = {
  146. 'id': node.id,
  147. 'name': node.name,
  148. 'category': node.category,
  149. 'version': node.version
  150. }
  151. self._cache[cache_key] = node_data
  152. return node_data
  153. def update_node(self, node_id: int, update_data: dict):
  154. node = self.db.query(KGNode).get(node_id)
  155. if not node:
  156. raise ValueError("Node not found")
  157. try:
  158. for key, value in update_data.items():
  159. setattr(node, key, value)
  160. self.db.commit()
  161. return node
  162. except Exception as e:
  163. self.db.rollback()
  164. logger.error(f"更新节点失败: {str(e)}")
  165. raise ValueError("Update failed")
  166. def delete_node(self, node_id: int):
  167. node = self.db.query(KGNode).get(node_id)
  168. if not node:
  169. raise ValueError("Node not found")
  170. try:
  171. self.db.delete(node)
  172. self.db.commit()
  173. return None
  174. except Exception as e:
  175. self.db.rollback()
  176. logger.error(f"删除节点失败: {str(e)}")
  177. raise ValueError("Delete failed")
  178. def batch_process_er_nodes(self):
  179. batch_size = 200
  180. offset = 0
  181. while True:
  182. try:
  183. #下面的查询语句,增加根据id排序,防止并发问题
  184. nodes = self.db.query(KGNode).filter(
  185. #KGNode.version == 'ER',
  186. KGNode.embedding == None
  187. ).order_by(KGNode.id).offset(offset).limit(batch_size).all()
  188. if not nodes:
  189. break
  190. updated_nodes = []
  191. for node in nodes:
  192. if not node.embedding:
  193. embedding = Vectorizer.get_embedding(node.name)
  194. node.embedding = embedding
  195. updated_nodes.append(node)
  196. if updated_nodes:
  197. self.db.commit()
  198. offset += batch_size
  199. except Exception as e:
  200. self.db.rollback()
  201. print(f"批量处理ER节点失败: {str(e)}")
  202. raise ValueError("Batch process failed")
  203. def get_node_by_name_category(self, name: str, category: str):
  204. if not name or not category:
  205. raise ValueError("Name and category are required")
  206. cache_key = f"get_node_by_name_category_{name}:{category}"
  207. if cache_key in self._cache:
  208. return self._cache[cache_key]
  209. node = self.db.query(KGNode).filter(
  210. KGNode.name == name,
  211. KGNode.category == category,
  212. KGNode.status == 0
  213. ).first()
  214. if not node:
  215. return None
  216. node_data = {
  217. 'id': node.id,
  218. 'name': node.name,
  219. 'category': node.category,
  220. 'version': node.version
  221. }
  222. self._cache[cache_key] = node_data
  223. return node_data