graph_router.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import logging
  2. import sys,os
  3. import traceback
  4. from agent.cdss.capbility import CDSSCapability
  5. from agent.cdss.models.schemas import CDSSInput, CDSSInt, CDSSText
  6. from model.response import StandardResponse
  7. from service.cdss_service import CdssService
  8. from service.kg_node_service import KGNodeService
  9. current_path = os.getcwd()
  10. sys.path.append(current_path)
  11. import time
  12. from fastapi import APIRouter, Depends, Query
  13. from typing import Optional, List
  14. import sys
  15. sys.path.append('..')
  16. from utils.agent import call_chat_api,get_conversation_id
  17. import json
  18. router = APIRouter(prefix="/disease", tags=["Knowledge Graph"])
  19. logger = logging.getLogger(__name__)
  20. @router.get("/recommend", response_model=StandardResponse)
  21. async def recommend(
  22. chief: str,
  23. present_illness: Optional[str] = None,
  24. sex: Optional[str] = None,
  25. age: Optional[int] = None,
  26. department: Optional[str] = None,
  27. ):
  28. start_time = time.time()
  29. app_id = "256fd853-60b0-4357-b11b-8114b4e90ae0"
  30. conversation_id = get_conversation_id(app_id)
  31. desc = "主诉:"+chief
  32. if present_illness:
  33. desc+="\n现病史:" + present_illness
  34. result = call_chat_api(app_id, conversation_id, desc)
  35. json_data = json.loads(result)
  36. keyword = " ".join(json_data["symptoms"])
  37. result = await neighbor_search(keyword=keyword,sex=sex,age=age, neighbor_type='Check', limit=10)
  38. end_time = time.time()
  39. print(f"recommend执行完成,耗时:{end_time - start_time:.2f}秒")
  40. return result;
  41. @router.get("/neighbor_search", response_model=StandardResponse)
  42. async def neighbor_search(
  43. keyword: str = Query(..., min_length=2),
  44. sex: Optional[str] = None,
  45. age: Optional[int] = None,
  46. department: Optional[str] = None
  47. ):
  48. """
  49. 根据关键词和属性过滤条件搜索图谱节点
  50. """
  51. try:
  52. keywords = keyword.split(" ")
  53. record = CDSSInput(
  54. pat_age=CDSSInt(type="month", value=age),
  55. pat_sex=CDSSText(type="sex", value=sex),
  56. chief_complaint=keywords,
  57. department=CDSSText(type='department', value=department)
  58. )
  59. # 使用从main.py导入的capability实例处理CDSS逻辑
  60. output = capability.process(input=record)
  61. output.diagnosis.value = [{"name":key,"count":value["count"],"score":value["score"],"symptoms":value["symptoms"],
  62. #"hasInfo": 1,
  63. #"type": 1
  64. } for key,value in output.diagnosis.value.items()]
  65. return StandardResponse(
  66. success=True,
  67. data={"可能诊断":output.diagnosis.value,"症状":keywords}
  68. )
  69. except Exception as e:
  70. traceback.print_exc()
  71. logger.error(f"get_disease_detail failed: {str(e)}")
  72. return StandardResponse(
  73. success=False,
  74. errorCode=500,
  75. errorMsg=str(e)
  76. )
  77. @router.get("/{disease_name}/detail", response_model=StandardResponse)
  78. async def get_disease_detail(
  79. disease_name: str
  80. ):
  81. try:
  82. service = CdssService()
  83. result = service.get_disease_detail(disease_name,'疾病')
  84. return StandardResponse(success=True, data=result)
  85. except Exception as e:
  86. traceback.print_exc()
  87. logger.error(f"get_disease_detail failed: {str(e)}")
  88. return StandardResponse(
  89. success=False,
  90. errorCode=500,
  91. errorMsg=str(e)
  92. )
  93. capability = CDSSCapability()
  94. graph_router = router