from fastapi import APIRouter, Depends, HTTPException from typing import Optional, List from pydantic import BaseModel from model.response import StandardResponse from db.session import get_db from sqlalchemy.orm import Session from service.kg_node_service import KGNodeService from service.trunks_service import TrunksService from service.kg_edge_service import KGEdgeService from service.kg_prop_service import KGPropService import logging router = APIRouter(prefix="/saas", tags=["SaaS Knowledge Base"]) logger = logging.getLogger(__name__) class PaginatedSearchRequest(BaseModel): keyword: Optional[str] = None pageNo: int = 1 limit: int = 10 knowledge_ids: Optional[List[str]] = None class NodePaginatedSearchRequest(BaseModel): name: str category: Optional[str] = None pageNo: int = 1 limit: int = 10 class NodeCreateRequest(BaseModel): name: str category: str layout: Optional[str] = None version: Optional[str] = None embedding: Optional[List[float]] = None class NodeUpdateRequest(BaseModel): layout: Optional[str] = None version: Optional[str] = None status: Optional[int] = None embedding: Optional[List[float]] = None class VectorSearchRequest(BaseModel): text: str limit: int = 10 type: Optional[str] = None class NodeRelationshipRequest(BaseModel): src_id: int @router.post("/nodes/paginated_search2", response_model=StandardResponse) async def paginated_search( payload: PaginatedSearchRequest, db: Session = Depends(get_db) ): try: service = KGNodeService(db) search_params = { 'keyword': payload.keyword, 'pageNo': payload.pageNo, 'limit': payload.limit, 'knowledge_ids': payload.knowledge_ids, 'load_props': True } result = service.paginated_search(search_params) return StandardResponse( success=True, data={ 'records': result['records'], 'pagination': result['pagination'] } ) except Exception as e: logger.error(f"分页查询失败: {str(e)}") raise HTTPException( status_code=500, detail=StandardResponse( success=False, error_code=500, error_msg=str(e) ) ) @router.post("/nodes/paginated_search", response_model=StandardResponse) async def paginated_search( payload: NodePaginatedSearchRequest, db: Session = Depends(get_db) ): try: service = KGNodeService(db) search_params = { 'keyword': payload.name, 'category': payload.category, 'pageNo': payload.pageNo, 'limit': payload.limit, 'load_props': True } result = service.paginated_search(search_params) return StandardResponse( success=True, data={ 'records': result['records'], 'pagination': result['pagination'] } ) except Exception as e: logger.error(f"分页查询失败: {str(e)}") raise HTTPException( status_code=500, detail=StandardResponse( success=False, error_code=500, error_msg=str(e) ) ) @router.post("/nodes", response_model=StandardResponse) async def create_node( payload: NodeCreateRequest, db: Session = Depends(get_db) ): try: service = TrunksService() result = service.create_node(payload.dict()) return StandardResponse(success=True, data=result) except Exception as e: logger.error(f"创建节点失败: {str(e)}") raise HTTPException(500, detail=StandardResponse.error(str(e))) @router.get("/nodes/{node_id}", response_model=StandardResponse) async def get_node( node_id: int, db: Session = Depends(get_db) ): try: service = TrunksService() result = service.get_node(node_id) return StandardResponse(success=True, data=result) except Exception as e: logger.error(f"获取节点失败: {str(e)}") raise HTTPException(500, detail=StandardResponse.error(str(e))) @router.put("/nodes/{node_id}", response_model=StandardResponse) async def update_node( node_id: int, payload: NodeUpdateRequest, db: Session = Depends(get_db) ): try: service = TrunksService() result = service.update_node(node_id, payload.dict(exclude_unset=True)) return StandardResponse(success=True, data=result) except Exception as e: logger.error(f"更新节点失败: {str(e)}") raise HTTPException(500, detail=StandardResponse.error(str(e))) @router.delete("/nodes/{node_id}", response_model=StandardResponse) async def delete_node( node_id: int, db: Session = Depends(get_db) ): try: service = TrunksService() service.delete_node(node_id) return StandardResponse(success=True) except Exception as e: logger.error(f"删除节点失败: {str(e)}") raise HTTPException(500, detail=StandardResponse.error(str(e))) @router.post('/trunks/vector_search', response_model=StandardResponse) async def vector_search( payload: VectorSearchRequest, db: Session = Depends(get_db) ): try: service = TrunksService() result = service.search_by_vector( payload.text, payload.limit, {'type': payload.type} if payload.type else None ) return StandardResponse(success=True, data=result) except Exception as e: logger.error(f"向量搜索失败: {str(e)}") raise HTTPException(500, detail=StandardResponse.error(str(e))) @router.get('/trunks/{trunk_id}', response_model=StandardResponse) async def get_trunk( trunk_id: int, db: Session = Depends(get_db) ): try: service = TrunksService() result = service.get_trunk_by_id(trunk_id) return StandardResponse(success=True, data=result) except Exception as e: logger.error(f"获取trunk详情失败: {str(e)}") raise HTTPException(500, detail=StandardResponse.error(str(e))) @router.get("/nodes/{src_id}/relationships", response_model=StandardResponse) async def get_node_relationships( src_id: int, db: Session = Depends(get_db) ): try: edge_service = KGEdgeService(db) prop_service = KGPropService(db) edges = edge_service.get_edges_by_nodes(src_id=src_id, dest_id=None) relationships = [] for edge in edges: dest_node = edge['dest_node'] dest_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']} for p in prop_service.get_props_by_ref_id(dest_node['id'])] edge_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']} for p in prop_service.get_props_by_ref_id(edge['id'])] relationships.append({ "name": edge['name'], "properties": edge_props, "targetNode": { "category": dest_node['category'], "id": str(dest_node['id']), "name": dest_node['name'], "properties": dest_props } }) return StandardResponse( success=True, data={"relationships": relationships} ) except Exception as e: logger.error(f"获取节点关系失败: {str(e)}") raise HTTPException(500, detail=StandardResponse.error(str(e))) saas_kb_router = router