knowledge_saas.py 7.6 KB


  1. from fastapi import APIRouter, Depends, HTTPException
  2. from typing import Optional, List
  3. from pydantic import BaseModel
  4. from model.response import StandardResponse
  5. from db.session import get_db
  6. from sqlalchemy.orm import Session
  7. from service.kg_node_service import KGNodeService
  8. from service.trunks_service import TrunksService
  9. from service.kg_edge_service import KGEdgeService
  10. from service.kg_prop_service import KGPropService
  11. import logging
  12. router = APIRouter(prefix="/saas", tags=["SaaS Knowledge Base"])
  13. logger = logging.getLogger(__name__)
  14. class PaginatedSearchRequest(BaseModel):
  15. keyword: Optional[str] = None
  16. pageNo: int = 1
  17. limit: int = 10
  18. knowledge_ids: Optional[List[str]] = None
  19. class NodePaginatedSearchRequest(BaseModel):
  20. name: str
  21. category: Optional[str] = None
  22. pageNo: int = 1
  23. limit: int = 10
  24. class NodeCreateRequest(BaseModel):
  25. name: str
  26. category: str
  27. layout: Optional[str] = None
  28. version: Optional[str] = None
  29. embedding: Optional[List[float]] = None
  30. class NodeUpdateRequest(BaseModel):
  31. layout: Optional[str] = None
  32. version: Optional[str] = None
  33. status: Optional[int] = None
  34. embedding: Optional[List[float]] = None
  35. class VectorSearchRequest(BaseModel):
  36. text: str
  37. limit: int = 10
  38. type: Optional[str] = None
  39. class NodeRelationshipRequest(BaseModel):
  40. src_id: int
  41. @router.post("/nodes/paginated_search2", response_model=StandardResponse)
  42. async def paginated_search(
  43. payload: PaginatedSearchRequest,
  44. db: Session = Depends(get_db)
  45. ):
  46. try:
  47. service = KGNodeService(db)
  48. search_params = {
  49. 'keyword': payload.keyword,
  50. 'pageNo': payload.pageNo,
  51. 'limit': payload.limit,
  52. 'knowledge_ids': payload.knowledge_ids,
  53. 'load_props': True
  54. }
  55. result = service.paginated_search(search_params)
  56. return StandardResponse(
  57. success=True,
  58. data={
  59. 'records': result['records'],
  60. 'pagination': result['pagination']
  61. }
  62. )
  63. except Exception as e:
  64. logger.error(f"分页查询失败: {str(e)}")
  65. raise HTTPException(
  66. status_code=500,
  67. detail=StandardResponse(
  68. success=False,
  69. error_code=500,
  70. error_msg=str(e)
  71. )
  72. )
  73. @router.post("/nodes/paginated_search", response_model=StandardResponse)
  74. async def paginated_search(
  75. payload: NodePaginatedSearchRequest,
  76. db: Session = Depends(get_db)
  77. ):
  78. try:
  79. service = KGNodeService(db)
  80. search_params = {
  81. 'keyword': payload.name,
  82. 'category': payload.category,
  83. 'pageNo': payload.pageNo,
  84. 'limit': payload.limit,
  85. 'load_props': True
  86. }
  87. result = service.paginated_search(search_params)
  88. return StandardResponse(
  89. success=True,
  90. data={
  91. 'records': result['records'],
  92. 'pagination': result['pagination']
  93. }
  94. )
  95. except Exception as e:
  96. logger.error(f"分页查询失败: {str(e)}")
  97. raise HTTPException(
  98. status_code=500,
  99. detail=StandardResponse(
  100. success=False,
  101. error_code=500,
  102. error_msg=str(e)
  103. )
  104. )
  105. @router.post("/nodes", response_model=StandardResponse)
  106. async def create_node(
  107. payload: NodeCreateRequest,
  108. db: Session = Depends(get_db)
  109. ):
  110. try:
  111. service = TrunksService()
  112. result = service.create_node(payload.dict())
  113. return StandardResponse(success=True, data=result)
  114. except Exception as e:
  115. logger.error(f"创建节点失败: {str(e)}")
  116. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  117. @router.get("/nodes/{node_id}", response_model=StandardResponse)
  118. async def get_node(
  119. node_id: int,
  120. db: Session = Depends(get_db)
  121. ):
  122. try:
  123. service = TrunksService()
  124. result = service.get_node(node_id)
  125. return StandardResponse(success=True, data=result)
  126. except Exception as e:
  127. logger.error(f"获取节点失败: {str(e)}")
  128. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  129. @router.put("/nodes/{node_id}", response_model=StandardResponse)
  130. async def update_node(
  131. node_id: int,
  132. payload: NodeUpdateRequest,
  133. db: Session = Depends(get_db)
  134. ):
  135. try:
  136. service = TrunksService()
  137. result = service.update_node(node_id, payload.dict(exclude_unset=True))
  138. return StandardResponse(success=True, data=result)
  139. except Exception as e:
  140. logger.error(f"更新节点失败: {str(e)}")
  141. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  142. @router.delete("/nodes/{node_id}", response_model=StandardResponse)
  143. async def delete_node(
  144. node_id: int,
  145. db: Session = Depends(get_db)
  146. ):
  147. try:
  148. service = TrunksService()
  149. service.delete_node(node_id)
  150. return StandardResponse(success=True)
  151. except Exception as e:
  152. logger.error(f"删除节点失败: {str(e)}")
  153. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  154. @router.post('/trunks/vector_search', response_model=StandardResponse)
  155. async def vector_search(
  156. payload: VectorSearchRequest,
  157. db: Session = Depends(get_db)
  158. ):
  159. try:
  160. service = TrunksService()
  161. result = service.search_by_vector(
  162. payload.text,
  163. payload.limit,
  164. {'type': payload.type} if payload.type else None
  165. )
  166. return StandardResponse(success=True, data=result)
  167. except Exception as e:
  168. logger.error(f"向量搜索失败: {str(e)}")
  169. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  170. @router.get('/trunks/{trunk_id}', response_model=StandardResponse)
  171. async def get_trunk(
  172. trunk_id: int,
  173. db: Session = Depends(get_db)
  174. ):
  175. try:
  176. service = TrunksService()
  177. result = service.get_trunk_by_id(trunk_id)
  178. return StandardResponse(success=True, data=result)
  179. except Exception as e:
  180. logger.error(f"获取trunk详情失败: {str(e)}")
  181. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  182. @router.get("/nodes/{src_id}/relationships", response_model=StandardResponse)
  183. async def get_node_relationships(
  184. src_id: int,
  185. db: Session = Depends(get_db)
  186. ):
  187. try:
  188. edge_service = KGEdgeService(db)
  189. prop_service = KGPropService(db)
  190. edges = edge_service.get_edges_by_nodes(src_id=src_id, dest_id=None)
  191. relationships = []
  192. for edge in edges:
  193. dest_node = edge['dest_node']
  194. dest_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']}
  195. for p in prop_service.get_props_by_ref_id(dest_node['id'])]
  196. edge_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']}
  197. for p in prop_service.get_props_by_ref_id(edge['id'])]
  198. relationships.append({
  199. "name": edge['name'],
  200. "properties": edge_props,
  201. "targetNode": {
  202. "category": dest_node['category'],
  203. "id": str(dest_node['id']),
  204. "name": dest_node['name'],
  205. "properties": dest_props
  206. }
  207. })
  208. return StandardResponse(
  209. success=True,
  210. data={"relationships": relationships}
  211. )
  212. except Exception as e:
  213. logger.error(f"获取节点关系失败: {str(e)}")
  214. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  215. saas_kb_router = router