knowledge_saas.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from fastapi import APIRouter, Depends, HTTPException
  2. from typing import Optional, List
  3. from pydantic import BaseModel
  4. from model.response import StandardResponse
  5. from db.session import get_db
  6. from sqlalchemy.orm import Session
  7. from service.kg_node_service import KGNodeService
  8. from service.trunks_service import TrunksService
  9. import logging
  10. router = APIRouter(prefix="/saas", tags=["SaaS Knowledge Base"])
  11. logger = logging.getLogger(__name__)
  12. class PaginatedSearchRequest(BaseModel):
  13. keyword: Optional[str] = None
  14. pageNo: int = 1
  15. limit: int = 10
  16. knowledge_ids: Optional[List[str]] = None
  17. class NodeCreateRequest(BaseModel):
  18. name: str
  19. category: str
  20. layout: Optional[str] = None
  21. version: Optional[str] = None
  22. embedding: Optional[List[float]] = None
  23. class NodeUpdateRequest(BaseModel):
  24. layout: Optional[str] = None
  25. version: Optional[str] = None
  26. status: Optional[int] = None
  27. embedding: Optional[List[float]] = None
  28. @router.post("/paginated_search", response_model=StandardResponse)
  29. async def paginated_search(
  30. payload: PaginatedSearchRequest,
  31. db: Session = Depends(get_db)
  32. ):
  33. try:
  34. service = KGNodeService(db)
  35. search_params = {
  36. 'keyword': payload.keyword,
  37. 'pageNo': payload.pageNo,
  38. 'limit': payload.limit,
  39. 'knowledge_ids': payload.knowledge_ids,
  40. 'load_props': True
  41. }
  42. result = service.paginated_search(search_params)
  43. return StandardResponse(
  44. success=True,
  45. data={
  46. 'records': result['records'],
  47. 'pagination': result['pagination']
  48. }
  49. )
  50. except Exception as e:
  51. logger.error(f"分页查询失败: {str(e)}")
  52. raise HTTPException(
  53. status_code=500,
  54. detail=StandardResponse(
  55. success=False,
  56. error_code=500,
  57. error_msg=str(e)
  58. )
  59. )
  60. @router.post("/nodes", response_model=StandardResponse)
  61. async def create_node(
  62. payload: NodeCreateRequest,
  63. db: Session = Depends(get_db)
  64. ):
  65. try:
  66. service = TrunksService()
  67. result = service.create_node(payload.dict())
  68. return StandardResponse(success=True, data=result)
  69. except Exception as e:
  70. logger.error(f"创建节点失败: {str(e)}")
  71. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  72. @router.get("/nodes/{node_id}", response_model=StandardResponse)
  73. async def get_node(
  74. node_id: int,
  75. db: Session = Depends(get_db)
  76. ):
  77. try:
  78. service = TrunksService()
  79. result = service.get_node(node_id)
  80. return StandardResponse(success=True, data=result)
  81. except Exception as e:
  82. logger.error(f"获取节点失败: {str(e)}")
  83. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  84. @router.put("/nodes/{node_id}", response_model=StandardResponse)
  85. async def update_node(
  86. node_id: int,
  87. payload: NodeUpdateRequest,
  88. db: Session = Depends(get_db)
  89. ):
  90. try:
  91. service = TrunksService()
  92. result = service.update_node(node_id, payload.dict(exclude_unset=True))
  93. return StandardResponse(success=True, data=result)
  94. except Exception as e:
  95. logger.error(f"更新节点失败: {str(e)}")
  96. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  97. @router.delete("/nodes/{node_id}", response_model=StandardResponse)
  98. async def delete_node(
  99. node_id: int,
  100. db: Session = Depends(get_db)
  101. ):
  102. try:
  103. service = TrunksService()
  104. service.delete_node(node_id)
  105. return StandardResponse(success=True)
  106. except Exception as e:
  107. logger.error(f"删除节点失败: {str(e)}")
  108. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  109. saas_kb_router = router