knowledge_saas.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. from fastapi import APIRouter, Depends, HTTPException
  2. from typing import Optional, List
  3. from pydantic import BaseModel
  4. from model.response import StandardResponse
  5. from db.session import get_db
  6. from sqlalchemy.orm import Session
  7. from service.kg_node_service import KGNodeService
  8. from service.trunks_service import TrunksService
  9. import logging
  10. from utils.ObjectToJsonArrayConverter import ObjectToJsonArrayConverter
  11. router = APIRouter(prefix="/saas", tags=["SaaS Knowledge Base"])
  12. logger = logging.getLogger(__name__)
  13. class PaginatedSearchRequest(BaseModel):
  14. keyword: Optional[str] = None
  15. pageNo: int = 1
  16. limit: int = 10
  17. knowledge_ids: Optional[List[str]] = None
  18. class NodePaginatedSearchRequest(BaseModel):
  19. name: str
  20. appid: str
  21. category: Optional[str] = None
  22. pageNo: int = 1
  23. limit: int = 10
  24. class NodeCreateRequest(BaseModel):
  25. name: str
  26. category: str
  27. layout: Optional[str] = None
  28. version: Optional[str] = None
  29. embedding: Optional[List[float]] = None
  30. class NodeUpdateRequest(BaseModel):
  31. layout: Optional[str] = None
  32. version: Optional[str] = None
  33. status: Optional[int] = None
  34. embedding: Optional[List[float]] = None
  35. class VectorSearchRequest(BaseModel):
  36. text: str
  37. limit: int = 10
  38. type: Optional[str] = None
  39. @router.post("/nodes/paginated_search2", response_model=StandardResponse)
  40. async def paginated_search(
  41. payload: PaginatedSearchRequest,
  42. db: Session = Depends(get_db)
  43. ):
  44. try:
  45. service = KGNodeService(db)
  46. search_params = {
  47. 'keyword': payload.keyword,
  48. 'pageNo': payload.pageNo,
  49. 'limit': payload.limit,
  50. 'knowledge_ids': payload.knowledge_ids,
  51. 'load_props': True
  52. }
  53. result = service.paginated_search(search_params)
  54. return StandardResponse(
  55. success=True,
  56. data={
  57. 'records': result['records'],
  58. 'pagination': result['pagination']
  59. }
  60. )
  61. except Exception as e:
  62. logger.error(f"分页查询失败: {str(e)}")
  63. raise HTTPException(
  64. status_code=500,
  65. detail=StandardResponse(
  66. success=False,
  67. error_code=500,
  68. error_msg=str(e)
  69. )
  70. )
  71. @router.post("/nodes/query", response_model=StandardResponse)
  72. async def paginated_search(
  73. payload: NodePaginatedSearchRequest,
  74. db: Session = Depends(get_db)
  75. ):
  76. try:
  77. service = KGNodeService(db)
  78. search_params = {
  79. 'keyword': payload.name,
  80. 'category': payload.category,
  81. 'pageNo': payload.pageNo,
  82. 'limit': payload.limit,
  83. 'load_props': True
  84. }
  85. result = service.paginated_search(search_params)
  86. #result[data]的distance属性去掉
  87. for item in result['records']:
  88. item.pop('distance', None)
  89. #result[pagination]去掉
  90. result.pop('pagination', None)
  91. #result[records]改为result[nodes]
  92. result['nodes'] = result['records']
  93. result.pop('records', None)
  94. return StandardResponse(
  95. success=True,
  96. data=ObjectToJsonArrayConverter.convert(result)
  97. )
  98. except Exception as e:
  99. logger.error(f"分页查询失败: {str(e)}")
  100. raise HTTPException(
  101. status_code=500,
  102. detail=StandardResponse(
  103. success=False,
  104. error_code=500,
  105. error_msg=str(e)
  106. )
  107. )
  108. @router.post("/nodes", response_model=StandardResponse)
  109. async def create_node(
  110. payload: NodeCreateRequest,
  111. db: Session = Depends(get_db)
  112. ):
  113. try:
  114. service = TrunksService()
  115. result = service.create_node(payload.dict())
  116. return StandardResponse(success=True, data=result)
  117. except Exception as e:
  118. logger.error(f"创建节点失败: {str(e)}")
  119. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  120. @router.get("/nodes/{node_id}", response_model=StandardResponse)
  121. async def get_node(
  122. node_id: int,
  123. db: Session = Depends(get_db)
  124. ):
  125. try:
  126. service = TrunksService()
  127. result = service.get_node(node_id)
  128. return StandardResponse(success=True, data=result)
  129. except Exception as e:
  130. logger.error(f"获取节点失败: {str(e)}")
  131. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  132. @router.put("/nodes/{node_id}", response_model=StandardResponse)
  133. async def update_node(
  134. node_id: int,
  135. payload: NodeUpdateRequest,
  136. db: Session = Depends(get_db)
  137. ):
  138. try:
  139. service = TrunksService()
  140. result = service.update_node(node_id, payload.dict(exclude_unset=True))
  141. return StandardResponse(success=True, data=result)
  142. except Exception as e:
  143. logger.error(f"更新节点失败: {str(e)}")
  144. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  145. @router.delete("/nodes/{node_id}", response_model=StandardResponse)
  146. async def delete_node(
  147. node_id: int,
  148. db: Session = Depends(get_db)
  149. ):
  150. try:
  151. service = TrunksService()
  152. service.delete_node(node_id)
  153. return StandardResponse(success=True)
  154. except Exception as e:
  155. logger.error(f"删除节点失败: {str(e)}")
  156. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  157. @router.post('/trunks/vector_search', response_model=StandardResponse)
  158. async def vector_search(
  159. payload: VectorSearchRequest,
  160. db: Session = Depends(get_db)
  161. ):
  162. try:
  163. service = TrunksService()
  164. result = service.search_by_vector(
  165. payload.text,
  166. payload.limit,
  167. {'type': payload.type} if payload.type else None
  168. )
  169. return StandardResponse(success=True, data=result)
  170. except Exception as e:
  171. logger.error(f"向量搜索失败: {str(e)}")
  172. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  173. @router.get('/trunks/{trunk_id}', response_model=StandardResponse)
  174. async def get_trunk(
  175. trunk_id: int,
  176. db: Session = Depends(get_db)
  177. ):
  178. try:
  179. service = TrunksService()
  180. result = service.get_trunk_by_id(trunk_id)
  181. return StandardResponse(success=True, data=result)
  182. except Exception as e:
  183. logger.error(f"获取trunk详情失败: {str(e)}")
  184. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  185. saas_kb_router = router