123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- from sqlalchemy.orm import Session
- from typing import Optional
- from model.trunks_model import KGNode
- from db.session import get_db
- import logging
- from sqlalchemy.exc import IntegrityError
- from schema.response import StandardResponse
- logger = logging.getLogger(__name__)
- class KGNodeService:
- def __init__(self):
- self.db = next(get_db())
- def paginated_search(self, search_params: dict) -> StandardResponse:
- keyword = search_params.get('keyword', '')
- page_no = search_params.get('pageNo', 1)
- limit = search_params.get('limit', 10)
- if page_no < 1:
- page_no = 1
- if limit < 1:
- limit = 10
- embedding = Vectorizer.get_embedding(keyword)
- offset = (page_no - 1) * limit
- try:
- total_count = self.db.query(func.count(KGNode.id)).scalar()
- query = self.db.query(
- KGNode.id,
- KGNode.name,
- KGNode.category,
- KGNode.embedding.l2_distance(embedding).label('distance')
- )
- if search_params.get('knowledge_ids'):
- query = query.filter(KGNode.version.in_(search_params['knowledge_ids']))
- results = query.order_by('distance').offset(offset).limit(limit).all()
- return StandardResponse(
- success=True,
- data={
- 'records': [{
- 'id': r.id,
- 'name': r.name,
- 'category': r.category,
- 'distance': r.distance
- } for r in results],
- 'pagination': {
- 'total': total_count,
- 'pageNo': page_no,
- 'limit': limit,
- 'totalPages': (total_count + limit - 1) // limit
- }
- }
- )
- except Exception as e:
- logger.error(f"分页查询失败: {str(e)}")
- return StandardResponse(
- success=False,
- error_code=500,
- error_msg=str(e)
- )
- def create_node(self, node_data: dict) -> StandardResponse:
- try:
- existing = self.db.query(KGNode).filter(
- KGNode.name == node_data['name'],
- KGNode.category == node_data['category'],
- KGNode.version == node_data.get('version')
- ).first()
-
- if existing:
- return StandardResponse(
- success=False,
- error_code=409,
- error_msg="Node already exists"
- )
- new_node = KGNode(**node_data)
- self.db.add(new_node)
- self.db.commit()
- return StandardResponse(success=True, data=new_node)
- except IntegrityError as e:
- self.db.rollback()
- logger.error(f"创建节点失败: {str(e)}")
- return StandardResponse(
- success=False,
- error_code=500,
- error_msg="Database integrity error"
- )
- def get_node(self, node_id: int) -> StandardResponse:
- node = self.db.query(KGNode).get(node_id)
- if not node:
- return StandardResponse(
- success=False,
- error_code=404,
- error_msg="Node not found"
- )
- return StandardResponse(success=True, data=node)
- def update_node(self, node_id: int, update_data: dict) -> StandardResponse:
- node = self.db.query(KGNode).get(node_id)
- if not node:
- return StandardResponse(
- success=False,
- error_code=404,
- error_msg="Node not found"
- )
- try:
- for key, value in update_data.items():
- setattr(node, key, value)
- self.db.commit()
- return StandardResponse(success=True, data=node)
- except Exception as e:
- self.db.rollback()
- logger.error(f"更新节点失败: {str(e)}")
- return StandardResponse(
- success=False,
- error_code=500,
- error_msg="Update failed"
- )
- def delete_node(self, node_id: int) -> StandardResponse:
- node = self.db.query(KGNode).get(node_id)
- if not node:
- return StandardResponse(
- success=False,
- error_code=404,
- error_msg="Node not found"
- )
- try:
- self.db.delete(node)
- self.db.commit()
- return StandardResponse(success=True)
- except Exception as e:
- self.db.rollback()
- logger.error(f"删除节点失败: {str(e)}")
- return StandardResponse(
- success=False,
- error_code=500,
- error_msg="Delete failed"
- )
|