knowledge_nodes_api.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from fastapi import APIRouter, Depends, HTTPException, Request
  2. from typing import Optional, List
  3. from pydantic import BaseModel
  4. from model.response import StandardResponse
  5. from db.session import get_db
  6. from sqlalchemy.orm import Session
  7. from service.kg_node_service import KGNodeService
  8. from service.kg_edge_service import KGEdgeService
  9. from service.kg_prop_service import KGPropService
  10. import logging
  11. from utils.ObjectToJsonArrayConverter import ObjectToJsonArrayConverter
  12. router = APIRouter(prefix="/v1/knowledge", tags=["SaaS Knowledge Base"])
  13. logger = logging.getLogger(__name__)
  14. class PaginatedSearchRequest(BaseModel):
  15. name: str
  16. distance: float = 1.45
  17. category: Optional[str] = None
  18. pageNo: int = 1
  19. limit: int = 10
  20. async def get_request_id(request: Request):
  21. return request.state.context["request_id"]
  22. @router.post("/nodes/paginated_search", response_model=StandardResponse)
  23. async def paginated_search(
  24. payload: PaginatedSearchRequest,
  25. db: Session = Depends(get_db),
  26. request_id: str = Depends(get_request_id)
  27. ):
  28. try:
  29. service = KGNodeService(db)
  30. search_params = {
  31. 'keyword': payload.name,
  32. 'category': payload.category,
  33. 'pageNo': payload.pageNo,
  34. 'limit': payload.limit,
  35. 'load_props': True,
  36. 'distance': 2.0-payload.distance,
  37. }
  38. result = service.paginated_search(search_params)
  39. #result[data]的distance属性去掉
  40. for item in result['records']:
  41. item.pop('distance', None)
  42. #result[pagination]去掉
  43. result.pop('pagination', None)
  44. #result[records]改为result[nodes]
  45. result['nodes'] = result['records']
  46. result.pop('records', None)
  47. return StandardResponse(
  48. success=True,
  49. requestId=request_id,
  50. data=ObjectToJsonArrayConverter.convert(result)
  51. )
  52. except Exception as e:
  53. logger.error(f"分页查询失败: {str(e)}")
  54. raise HTTPException(
  55. status_code=500,
  56. detail=StandardResponse(
  57. success=False,
  58. error_code=500,
  59. error_msg=str(e)
  60. )
  61. )
  62. @router.get("/nodes/{src_id}/relationships", response_model=StandardResponse)
  63. async def get_node_relationships(
  64. src_id: int,
  65. db: Session = Depends(get_db),
  66. request_id: str = Depends(get_request_id)
  67. ):
  68. try:
  69. edge_service = KGEdgeService(db)
  70. prop_service = KGPropService(db)
  71. edges = edge_service.get_edges_by_nodes(src_id=src_id, dest_id=None)
  72. relationships = []
  73. #count = 0
  74. for edge in edges:
  75. #if count >= 2:
  76. #break
  77. dest_node = edge['dest_node']
  78. dest_props = []
  79. edge_props = []
  80. #count += 1
  81. #dest_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']}
  82. # for p in prop_service.get_props_by_ref_id(dest_node['id'])]
  83. #edge_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']}
  84. # for p in prop_service.get_props_by_ref_id(edge['id'])]
  85. relationships.append({
  86. "name": edge['name'],
  87. "props": edge_props,
  88. "destNode": {
  89. "category": dest_node['category'],
  90. "id": str(dest_node['id']),
  91. "name": dest_node['name'],
  92. "props": dest_props
  93. }
  94. })
  95. return StandardResponse(
  96. success=True,
  97. requestId=request_id,
  98. data=ObjectToJsonArrayConverter.convert({"relationships": relationships})
  99. )
  100. except Exception as e:
  101. logger.error(f"获取节点关系失败: {str(e)}")
  102. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  103. knowledge_nodes_api_router = router