knowledge_saas.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from fastapi import APIRouter, Depends, HTTPException
  2. from typing import Optional, List
  3. from pydantic import BaseModel
  4. from model.response import StandardResponse
  5. from db.session import get_db
  6. from sqlalchemy.orm import Session
  7. from service.kg_node_service import KGNodeService
  8. from service.trunks_service import TrunksService
  9. import logging
  10. router = APIRouter(prefix="/saas", tags=["SaaS Knowledge Base"])
  11. logger = logging.getLogger(__name__)
  12. class PaginatedSearchRequest(BaseModel):
  13. keyword: Optional[str] = None
  14. pageNo: int = 1
  15. limit: int = 10
  16. knowledge_ids: Optional[List[str]] = None
  17. class NodeCreateRequest(BaseModel):
  18. name: str
  19. category: str
  20. layout: Optional[str] = None
  21. version: Optional[str] = None
  22. embedding: Optional[List[float]] = None
  23. class NodeUpdateRequest(BaseModel):
  24. layout: Optional[str] = None
  25. version: Optional[str] = None
  26. status: Optional[int] = None
  27. embedding: Optional[List[float]] = None
  28. class VectorSearchRequest(BaseModel):
  29. text: str
  30. limit: int = 10
  31. type: Optional[str] = None
  32. @router.post("/nodes/paginated_search", response_model=StandardResponse)
  33. async def paginated_search(
  34. payload: PaginatedSearchRequest,
  35. db: Session = Depends(get_db)
  36. ):
  37. try:
  38. service = KGNodeService(db)
  39. search_params = {
  40. 'keyword': payload.keyword,
  41. 'pageNo': payload.pageNo,
  42. 'limit': payload.limit,
  43. 'knowledge_ids': payload.knowledge_ids,
  44. 'load_props': True
  45. }
  46. result = service.paginated_search(search_params)
  47. return StandardResponse(
  48. success=True,
  49. data={
  50. 'records': result['records'],
  51. 'pagination': result['pagination']
  52. }
  53. )
  54. except Exception as e:
  55. logger.error(f"分页查询失败: {str(e)}")
  56. raise HTTPException(
  57. status_code=500,
  58. detail=StandardResponse(
  59. success=False,
  60. error_code=500,
  61. error_msg=str(e)
  62. )
  63. )
  64. @router.post("/nodes", response_model=StandardResponse)
  65. async def create_node(
  66. payload: NodeCreateRequest,
  67. db: Session = Depends(get_db)
  68. ):
  69. try:
  70. service = TrunksService()
  71. result = service.create_node(payload.dict())
  72. return StandardResponse(success=True, data=result)
  73. except Exception as e:
  74. logger.error(f"创建节点失败: {str(e)}")
  75. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  76. @router.get("/nodes/{node_id}", response_model=StandardResponse)
  77. async def get_node(
  78. node_id: int,
  79. db: Session = Depends(get_db)
  80. ):
  81. try:
  82. service = TrunksService()
  83. result = service.get_node(node_id)
  84. return StandardResponse(success=True, data=result)
  85. except Exception as e:
  86. logger.error(f"获取节点失败: {str(e)}")
  87. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  88. @router.put("/nodes/{node_id}", response_model=StandardResponse)
  89. async def update_node(
  90. node_id: int,
  91. payload: NodeUpdateRequest,
  92. db: Session = Depends(get_db)
  93. ):
  94. try:
  95. service = TrunksService()
  96. result = service.update_node(node_id, payload.dict(exclude_unset=True))
  97. return StandardResponse(success=True, data=result)
  98. except Exception as e:
  99. logger.error(f"更新节点失败: {str(e)}")
  100. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  101. @router.delete("/nodes/{node_id}", response_model=StandardResponse)
  102. async def delete_node(
  103. node_id: int,
  104. db: Session = Depends(get_db)
  105. ):
  106. try:
  107. service = TrunksService()
  108. service.delete_node(node_id)
  109. return StandardResponse(success=True)
  110. except Exception as e:
  111. logger.error(f"删除节点失败: {str(e)}")
  112. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  113. @router.post('/trunks/vector_search', response_model=StandardResponse)
  114. async def vector_search(
  115. payload: VectorSearchRequest,
  116. db: Session = Depends(get_db)
  117. ):
  118. try:
  119. service = TrunksService()
  120. result = service.search_by_vector(
  121. payload.text,
  122. payload.limit,
  123. {'type': payload.type} if payload.type else None
  124. )
  125. return StandardResponse(success=True, data=result)
  126. except Exception as e:
  127. logger.error(f"向量搜索失败: {str(e)}")
  128. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  129. @router.get('/trunks/{trunk_id}', response_model=StandardResponse)
  130. async def get_trunk(
  131. trunk_id: int,
  132. db: Session = Depends(get_db)
  133. ):
  134. try:
  135. service = TrunksService()
  136. result = service.get_trunk_by_id(trunk_id)
  137. return StandardResponse(success=True, data=result)
  138. except Exception as e:
  139. logger.error(f"获取trunk详情失败: {str(e)}")
  140. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  141. saas_kb_router = router