graph_router.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import sys,os
  2. from agent.cdss.capbility import CDSSCapability
  3. from agent.cdss.models.schemas import CDSSInput, CDSSInt, CDSSText
  4. from model.response import StandardResponse
  5. current_path = os.getcwd()
  6. sys.path.append(current_path)
  7. import time
  8. from fastapi import APIRouter, Depends, Query
  9. from typing import Optional, List
  10. import sys
  11. sys.path.append('..')
  12. from utils.agent import call_chat_api,get_conversation_id
  13. import json
  14. router = APIRouter(prefix="/graph", tags=["Knowledge Graph"])
  15. @router.get("/nodes/recommend", response_model=StandardResponse)
  16. async def recommend(
  17. chief: str,
  18. present_illness: Optional[str] = None,
  19. sex: Optional[str] = None,
  20. age: Optional[int] = None,
  21. department: Optional[str] = None,
  22. ):
  23. start_time = time.time()
  24. app_id = "256fd853-60b0-4357-b11b-8114b4e90ae0"
  25. conversation_id = get_conversation_id(app_id)
  26. # desc = "主诉:"+chief
  27. # if present_illness:
  28. # desc+="\n现病史:" + present_illness
  29. result = call_chat_api(app_id, conversation_id, chief)
  30. json_data = json.loads(result)
  31. keyword = " ".join(json_data["symptoms"])
  32. result = await neighbor_search(keyword=keyword,sex=sex,age=age, neighbor_type='Check', limit=10)
  33. end_time = time.time()
  34. print(f"recommend执行完成,耗时:{end_time - start_time:.2f}秒")
  35. return result;
  36. @router.get("/nodes/neighbor_search", response_model=StandardResponse)
  37. async def neighbor_search(
  38. keyword: str = Query(..., min_length=2),
  39. sex: Optional[str] = None,
  40. age: Optional[int] = None,
  41. department: Optional[str] = None,
  42. limit: int = Query(10, ge=1, le=100),
  43. node_type: Optional[str] = Query(None),
  44. neighbor_type: Optional[str] = Query(None),
  45. min_degree: Optional[int] = Query(None)
  46. ):
  47. """
  48. 根据关键词和属性过滤条件搜索图谱节点
  49. """
  50. try:
  51. print(f"开始执行neighbor_search,参数:keyword={keyword}, limit={limit}, node_type={node_type}, neighbor_type={neighbor_type}, min_degree={min_degree}")
  52. keywords = keyword.split(" ")
  53. record = CDSSInput(
  54. pat_age=CDSSInt(type="month", value=age),
  55. pat_sex=CDSSText(type="sex", value=sex),
  56. chief_complaint=keywords,
  57. department=CDSSText(type='department', value=department)
  58. )
  59. # 使用从main.py导入的capability实例处理CDSS逻辑
  60. output = capability.process(input=record)
  61. output.diagnosis.value = [{"name":key,"old_score":value["old_score"],"count":value["count"],"score":value["score"],"symptoms":value["symptoms"],
  62. "hasInfo": 1,
  63. "type": 1} for key,value in output.diagnosis.value.items()]
  64. return StandardResponse(
  65. success=True,
  66. data={"可能诊断":output.diagnosis.value,"症状":keywords,"就诊科室":output.departments.value}
  67. )
  68. except Exception as e:
  69. print(e)
  70. raise e
  71. return StandardResponse(
  72. success=False,
  73. error_code=500,
  74. error_msg=str(e)
  75. )
  76. capability = CDSSCapability()
  77. #def get_capability():
  78. #from main import capability
  79. #return capability
  80. graph_router = router