123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- from fastapi import APIRouter, Depends, HTTPException
- from typing import Optional, List
- from pydantic import BaseModel
- from model.response import StandardResponse
- from db.session import get_db
- from sqlalchemy.orm import Session
- from service.kg_node_service import KGNodeService
- from service.trunks_service import TrunksService
- import logging
- router = APIRouter(prefix="/saas", tags=["SaaS Knowledge Base"])
- logger = logging.getLogger(__name__)
- class PaginatedSearchRequest(BaseModel):
- keyword: Optional[str] = None
- pageNo: int = 1
- limit: int = 10
- knowledge_ids: Optional[List[str]] = None
- class NodeCreateRequest(BaseModel):
- name: str
- category: str
- layout: Optional[str] = None
- version: Optional[str] = None
- embedding: Optional[List[float]] = None
- class NodeUpdateRequest(BaseModel):
- layout: Optional[str] = None
- version: Optional[str] = None
- status: Optional[int] = None
- embedding: Optional[List[float]] = None
- class VectorSearchRequest(BaseModel):
- text: str
- limit: int = 10
- type: Optional[str] = None
- @router.post("/nodes/paginated_search", response_model=StandardResponse)
- async def paginated_search(
- payload: PaginatedSearchRequest,
- db: Session = Depends(get_db)
- ):
- try:
- service = KGNodeService(db)
- search_params = {
- 'keyword': payload.keyword,
- 'pageNo': payload.pageNo,
- 'limit': payload.limit,
- 'knowledge_ids': payload.knowledge_ids,
- 'load_props': True
- }
- result = service.paginated_search(search_params)
- return StandardResponse(
- success=True,
- data={
- 'records': result['records'],
- 'pagination': result['pagination']
- }
- )
- except Exception as e:
- logger.error(f"分页查询失败: {str(e)}")
- raise HTTPException(
- status_code=500,
- detail=StandardResponse(
- success=False,
- error_code=500,
- error_msg=str(e)
- )
- )
- @router.post("/nodes", response_model=StandardResponse)
- async def create_node(
- payload: NodeCreateRequest,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.create_node(payload.dict())
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"创建节点失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.get("/nodes/{node_id}", response_model=StandardResponse)
- async def get_node(
- node_id: int,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.get_node(node_id)
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"获取节点失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.put("/nodes/{node_id}", response_model=StandardResponse)
- async def update_node(
- node_id: int,
- payload: NodeUpdateRequest,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.update_node(node_id, payload.dict(exclude_unset=True))
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"更新节点失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.delete("/nodes/{node_id}", response_model=StandardResponse)
- async def delete_node(
- node_id: int,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- service.delete_node(node_id)
- return StandardResponse(success=True)
- except Exception as e:
- logger.error(f"删除节点失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.post('/trunks/vector_search', response_model=StandardResponse)
- async def vector_search(
- payload: VectorSearchRequest,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.search_by_vector(
- payload.text,
- payload.limit,
- {'type': payload.type} if payload.type else None
- )
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"向量搜索失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- @router.get('/trunks/{trunk_id}', response_model=StandardResponse)
- async def get_trunk(
- trunk_id: int,
- db: Session = Depends(get_db)
- ):
- try:
- service = TrunksService()
- result = service.get_trunk_by_id(trunk_id)
- return StandardResponse(success=True, data=result)
- except Exception as e:
- logger.error(f"获取trunk详情失败: {str(e)}")
- raise HTTPException(500, detail=StandardResponse.error(str(e)))
- saas_kb_router = router
|