kb_router.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import sys,os
  2. current_path = os.getcwd()
  3. sys.path.append(current_path)
  4. from config.site import SiteConfig
  5. from fastapi import APIRouter, Depends, Query
  6. from db.database import get_db
  7. from sqlalchemy.orm import Session
  8. from agent.models.web.response import StandardResponse,FAILED,SUCCESS
  9. from agent.models.web.request import BasicRequest
  10. from agent.libs.graph import GraphBusiness
  11. from agent.libs.auth import verify_session_id, SessionValues
  12. import logging
  13. router = APIRouter(prefix="/kb", tags=["knowledge build interface"])
  14. logger = logging.getLogger(__name__)
  15. config = SiteConfig()
  16. LOG_DIR = config.get_config("TASK_LOG_DIR", current_path)
  17. # job_category = Column(String(64), nullable=False)
  18. # job_name = Column(String(64))
  19. # job_details = Column(Text, nullable=False)
  20. # job_creator = Column(String(64), nullable=False)
  21. # job_logs = Column(Text, nullable=True)
  22. # job_files = Column(String(300), nullable=True)
  23. @router.post('/summary', response_model=StandardResponse)
  24. def summary_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse:
  25. if request.action != "get_summary":
  26. return StandardResponse(code=FAILED, message="invalid action")
  27. graph_id = request.get_param("graph_id",0)
  28. biz = GraphBusiness(db)
  29. summary = biz.get_graph_summary(graph_id=graph_id)
  30. if summary:
  31. logger.info(summary)
  32. return StandardResponse(code=SUCCESS, message="summary found", records=[summary])
  33. else:
  34. return StandardResponse(code=FAILED, message="summary not found",records=[])
  35. @router.post('/schemas', response_model=StandardResponse)
  36. def schemas_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse:
  37. if request.action== "get_nodes_schemas":
  38. graph_id = request.get_param("graph_id",0)
  39. biz = GraphBusiness(db)
  40. schemas = biz.get_nodes_categories(graph_id=graph_id)
  41. if schemas:
  42. return StandardResponse(code=SUCCESS, message="schemas found", records=schemas)
  43. if request.action== "get_edges_schemas":
  44. graph_id = request.get_param("graph_id",0)
  45. biz = GraphBusiness(db)
  46. schemas = biz.get_edges_categories(graph_id=graph_id)
  47. if schemas:
  48. return StandardResponse(code=SUCCESS, message="schemas found", records=schemas)
  49. return StandardResponse(code=FAILED, message="invalid action")
  50. @router.post('/nodes', response_model=StandardResponse)
  51. def nodes_func(request:BasicRequest, db: Session = Depends(get_db), sess:SessionValues = Depends(verify_session_id))->StandardResponse:
  52. if (request.action == "search_nodes"):
  53. node_name = request.get_param("name","")
  54. category = request.get_param("category","")
  55. graph_id = request.get_param("graph_id",0)
  56. biz = GraphBusiness(db)
  57. if node_name == "":
  58. return StandardResponse(code=FAILED, message="node name is empty", records=[])
  59. if category == "":
  60. return StandardResponse(code=FAILED, message="category is empty", records=[])
  61. if graph_id == 0:
  62. return StandardResponse(code=FAILED, message="graph id is empty", records=[])
  63. nodes = biz.search_like_node_by_name(graph_id=graph_id, category=category, name=node_name)
  64. if nodes:
  65. return StandardResponse(code=SUCCESS, message="nodes found", records=nodes)
  66. else:
  67. return StandardResponse(code=FAILED, message="search job failed")
  68. elif (request.action == "get_nodes"):
  69. graph_id = request.get_param("graph_id",0)
  70. page = request.get_param("page",1)
  71. page_size = request.get_param("page_size",1)
  72. biz = GraphBusiness(db)
  73. nodes = biz.get_nodes_by_page(graph_id=graph_id, page=page, page_size=page_size)
  74. if nodes:
  75. return StandardResponse(code=SUCCESS, message="nodes found", records=nodes)
  76. elif (request.action == "neighbors"):
  77. graph_id = request.get_param("graph_id",0)
  78. node_id = request.get_param("node_id",0)
  79. if node_id>0:
  80. biz = GraphBusiness(db)
  81. node = biz.get_node_by_id(graph_id=graph_id, node_id=node_id)
  82. if node is None:
  83. return StandardResponse(code=FAILED, message="node not found", records=[])
  84. nodes_in = biz.get_neighbors(graph_id=graph_id, node_id=node_id, direction="in")
  85. nodes_out = biz.get_neighbors(graph_id=graph_id, node_id=node_id, direction="out")
  86. nodes_all = []
  87. nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"self"})
  88. for node in nodes_in:
  89. nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"in"})
  90. for node in nodes_out:
  91. nodes_all.append({"id": node.id, "name":node.name, "category":node.category, "direction":"out"})
  92. return StandardResponse(code=SUCCESS, message="nodes found", records=nodes_all)
  93. return StandardResponse(code=FAILED, message="invalid action")
  94. kb_router = router