瀏覽代碼

代码提交

SGTY 1 月之前
父節點
當前提交
e463e90196
共有 3 個文件被更改,包括 25 次插入16 次删除
  1. 19 15
      agent/cdss/libs/cdss_helper2.py
  2. 5 0
      router/graph_router.py
  3. 1 1
      tests/test.py

+ 19 - 15
agent/cdss/libs/cdss_helper2.py

@@ -286,6 +286,7 @@ class CDSSHelper(GraphHelper):
         symptom_same_edge = ['症状同义词', '症状同义词2.0']
         department_edge = ['belongs_to','所属科室']
         allowed_links = symptom_edge+department_edge+symptom_same_edge
+        #allowed_links = symptom_edge + department_edge
         # 将输入的症状名称转换为节点ID
         # 由于可能存在同名节点,转换后的节点ID数量可能大于输入的症状数量
         node_ids = []
@@ -710,32 +711,35 @@ class CDSSHelper(GraphHelper):
                 count = 1
             for disease, data in final_results[department]["diseases"]:
                 total_diags += 1
-                disease_id = disease.split(":")[0]
-                disease = disease.split(":")[1]
                 if disease in diags.keys():
                     diags[disease]["count"] += data["count"]+count
                     diags[disease]["score"] += (data["count"]+count)*0.1 * department_factor
                 else:
-                    symptoms_data = self.get_symptoms_data(disease_id, symptom_edge)
-                    if symptoms_data is None:
-                        continue
-                    symptoms = []
-                    for symptom in symptoms_data:
-                        matched = False
-                        if symptom in start_nodes:
-                            matched = True
-                        symptoms.append({"name":symptom,"matched":matched})
-                    #symtoms中matched=true的排在前面,matched=false的排在后面
-                    symptoms = sorted(symptoms, key=lambda x: x["matched"], reverse=True)
-                    diags[disease] = {"count": data["count"]+count, "score": (data["count"]+count)*0.1 * department_factor,"symptoms":symptoms}
+                    diags[disease] = {"count": data["count"]+count, "score": (data["count"]+count)*0.1 * department_factor}
   
         #sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)[:10]
         sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
 
         diags = {}
         for item in sorted_score_diags:
+            disease_info = item[0].split(":");
+            disease_id = disease_info[0]
+            disease = disease_info[1]
+            symptoms_data = self.get_symptoms_data(disease_id, symptom_edge)
+            if symptoms_data is None:
+                continue
+            symptoms = []
+            for symptom in symptoms_data:
+                matched = False
+                if symptom in start_nodes:
+                    matched = True
+                symptoms.append({"name": symptom, "matched": matched})
+            # symtoms中matched=true的排在前面,matched=false的排在后面
+            symptoms = sorted(symptoms, key=lambda x: x["matched"], reverse=True)
+
+
             new_item = {"old_score": item[1]["score"],"count": item[1]["count"], "score": float(item[1]["count"])*0.1,"symptoms":symptoms}
-            diags[item[0]] = new_item
+            diags[disease] = new_item
         sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
 
         print(f"STEP 5 finished")

+ 5 - 0
router/graph_router.py

@@ -22,6 +22,7 @@ router = APIRouter(prefix="/graph", tags=["Knowledge Graph"])
 @router.get("/nodes/recommend", response_model=StandardResponse)
 async def recommend(
     chief: str,
+    present_illness: Optional[str] = None,
     sex: Optional[str] = None,
     age: Optional[int] = None,
     department: Optional[str] = None,
@@ -29,6 +30,10 @@ async def recommend(
     start_time = time.time()
     app_id = "256fd853-60b0-4357-b11b-8114b4e90ae0"
     conversation_id = get_conversation_id(app_id)
+
+    # desc = "主诉:"+chief
+    # if present_illness:
+    #     desc+="\n现病史:" + present_illness
     result = call_chat_api(app_id, conversation_id, chief)
     json_data = json.loads(result)
     keyword = " ".join(json_data["symptoms"])

+ 1 - 1
tests/test.py

@@ -6,7 +6,7 @@ capability = CDSSCapability()
 record = CDSSInput(
     pat_age=CDSSInt(type="month", value=24),
     pat_sex=CDSSText(type="sex", value="男"),
-    chief_complaint=["右下腹痛"],
+    chief_complaint=["右下腹痛","恶心","呕吐"],
     #chief_complaint=["呕血", "黑便", "头晕", "心悸"],
     #chief_complaint=["流鼻涕"],