from sqlalchemy.orm import Session from typing import Optional from model.kg_node import KGNode from db.session import get_db import logging from sqlalchemy.exc import IntegrityError from utils import vectorizer from utils.vectorizer import Vectorizer from sqlalchemy import func from service.kg_prop_service import KGPropService from service.kg_edge_service import KGEdgeService from cachetools import TTLCache from cachetools.keys import hashkey logger = logging.getLogger(__name__) DISTANCE_THRESHOLD = 0.65 class KGNodeService: def __init__(self, db: Session): self.db = db _cache = TTLCache(maxsize=100000, ttl=60*60*24*30) def search_title_index(self, index: str, keywrod: str,category: str, top_k: int = 3,distance: float = 0.3) -> Optional[int]: cache_key = f"search_title_index_{index}:{keywrod}:{category}:{top_k}:{distance}" if cache_key in self._cache: return self._cache[cache_key] query_embedding = Vectorizer.get_embedding(keywrod) db = next(get_db()) # 执行向量搜索 results = ( db.query( KGNode.id, KGNode.name, KGNode.category, KGNode.embedding.l2_distance(query_embedding).label('distance') ) .filter(KGNode.status == 0) .filter(KGNode.category == category) #todo 是否能提高性能 改成余弦算法 .filter(KGNode.embedding.l2_distance(query_embedding) <= distance) .order_by('distance').limit(top_k).all() ) results = [ { "id": node.id, "title": node.name, "text": node.category, "score": 2.0-node.distance } for node in results ] self._cache[cache_key] = results return results def paginated_search(self, search_params: dict) -> dict: load_props = search_params.get('load_props', False) prop_service = KGPropService(self.db) edge_service = KGEdgeService(self.db) keyword = search_params.get('keyword', '') category = search_params.get('category', None) page_no = search_params.get('pageNo', 1) distance = search_params.get('distance',DISTANCE_THRESHOLD) limit = search_params.get('limit', 10) if page_no < 1: page_no = 1 if limit < 1: limit = 10 embedding = Vectorizer.get_embedding(keyword) offset = (page_no - 1) * limit try: # 构建基础查询条件 base_query = self.db.query(func.count(KGNode.id)).filter( KGNode.status == 0, KGNode.embedding.l2_distance(embedding) < distance ) # 如果有category,则添加额外过滤条件 if category: base_query = base_query.filter(KGNode.category == category) # 如果有knowledge_ids,则添加额外过滤条件 if search_params.get('knowledge_ids'): total_count = base_query.filter( KGNode.version.in_(search_params['knowledge_ids']) ).scalar() else: total_count = base_query.scalar() query = self.db.query( KGNode.id, KGNode.name, KGNode.category, KGNode.embedding.l2_distance(embedding).label('distance') ) query = query.filter(KGNode.status == 0) #category有值时,过滤掉category不等于category的节点 if category: query = query.filter(KGNode.category == category) if search_params.get('knowledge_ids'): query = query.filter(KGNode.version.in_(search_params['knowledge_ids'])) query = query.filter(KGNode.embedding.l2_distance(embedding) < distance) results = query.order_by('distance').offset(offset).limit(limit).all() #将results相同distance的category=疾病的放在前面 results = sorted(results, key=lambda x: (x.distance, not x.category == '疾病')) return { 'records': [{ 'id': r.id, 'name': r.name, 'category': r.category, 'props': prop_service.get_props_by_ref_id(r.id) if load_props else [], #'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [], 'distance': round(r.distance, 3) } for r in results], 'pagination': { 'total': total_count, 'pageNo': page_no, 'limit': limit, 'totalPages': (total_count + limit - 1) // limit } } except Exception as e: logger.error(f"分页查询失败: {str(e)}") raise e def create_node(self, node_data: dict): try: existing = self.db.query(KGNode).filter( KGNode.name == node_data['name'], KGNode.category == node_data['category'], KGNode.version == node_data.get('version') ).first() if existing: raise ValueError("Node already exists") new_node = KGNode(**node_data) self.db.add(new_node) self.db.commit() return new_node except IntegrityError as e: self.db.rollback() logger.error(f"创建节点失败: {str(e)}") raise ValueError("Database integrity error") def get_node(self, node_id: int): if node_id is None: raise ValueError("Node ID is required") cache_key = f"get_node_{node_id}" if cache_key in self._cache: return self._cache[cache_key] node = self.db.query(KGNode).filter(KGNode.id == node_id, KGNode.status == 0).first() if not node: raise ValueError("Node not found") node_data = { 'id': node.id, 'name': node.name, 'category': node.category, 'version': node.version } self._cache[cache_key] = node_data return node_data def update_node(self, node_id: int, update_data: dict): node = self.db.query(KGNode).get(node_id) if not node: raise ValueError("Node not found") try: for key, value in update_data.items(): setattr(node, key, value) self.db.commit() return node except Exception as e: self.db.rollback() logger.error(f"更新节点失败: {str(e)}") raise ValueError("Update failed") def delete_node(self, node_id: int): node = self.db.query(KGNode).get(node_id) if not node: raise ValueError("Node not found") try: self.db.delete(node) self.db.commit() return None except Exception as e: self.db.rollback() logger.error(f"删除节点失败: {str(e)}") raise ValueError("Delete failed") def batch_process_er_nodes(self): batch_size = 200 offset = 0 while True: try: #下面的查询语句,增加根据id排序,防止并发问题 nodes = self.db.query(KGNode).filter( #KGNode.version == 'ER', KGNode.embedding == None ).order_by(KGNode.id).offset(offset).limit(batch_size).all() if not nodes: break updated_nodes = [] for node in nodes: if not node.embedding: embedding = Vectorizer.get_embedding(node.name) node.embedding = embedding updated_nodes.append(node) if updated_nodes: self.db.commit() offset += batch_size except Exception as e: self.db.rollback() print(f"批量处理ER节点失败: {str(e)}") raise ValueError("Batch process failed")