from fastapi import APIRouter, Depends, HTTPException from typing import Optional, List from pydantic import BaseModel from model.response import StandardResponse from db.session import get_db from sqlalchemy.orm import Session from service.kg_node_service import KGNodeService from service.trunks_service import TrunksService import logging from utils.ObjectToJsonArrayConverter import ObjectToJsonArrayConverter router = APIRouter(prefix="/saas", tags=["SaaS Knowledge Base"]) logger = logging.getLogger(__name__) class PaginatedSearchRequest(BaseModel): keyword: Optional[str] = None pageNo: int = 1 limit: int = 10 knowledge_ids: Optional[List[str]] = None class NodePaginatedSearchRequest(BaseModel): name: str appid: str category: Optional[str] = None pageNo: int = 1 limit: int = 10 class NodeCreateRequest(BaseModel): name: str category: str layout: Optional[str] = None version: Optional[str] = None embedding: Optional[List[float]] = None class NodeUpdateRequest(BaseModel): layout: Optional[str] = None version: Optional[str] = None status: Optional[int] = None embedding: Optional[List[float]] = None class VectorSearchRequest(BaseModel): text: str limit: int = 10 type: Optional[str] = None @router.post("/nodes/paginated_search2", response_model=StandardResponse) async def paginated_search( payload: PaginatedSearchRequest, db: Session = Depends(get_db) ): try: service = KGNodeService(db) search_params = { 'keyword': payload.keyword, 'pageNo': payload.pageNo, 'limit': payload.limit, 'knowledge_ids': payload.knowledge_ids, 'load_props': True } result = service.paginated_search(search_params) return StandardResponse( success=True, data={ 'records': result['records'], 'pagination': result['pagination'] } ) except Exception as e: logger.error(f"分页查询失败: {str(e)}") raise HTTPException( status_code=500, detail=StandardResponse( success=False, error_code=500, error_msg=str(e) ) ) @router.post("/nodes/query", response_model=StandardResponse) async def paginated_search( payload: NodePaginatedSearchRequest, db: Session = Depends(get_db) ): try: service = KGNodeService(db) search_params = { 'keyword': payload.name, 'category': payload.category, 'pageNo': payload.pageNo, 'limit': payload.limit, 'load_props': True } result = service.paginated_search(search_params) #result[data]的distance属性去掉 for item in result['records']: item.pop('distance', None) #result[pagination]去掉 result.pop('pagination', None) #result[records]改为result[nodes] result['nodes'] = result['records'] result.pop('records', None) return StandardResponse( success=True, data=ObjectToJsonArrayConverter.convert(result) ) except Exception as e: logger.error(f"分页查询失败: {str(e)}") raise HTTPException( status_code=500, detail=StandardResponse( success=False, error_code=500, error_msg=str(e) ) ) @router.post("/nodes", response_model=StandardResponse) async def create_node( payload: NodeCreateRequest, db: Session = Depends(get_db) ): try: service = TrunksService() result = service.create_node(payload.dict()) return StandardResponse(success=True, data=result) except Exception as e: logger.error(f"创建节点失败: {str(e)}") raise HTTPException(500, detail=StandardResponse.error(str(e))) @router.get("/nodes/{node_id}", response_model=StandardResponse) async def get_node( node_id: int, db: Session = Depends(get_db) ): try: service = TrunksService() result = service.get_node(node_id) return StandardResponse(success=True, data=result) except Exception as e: logger.error(f"获取节点失败: {str(e)}") raise HTTPException(500, detail=StandardResponse.error(str(e))) @router.put("/nodes/{node_id}", response_model=StandardResponse) async def update_node( node_id: int, payload: NodeUpdateRequest, db: Session = Depends(get_db) ): try: service = TrunksService() result = service.update_node(node_id, payload.dict(exclude_unset=True)) return StandardResponse(success=True, data=result) except Exception as e: logger.error(f"更新节点失败: {str(e)}") raise HTTPException(500, detail=StandardResponse.error(str(e))) @router.delete("/nodes/{node_id}", response_model=StandardResponse) async def delete_node( node_id: int, db: Session = Depends(get_db) ): try: service = TrunksService() service.delete_node(node_id) return StandardResponse(success=True) except Exception as e: logger.error(f"删除节点失败: {str(e)}") raise HTTPException(500, detail=StandardResponse.error(str(e))) @router.post('/trunks/vector_search', response_model=StandardResponse) async def vector_search( payload: VectorSearchRequest, db: Session = Depends(get_db) ): try: service = TrunksService() result = service.search_by_vector( payload.text, payload.limit, {'type': payload.type} if payload.type else None ) return StandardResponse(success=True, data=result) except Exception as e: logger.error(f"向量搜索失败: {str(e)}") raise HTTPException(500, detail=StandardResponse.error(str(e))) @router.get('/trunks/{trunk_id}', response_model=StandardResponse) async def get_trunk( trunk_id: int, db: Session = Depends(get_db) ): try: service = TrunksService() result = service.get_trunk_by_id(trunk_id) return StandardResponse(success=True, data=result) except Exception as e: logger.error(f"获取trunk详情失败: {str(e)}") raise HTTPException(500, detail=StandardResponse.error(str(e))) saas_kb_router = router