yuchengwei 3 ヶ月 前
コミット
5289081809
2 ファイル変更18 行追加8 行削除
  1. 3 3
      agent/cdss/libs/cdss_helper.py
  2. 15 5
      router/graph_router.py

+ 3 - 3
agent/cdss/libs/cdss_helper.py

@@ -413,7 +413,7 @@ class CDSSHelper(GraphHelper):
             queue = [(node, 0, node_id_names[node], {'allowed_types': allowed_types, 'allowed_links': allowed_links})]
 
             # 整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换
-            if input.pat_age.value > 0 and input.pat_age.type == 'year':
+            if input.pat_age and input.pat_age.value is not None and input.pat_age.value > 0 and input.pat_age.type == 'year':
                 # 这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的
                 input.pat_age.value = input.pat_age.value * 12
                 input.pat_age.type = 'month'
@@ -482,9 +482,9 @@ class CDSSHelper(GraphHelper):
         new_results = {}
 
         for item in results:
-            if self.check_sex_allowed(item, input.pat_sex.value) == False:
+            if input.pat_sex and input.pat_sex.value is not None  and self.check_sex_allowed(item, input.pat_sex.value) == False:
                 continue
-            if self.check_age_allowed(item, input.pat_age.value) == False:
+            if input.pat_age and input.pat_age.value is not None and self.check_age_allowed(item, input.pat_age.value) == False:
                 continue
             new_results[item] = results[item]
         results = new_results

+ 15 - 5
router/graph_router.py

@@ -21,7 +21,10 @@ router = APIRouter(prefix="/graph", tags=["Knowledge Graph"])
 
 @router.get("/nodes/recommend", response_model=StandardResponse)
 async def recommend(
-    chief: str
+    chief: str,
+    sex: Optional[str] = None,
+    age: Optional[int] = None,
+    department: Optional[str] = None,
 ):
     start_time = time.time()
     app_id = "256fd853-60b0-4357-b11b-8114b4e90ae0"
@@ -29,7 +32,7 @@ async def recommend(
     result = call_chat_api(app_id, conversation_id, chief)
     json_data = json.loads(result)
     keyword = " ".join(json_data["symptoms"])
-    result = await neighbor_search(keyword=keyword, neighbor_type='Check', limit=10)
+    result = await neighbor_search(keyword=keyword,sex=sex,age=age, neighbor_type='Check', limit=10)
     end_time = time.time()
     print(f"recommend执行完成,耗时:{end_time - start_time:.2f}秒")
     return result;
@@ -38,6 +41,9 @@ async def recommend(
 @router.get("/nodes/neighbor_search", response_model=StandardResponse)
 async def neighbor_search(
     keyword: str = Query(..., min_length=2),
+    sex: Optional[str] = None,
+    age: Optional[int] = None,
+    department: Optional[str] = None,
     limit: int = Query(10, ge=1, le=100),
     node_type: Optional[str] = Query(None),
     neighbor_type: Optional[str] = Query(None),
@@ -52,14 +58,18 @@ async def neighbor_search(
         keywords = keyword.split(" ")
 
         record = CDSSInput(
-            #pat_age=CDSSInt(type="month", value=24),
-            #pat_sex=CDSSText(type="sex", value="女"),
+            pat_age=CDSSInt(type="month", value=age),
+            pat_sex=CDSSText(type="sex", value=sex),
             chief_complaint=keywords,
+            department=CDSSText(type='department', value=department)
         )
         # 使用从main.py导入的capability实例处理CDSS逻辑
         output = capability.process(input=record)
+        
+        output.diagnosis.value = [{"name":key,"score":value["score"],"count":value["count"],
+            "hasInfo": 1,
+            "type": 1} for key,value in output.diagnosis.value.items()]
 
-        print(output.diagnosis.value)
         return StandardResponse(
             success=True,
             data={"可能诊断":output.diagnosis.value,"症状":keywords,"推荐检验":output.checks.value}