|
@@ -0,0 +1,258 @@
|
|
|
+from sqlalchemy.orm import Session
|
|
|
+from typing import Optional
|
|
|
+from model.kg_node import KGNode
|
|
|
+from db.session import get_db
|
|
|
+import logging
|
|
|
+from sqlalchemy.exc import IntegrityError
|
|
|
+
|
|
|
+from utils import vectorizer
|
|
|
+from utils.vectorizer import Vectorizer
|
|
|
+from sqlalchemy import func
|
|
|
+from service.kg_prop_service import KGPropService
|
|
|
+from service.kg_edge_service import KGEdgeService
|
|
|
+from cachetools import TTLCache
|
|
|
+from cachetools.keys import hashkey
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+DISTANCE_THRESHOLD = 0.65
|
|
|
+class KGNodeService:
|
|
|
+ def __init__(self, db: Session):
|
|
|
+ self.db = db
|
|
|
+
|
|
|
+ _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
|
|
|
+
|
|
|
+ def search_title_index(self, index: str, keywrod: str,category: str, top_k: int = 3,distance: float = 0.3) -> Optional[int]:
|
|
|
+ cache_key = f"search_title_index_{index}:{keywrod}:{category}:{top_k}:{distance}"
|
|
|
+ if cache_key in self._cache:
|
|
|
+ return self._cache[cache_key]
|
|
|
+
|
|
|
+ query_embedding = Vectorizer.get_embedding(keywrod)
|
|
|
+ db = next(get_db())
|
|
|
+ # 执行向量搜索
|
|
|
+ results = (
|
|
|
+ db.query(
|
|
|
+ KGNode.id,
|
|
|
+ KGNode.name,
|
|
|
+ KGNode.category,
|
|
|
+ KGNode.embedding.l2_distance(query_embedding).label('distance')
|
|
|
+ )
|
|
|
+ .filter(KGNode.status == 0)
|
|
|
+ .filter(KGNode.category == category)
|
|
|
+ #todo 是否能提高性能 改成余弦算法
|
|
|
+ .filter(KGNode.embedding.l2_distance(query_embedding) <= distance)
|
|
|
+ .order_by('distance').limit(top_k).all()
|
|
|
+ )
|
|
|
+ results = [
|
|
|
+ {
|
|
|
+ "id": node.id,
|
|
|
+ "title": node.name,
|
|
|
+ "text": node.category,
|
|
|
+ "score": 2.0-node.distance
|
|
|
+ }
|
|
|
+ for node in results
|
|
|
+ ]
|
|
|
+
|
|
|
+ self._cache[cache_key] = results
|
|
|
+ return results
|
|
|
+
|
|
|
+ def paginated_search(self, search_params: dict) -> dict:
|
|
|
+ load_props = search_params.get('load_props', False)
|
|
|
+ prop_service = KGPropService(self.db)
|
|
|
+ edge_service = KGEdgeService(self.db)
|
|
|
+ keyword = search_params.get('keyword', '')
|
|
|
+ category = search_params.get('category', None)
|
|
|
+ page_no = search_params.get('pageNo', 1)
|
|
|
+ #distance 为NONE或不存在时,使用默认值
|
|
|
+ if search_params.get('distance') is None:
|
|
|
+ distance = DISTANCE_THRESHOLD
|
|
|
+ else:
|
|
|
+ distance = search_params.get('distance')
|
|
|
+ limit = search_params.get('limit', 10)
|
|
|
+
|
|
|
+ if page_no < 1:
|
|
|
+ page_no = 1
|
|
|
+ if limit < 1:
|
|
|
+ limit = 10
|
|
|
+
|
|
|
+ embedding = Vectorizer.get_embedding(keyword)
|
|
|
+ offset = (page_no - 1) * limit
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 构建基础查询条件
|
|
|
+ base_query = self.db.query(func.count(KGNode.id)).filter(
|
|
|
+ KGNode.status == 0,
|
|
|
+ KGNode.embedding.l2_distance(embedding) <= distance
|
|
|
+ )
|
|
|
+ # 如果有category,则添加额外过滤条件
|
|
|
+ if category:
|
|
|
+ base_query = base_query.filter(KGNode.category == category)
|
|
|
+ # 如果有knowledge_ids,则添加额外过滤条件
|
|
|
+ if search_params.get('knowledge_ids'):
|
|
|
+ total_count = base_query.filter(
|
|
|
+ KGNode.version.in_(search_params['knowledge_ids'])
|
|
|
+ ).scalar()
|
|
|
+ else:
|
|
|
+ total_count = base_query.scalar()
|
|
|
+
|
|
|
+ query = self.db.query(
|
|
|
+ KGNode.id,
|
|
|
+ KGNode.name,
|
|
|
+ KGNode.category,
|
|
|
+ KGNode.embedding.l2_distance(embedding).label('distance')
|
|
|
+ )
|
|
|
+ query = query.filter(KGNode.status == 0)
|
|
|
+ #category有值时,过滤掉category不等于category的节点
|
|
|
+ if category:
|
|
|
+ query = query.filter(KGNode.category == category)
|
|
|
+ if search_params.get('knowledge_ids'):
|
|
|
+ query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
|
|
|
+ query = query.filter(KGNode.embedding.l2_distance(embedding) <= distance)
|
|
|
+ results = query.order_by('distance').offset(offset).limit(limit).all()
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'records': [{
|
|
|
+ 'id': r.id,
|
|
|
+ 'name': r.name,
|
|
|
+ 'category': r.category,
|
|
|
+ 'props': prop_service.get_props_by_ref_id(r.id) if load_props else [],
|
|
|
+ #'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [],
|
|
|
+ 'distance': round(r.distance, 3)
|
|
|
+ } for r in results],
|
|
|
+ 'pagination': {
|
|
|
+ 'total': total_count,
|
|
|
+ 'pageNo': page_no,
|
|
|
+ 'limit': limit,
|
|
|
+ 'totalPages': (total_count + limit - 1) // limit
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"分页查询失败: {str(e)}")
|
|
|
+ raise e
|
|
|
+
|
|
|
+ def create_node(self, node_data: dict):
|
|
|
+ try:
|
|
|
+ existing = self.db.query(KGNode).filter(
|
|
|
+ KGNode.name == node_data['name'],
|
|
|
+ KGNode.category == node_data['category'],
|
|
|
+ KGNode.version == node_data.get('version')
|
|
|
+ ).first()
|
|
|
+
|
|
|
+ if existing:
|
|
|
+ raise ValueError("Node already exists")
|
|
|
+
|
|
|
+ new_node = KGNode(**node_data)
|
|
|
+ self.db.add(new_node)
|
|
|
+ self.db.commit()
|
|
|
+ return new_node
|
|
|
+
|
|
|
+ except IntegrityError as e:
|
|
|
+ self.db.rollback()
|
|
|
+ logger.error(f"创建节点失败: {str(e)}")
|
|
|
+ raise ValueError("Database integrity error")
|
|
|
+
|
|
|
+ def get_node(self, node_id: int):
|
|
|
+ if node_id is None:
|
|
|
+ raise ValueError("Node ID is required")
|
|
|
+
|
|
|
+ cache_key = f"get_node_{node_id}"
|
|
|
+ if cache_key in self._cache:
|
|
|
+ return self._cache[cache_key]
|
|
|
+
|
|
|
+ node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first()
|
|
|
+
|
|
|
+ if not node:
|
|
|
+ raise ValueError("Node not found")
|
|
|
+
|
|
|
+ node_data = {
|
|
|
+ 'id': node.id,
|
|
|
+ 'name': node.name,
|
|
|
+ 'category': node.category,
|
|
|
+ 'version': node.version
|
|
|
+ }
|
|
|
+ self._cache[cache_key] = node_data
|
|
|
+ return node_data
|
|
|
+
|
|
|
+ def update_node(self, node_id: int, update_data: dict):
|
|
|
+ node = self.db.query(KGNode).get(node_id)
|
|
|
+ if not node:
|
|
|
+ raise ValueError("Node not found")
|
|
|
+
|
|
|
+ try:
|
|
|
+ for key, value in update_data.items():
|
|
|
+ setattr(node, key, value)
|
|
|
+ self.db.commit()
|
|
|
+ return node
|
|
|
+ except Exception as e:
|
|
|
+ self.db.rollback()
|
|
|
+ logger.error(f"更新节点失败: {str(e)}")
|
|
|
+ raise ValueError("Update failed")
|
|
|
+
|
|
|
+ def delete_node(self, node_id: int):
|
|
|
+ node = self.db.query(KGNode).get(node_id)
|
|
|
+ if not node:
|
|
|
+ raise ValueError("Node not found")
|
|
|
+
|
|
|
+ try:
|
|
|
+ self.db.delete(node)
|
|
|
+ self.db.commit()
|
|
|
+ return None
|
|
|
+ except Exception as e:
|
|
|
+ self.db.rollback()
|
|
|
+ logger.error(f"删除节点失败: {str(e)}")
|
|
|
+ raise ValueError("Delete failed")
|
|
|
+
|
|
|
+ def batch_process_er_nodes(self):
|
|
|
+ batch_size = 200
|
|
|
+ offset = 0
|
|
|
+
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ #下面的查询语句,增加根据id排序,防止并发问题
|
|
|
+ nodes = self.db.query(KGNode).filter(
|
|
|
+ #KGNode.version == 'ER',
|
|
|
+ KGNode.embedding == None
|
|
|
+ ).order_by(KGNode.id).offset(offset).limit(batch_size).all()
|
|
|
+
|
|
|
+ if not nodes:
|
|
|
+ break
|
|
|
+
|
|
|
+ updated_nodes = []
|
|
|
+ for node in nodes:
|
|
|
+ if not node.embedding:
|
|
|
+ embedding = Vectorizer.get_embedding(node.name)
|
|
|
+ node.embedding = embedding
|
|
|
+ updated_nodes.append(node)
|
|
|
+ if updated_nodes:
|
|
|
+ self.db.commit()
|
|
|
+
|
|
|
+ offset += batch_size
|
|
|
+ except Exception as e:
|
|
|
+ self.db.rollback()
|
|
|
+ print(f"批量处理ER节点失败: {str(e)}")
|
|
|
+ raise ValueError("Batch process failed")
|
|
|
+
|
|
|
+ def get_node_by_name_category(self, name: str, category: str):
|
|
|
+ if not name or not category:
|
|
|
+ raise ValueError("Name and category are required")
|
|
|
+
|
|
|
+ cache_key = f"get_node_by_name_category_{name}:{category}"
|
|
|
+ if cache_key in self._cache:
|
|
|
+ return self._cache[cache_key]
|
|
|
+
|
|
|
+ node = self.db.query(KGNode).filter(
|
|
|
+ KGNode.name == name,
|
|
|
+ KGNode.category == category,
|
|
|
+ KGNode.status == 0
|
|
|
+ ).first()
|
|
|
+
|
|
|
+ if not node:
|
|
|
+ return None
|
|
|
+
|
|
|
+ node_data = {
|
|
|
+ 'id': node.id,
|
|
|
+ 'name': node.name,
|
|
|
+ 'category': node.category,
|
|
|
+ 'version': node.version
|
|
|
+ }
|
|
|
+ self._cache[cache_key] = node_data
|
|
|
+ return node_data
|