from sqlalchemy.orm import Session from typing import Optional from model.kg_node import KGNode from db.session import get_db import logging from sqlalchemy.exc import IntegrityError from utils.vectorizer import Vectorizer from sqlalchemy import func from service.kg_prop_service import KGPropService from service.kg_edge_service import KGEdgeService logger = logging.getLogger(__name__) class KGNodeService: def __init__(self, db: Session): self.db = db def paginated_search(self, search_params: dict) -> dict: load_props = search_params.get('load_props', False) prop_service = KGPropService(self.db) edge_service = KGEdgeService(self.db) keyword = search_params.get('keyword', '') page_no = search_params.get('pageNo', 1) limit = search_params.get('limit', 10) if page_no < 1: page_no = 1 if limit < 1: limit = 10 embedding = Vectorizer.get_embedding(keyword) offset = (page_no - 1) * limit try: total_count = self.db.query(func.count(KGNode.id)).filter(KGNode.version.in_(search_params['knowledge_ids'])).scalar() if search_params.get('knowledge_ids') else self.db.query(func.count(KGNode.id)).scalar() query = self.db.query( KGNode.id, KGNode.name, KGNode.category, KGNode.embedding.l2_distance(embedding).label('distance') ) if search_params.get('knowledge_ids'): query = query.filter(KGNode.version.in_(search_params['knowledge_ids'])) results = query.order_by('distance').offset(offset).limit(limit).all() return { 'records': [{ 'id': r.id, 'name': r.name, 'category': r.category, 'props': prop_service.get_props_by_ref_id(r.id) if load_props else [], #'edges':edge_service.get_edges_by_nodes(r.id, r.id,False) if load_props else [], 'distance': r.distance } for r in results], 'pagination': { 'total': total_count, 'pageNo': page_no, 'limit': limit, 'totalPages': (total_count + limit - 1) // limit } } except Exception as e: logger.error(f"分页查询失败: {str(e)}") raise e def create_node(self, node_data: dict): try: existing = self.db.query(KGNode).filter( KGNode.name == node_data['name'], KGNode.category == node_data['category'], KGNode.version == node_data.get('version') ).first() if existing: raise ValueError("Node already exists") new_node = KGNode(**node_data) self.db.add(new_node) self.db.commit() return new_node except IntegrityError as e: self.db.rollback() logger.error(f"创建节点失败: {str(e)}") raise ValueError("Database integrity error") def get_node(self, node_id: int): node = self.db.query(KGNode).get(node_id) if not node: raise ValueError("Node not found") return { 'id': node.id, 'name': node.name, 'category': node.category, 'version': node.version } def update_node(self, node_id: int, update_data: dict): node = self.db.query(KGNode).get(node_id) if not node: raise ValueError("Node not found") try: for key, value in update_data.items(): setattr(node, key, value) self.db.commit() return node except Exception as e: self.db.rollback() logger.error(f"更新节点失败: {str(e)}") raise ValueError("Update failed") def delete_node(self, node_id: int): node = self.db.query(KGNode).get(node_id) if not node: raise ValueError("Node not found") try: self.db.delete(node) self.db.commit() return None except Exception as e: self.db.rollback() logger.error(f"删除节点失败: {str(e)}") raise ValueError("Delete failed") def batch_process_er_nodes(self): batch_size = 200 offset = 0 while True: try: nodes = self.db.query(KGNode).filter( KGNode.version == 'ER', KGNode.embedding == None ).offset(offset).limit(batch_size).all() if not nodes: break for node in nodes: if not node.embedding: embedding = Vectorizer.get_embedding(node.name) node.embedding = embedding self.db.commit() offset += batch_size except Exception as e: self.db.rollback() logger.error(f"批量处理ER节点失败: {str(e)}") raise ValueError("Batch process failed")