from sqlalchemy.orm import Session from typing import Optional from model.trunks_model import KGNode from db.session import get_db import logging from sqlalchemy.exc import IntegrityError from schema.response import StandardResponse logger = logging.getLogger(__name__) class KGNodeService: def __init__(self): self.db = next(get_db()) def paginated_search(self, search_params: dict) -> StandardResponse: keyword = search_params.get('keyword', '') page_no = search_params.get('pageNo', 1) limit = search_params.get('limit', 10) if page_no < 1: page_no = 1 if limit < 1: limit = 10 embedding = Vectorizer.get_embedding(keyword) offset = (page_no - 1) * limit try: total_count = self.db.query(func.count(KGNode.id)).scalar() query = self.db.query( KGNode.id, KGNode.name, KGNode.category, KGNode.embedding.l2_distance(embedding).label('distance') ) if search_params.get('knowledge_ids'): query = query.filter(KGNode.version.in_(search_params['knowledge_ids'])) results = query.order_by('distance').offset(offset).limit(limit).all() return StandardResponse( success=True, data={ 'records': [{ 'id': r.id, 'name': r.name, 'category': r.category, 'distance': r.distance } for r in results], 'pagination': { 'total': total_count, 'pageNo': page_no, 'limit': limit, 'totalPages': (total_count + limit - 1) // limit } } ) except Exception as e: logger.error(f"分页查询失败: {str(e)}") return StandardResponse( success=False, error_code=500, error_msg=str(e) ) def create_node(self, node_data: dict) -> StandardResponse: try: existing = self.db.query(KGNode).filter( KGNode.name == node_data['name'], KGNode.category == node_data['category'], KGNode.version == node_data.get('version') ).first() if existing: return StandardResponse( success=False, error_code=409, error_msg="Node already exists" ) new_node = KGNode(**node_data) self.db.add(new_node) self.db.commit() return StandardResponse(success=True, data=new_node) except IntegrityError as e: self.db.rollback() logger.error(f"创建节点失败: {str(e)}") return StandardResponse( success=False, error_code=500, error_msg="Database integrity error" ) def get_node(self, node_id: int) -> StandardResponse: node = self.db.query(KGNode).get(node_id) if not node: return StandardResponse( success=False, error_code=404, error_msg="Node not found" ) return StandardResponse(success=True, data=node) def update_node(self, node_id: int, update_data: dict) -> StandardResponse: node = self.db.query(KGNode).get(node_id) if not node: return StandardResponse( success=False, error_code=404, error_msg="Node not found" ) try: for key, value in update_data.items(): setattr(node, key, value) self.db.commit() return StandardResponse(success=True, data=node) except Exception as e: self.db.rollback() logger.error(f"更新节点失败: {str(e)}") return StandardResponse( success=False, error_code=500, error_msg="Update failed" ) def delete_node(self, node_id: int) -> StandardResponse: node = self.db.query(KGNode).get(node_id) if not node: return StandardResponse( success=False, error_code=404, error_msg="Node not found" ) try: self.db.delete(node) self.db.commit() return StandardResponse(success=True) except Exception as e: self.db.rollback() logger.error(f"删除节点失败: {str(e)}") return StandardResponse( success=False, error_code=500, error_msg="Delete failed" )