nodes_api.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from fastapi import APIRouter, Depends, HTTPException, Request
  2. from typing import Optional, List
  3. from pydantic import BaseModel
  4. from model.response import StandardResponse
  5. from db.session import get_db
  6. from sqlalchemy.orm import Session
  7. from service.kg_node_service import KGNodeService
  8. from service.kg_edge_service import KGEdgeService
  9. from service.kg_prop_service import KGPropService
  10. import logging
  11. from utils.ObjectToJsonArrayConverter import ObjectToJsonArrayConverter
  12. router = APIRouter(prefix="/v1/knowledge", tags=["SaaS Knowledge Base"])
  13. logger = logging.getLogger(__name__)
  14. class PaginatedSearchRequest(BaseModel):
  15. name: str
  16. distance: float = 1.45
  17. category: Optional[str] = None
  18. pageNo: int = 1
  19. limit: int = 10
  20. class GetNodeRelationshipsRequest(BaseModel):
  21. relationName: str
  22. async def get_request_id(request: Request):
  23. return request.state.context["request_id"]
  24. @router.post("/nodes/paginated_search", response_model=StandardResponse)
  25. async def paginated_search(
  26. payload: PaginatedSearchRequest,
  27. db: Session = Depends(get_db),
  28. request_id: str = Depends(get_request_id)
  29. ):
  30. try:
  31. service = KGNodeService(db)
  32. search_params = {
  33. 'keyword': payload.name,
  34. 'category': payload.category,
  35. 'pageNo': payload.pageNo,
  36. 'limit': payload.limit,
  37. 'load_props': True,
  38. 'distance': 2.0-payload.distance,
  39. }
  40. result = service.paginated_search(search_params)
  41. #result[data]的distance属性去掉
  42. for item in result['records']:
  43. item.pop('distance', None)
  44. #result[pagination]去掉
  45. result.pop('pagination', None)
  46. #result[records]改为result[nodes]
  47. result['nodes'] = result['records']
  48. result.pop('records', None)
  49. return StandardResponse(
  50. success=True,
  51. requestId=request_id,
  52. data=ObjectToJsonArrayConverter.convert(result)
  53. )
  54. except Exception as e:
  55. logger.error(f"分页查询失败: {str(e)}")
  56. raise HTTPException(
  57. status_code=500,
  58. detail=StandardResponse(
  59. success=False,
  60. error_code=500,
  61. error_msg=str(e)
  62. )
  63. )
  64. @router.post("/nodes/{src_id}/relationships", response_model=StandardResponse)
  65. async def get_node_relationships(
  66. src_id: int,
  67. payload: GetNodeRelationshipsRequest,
  68. db: Session = Depends(get_db),
  69. request_id: str = Depends(get_request_id)
  70. ):
  71. try:
  72. edge_service = KGEdgeService(db)
  73. prop_service = KGPropService(db)
  74. edges = edge_service.get_edges_by_nodes(src_id=src_id, dest_id=None)
  75. relationships = []
  76. #count = 0
  77. for edge in edges:
  78. #if count >= 2:
  79. #break
  80. dest_node = edge['dest_node']
  81. dest_props = []
  82. edge_props = []
  83. #count += 1
  84. #dest_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']}
  85. # for p in prop_service.get_props_by_ref_id(dest_node['id'])]
  86. #edge_props = [{'prop_title': p['prop_title'], 'prop_value': p['prop_value']}
  87. # for p in prop_service.get_props_by_ref_id(edge['id'])]
  88. relationships.append({
  89. "name": edge['name'],
  90. "props": edge_props,
  91. "destNode": {
  92. "category": dest_node['category'],
  93. "id": str(dest_node['id']),
  94. "name": dest_node['name'],
  95. "props": dest_props
  96. }
  97. })
  98. return StandardResponse(
  99. success=True,
  100. requestId=request_id,
  101. data=ObjectToJsonArrayConverter.convert({"relationships": relationships})
  102. )
  103. except Exception as e:
  104. logger.error(f"获取节点关系失败: {str(e)}")
  105. raise HTTPException(500, detail=StandardResponse.error(str(e)))
  106. nodes_api_router = router