knowledge_saas.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. from fastapi import APIRouter, Depends, HTTPException
  2. from typing import Optional
  3. from pydantic import BaseModel
  4. from model.response import StandardResponse
  5. from db.session import get_db
  6. from sqlalchemy.orm import Session
  7. from service.trunks_service import TrunksService
  8. import logging
  9. router = APIRouter(prefix="/saas", tags=["SaaS Knowledge Base"])
  10. logger = logging.getLogger(__name__)
  11. class PaginatedSearchRequest(BaseModel):
  12. keyword: Optional[str] = None
  13. pageNo: int = 1
  14. limit: int = 10
  15. knowledge_ids: Optional[List[str]] = None
  16. class NodeCreateRequest(BaseModel):
  17. name: str
  18. category: str
  19. layout: Optional[str] = None
  20. version: Optional[str] = None
  21. embedding: Optional[List[float]] = None
  22. class NodeUpdateRequest(BaseModel):
  23. layout: Optional[str] = None
  24. version: Optional[str] = None
  25. status: Optional[int] = None
  26. embedding: Optional[List[float]] = None
  27. @router.post("/paginated_search", response_model=StandardResponse)
  28. async def paginated_search(
  29. payload: PaginatedSearchRequest,
  30. db: Session = Depends(get_db)
  31. ):
  32. try:
  33. service = KGNodeService()
  34. search_params = {
  35. 'keyword': payload.keyword,
  36. 'pageNo': payload.pageNo,
  37. 'limit': payload.limit,
  38. 'knowledge_ids': payload.knowledge_ids
  39. }
  40. return service.paginated_search(search_params)
  41. except Exception as e:
  42. logger.error(f"分页查询失败: {str(e)}")
  43. raise HTTPException(
  44. status_code=500,
  45. detail=StandardResponse(
  46. success=False,
  47. error_code=500,
  48. error_msg=str(e)
  49. )
  50. )
  51. try:
  52. trunks_service = TrunksService()
  53. search_params = {
  54. 'keyword': payload.keyword,
  55. 'pageNo': payload.pageNo,
  56. 'limit': payload.limit,
  57. 'knowledge_ids': payload.knowledge_ids
  58. }
  59. result = trunks_service.paginated_search(search_params)
  60. return StandardResponse(
  61. success=True,
  62. data={
  63. 'records': result['data'],
  64. 'pagination': result['pagination']
  65. }
  66. )
  67. except Exception as e:
  68. logger.error(f"分页查询失败: {str(e)}")
  69. raise HTTPException(
  70. status_code=500,
  71. detail=StandardResponse(
  72. success=False,
  73. error_code=500,
  74. error_msg=str(e)
  75. )
  76. )
  77. @router.post("/nodes", response_model=StandardResponse)
  78. async def create_node(
  79. payload: NodeCreateRequest,
  80. db: Session = Depends(get_db)
  81. ):
  82. try:
  83. service = TrunksService()
  84. return service.create_node(payload.dict())
  85. except Exception as e:
  86. logger.error(f"创建节点失败: {str(e)}")
  87. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  88. @router.get("/nodes/{node_id}", response_model=StandardResponse)
  89. async def get_node(
  90. node_id: int,
  91. db: Session = Depends(get_db)
  92. ):
  93. try:
  94. service = TrunksService()
  95. return service.get_node(node_id)
  96. except Exception as e:
  97. logger.error(f"获取节点失败: {str(e)}")
  98. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  99. @router.put("/nodes/{node_id}", response_model=StandardResponse)
  100. async def update_node(
  101. node_id: int,
  102. payload: NodeUpdateRequest,
  103. db: Session = Depends(get_db)
  104. ):
  105. try:
  106. service = TrunksService()
  107. return service.update_node(node_id, payload.dict(exclude_unset=True))
  108. except Exception as e:
  109. logger.error(f"更新节点失败: {str(e)}")
  110. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  111. @router.delete("/nodes/{node_id}", response_model=StandardResponse)
  112. async def delete_node(
  113. node_id: int,
  114. db: Session = Depends(get_db)
  115. ):
  116. try:
  117. service = TrunksService()
  118. return service.delete_node(node_id)
  119. except Exception as e:
  120. logger.error(f"删除节点失败: {str(e)}")
  121. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  122. saas_kb_router = router