Browse Source

代码提交

SGTY 1 tháng trước cách đây
mục cha
commit
f292101aa9
2 tập tin đã thay đổi với 54 bổ sung33 xóa
  1. 53 32
      agent/cdss/libs/cdss_helper2.py
  2. 1 1
      router/graph_router.py

+ 53 - 32
agent/cdss/libs/cdss_helper2.py

@@ -6,6 +6,8 @@ import logging
 import json
 import time
 
+from sqlalchemy import false
+
 from service.kg_edge_service import KGEdgeService
 from db.session import get_db
 from service.kg_node_service import KGNodeService
@@ -279,7 +281,11 @@ class CDSSHelper(GraphHelper):
         allowed_types = DEPARTMENT + DIESEASE + SYMPTOM
         # 定义允许的关系类型,包括has_symptom、need_check、recommend_drug、belongs_to
         # 这些关系类型用于后续的路径查找和过滤
-        allowed_links = ['has_symptom','疾病相关症状','belongs_to','所属科室']
+
+        symptom_edge = ['has_symptom', '疾病相关症状']
+        symptom_same_edge = ['症状同义词', '症状同义词2.0']
+        department_edge = ['belongs_to','所属科室']
+        allowed_links = symptom_edge+department_edge+symptom_same_edge
         # 将输入的症状名称转换为节点ID
         # 由于可能存在同名节点,转换后的节点ID数量可能大于输入的症状数量
         node_ids = []
@@ -312,7 +318,7 @@ class CDSSHelper(GraphHelper):
         results = self.validDisease(results, start_nodes)
 
         # 调用step2方法处理科室、检查和药品信息
-        results = self.step2(results)
+        results = self.step2(results,department_edge)
 
         # STEP 3: 对于结果按照科室维度进行汇总
         final_results = self.step3(results)
@@ -325,7 +331,7 @@ class CDSSHelper(GraphHelper):
             departments.append({"name": temp[0], "count": temp[1]["count"]})
         
         # STEP 5: 对于final_results里面的diseases, checks和durgs统计全局出现的次数并且按照次数降序排序
-        sorted_score_diags,total_diags = self.step5(final_results, input,start_nodes)
+        sorted_score_diags,total_diags = self.step5(final_results, input,start_nodes,symptom_edge)
 
         # STEP 6: 整合数据并返回
         # if "department" in item.keys():
@@ -523,7 +529,7 @@ class CDSSHelper(GraphHelper):
         print(f"STEP 1 遍历图谱查找相关疾病 finished")
         return results
 
-    def step2(self, results):
+    def step2(self, results,department_edge):
         """
         查找疾病对应的科室、检查和药品信息
         :param results: 包含疾病信息的字典
@@ -546,7 +552,7 @@ class CDSSHelper(GraphHelper):
             out_edges = self.graph.out_edges(disease, data=True)
             for edge in out_edges:
                 src, dest, edge_data = edge
-                if edge_data["type"] != 'belongs_to' and edge_data["type"] != '所属科室':
+                if edge_data["type"] not in department_edge:
                     continue
                 dest_data = self.entity_data[self.entity_data.index == dest]
                 if dest_data.empty:
@@ -680,7 +686,7 @@ class CDSSHelper(GraphHelper):
         print("\n" + "\n".join(log_data))
         return sorted_final_results
 
-    def step5(self, final_results, input,start_nodes):
+    def step5(self, final_results, input, start_nodes, symptom_edge):
         """
         按科室汇总结果并排序
 
@@ -704,44 +710,34 @@ class CDSSHelper(GraphHelper):
                 count = 1
             for disease, data in final_results[department]["diseases"]:
                 total_diags += 1
-                #key = 'disease_name_parent_' + disease
-                # cached_data = self.cache.get(key)
-                # if cached_data:
-                #     disease = cached_data
+                disease_id = disease.split(":")[0]
                 disease = disease.split(":")[1]
                 if disease in diags.keys():
                     diags[disease]["count"] += data["count"]+count
                     diags[disease]["score"] += (data["count"]+count)*0.1 * department_factor
                 else:
-                    diags[disease] = {"count": data["count"]+count, "score": (data["count"]+count)*0.1 * department_factor}
+                    symptoms_data = self.get_symptoms_data(disease_id, symptom_edge)
+                    if symptoms_data is None:
+                        continue
+                    symptoms = []
+                    for symptom in symptoms_data:
+                        matched = False
+                        if symptom in start_nodes:
+                            matched = True
+                        symptoms.append({"name":symptom,"matched":matched})
+                    #symtoms中matched=true的排在前面,matched=false的排在后面
+                    symptoms = sorted(symptoms, key=lambda x: x["matched"], reverse=True)
+                    diags[disease] = {"count": data["count"]+count, "score": (data["count"]+count)*0.1 * department_factor,"symptoms":symptoms}
   
         #sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)[:10]
         sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
 
         diags = {}
         for item in sorted_score_diags:
-            new_item = {"old_score": item[1]["score"],"count": item[1]["count"], "score": float(item[1]["count"])*0.1}
+            new_item = {"old_score": item[1]["score"],"count": item[1]["count"], "score": float(item[1]["count"])*0.1,"symptoms":symptoms}
             diags[item[0]] = new_item
         sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
-        #循环sorted_score_diags,把disease现在是id:name替换name
-        # diags = {}
-        # for item in sorted_score_diags:
-        #     diseaseinfos = item[0].split(":")
-        #     diseaseId = diseaseinfos[0]
-        #     diseaseName = diseaseinfos[1]
-        #     diseaseScore = self.graph.nodes[int(diseaseId)].get('score', None)
-        #     if diseaseScore:
-        #         try:
-        #             # 创建新字典替换原元组
-        #             new_item = {"count": item[1]["count"], "score": float(diseaseScore)*0.1}
-        #             diags[diseaseName] = new_item
-        #         except (ValueError, TypeError):
-        #             # 如果转换失败,直接使用原元组
-        #             diags[diseaseName] = item[1]
-        #     else:
-        #         # 如果没有找到对应的score,直接使用原元组
-        #         diags[diseaseName] = item[1]
-        # sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
+
         print(f"STEP 5 finished")
         end_time = time.time()
         print(f"STEP 5 执行完成,耗时:{end_time - start_time:.2f}秒")
@@ -755,4 +751,29 @@ class CDSSHelper(GraphHelper):
             log_data.append(f"|{department}|{diesease_data}|{count_data}|{score_data}|")
 
         print("这里是经过排序的数据\n" + "\n".join(log_data))
-        return sorted_score_diags, total_diags
+        return sorted_score_diags, total_diags
+    
+    def get_symptoms_data(self, disease_id, symptom_edge):
+        """
+        获取疾病相关的症状数据
+        :param disease_id: 疾病节点ID
+        :param symptom_edge: 症状关系类型列表
+        :return: 症状数据列表
+        """
+        key = f'disease_{disease_id}_symptom'
+        symptom_data = self.cache[key] if key in self.cache else None
+        if symptom_data is None:
+            out_edges = self.graph.out_edges(int(disease_id), data=True)
+            symptom_data = []
+            for edge in out_edges:
+                src, dest, edge_data = edge
+                if edge_data["type"] not in symptom_edge:
+                    continue
+                dest_data = self.entity_data[self.entity_data.index == dest]
+                if dest_data.empty:
+                    continue
+                dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
+                if dest_name not in symptom_data:
+                    symptom_data.append(dest_name)
+            self.cache[key]=symptom_data
+        return symptom_data

+ 1 - 1
router/graph_router.py

@@ -66,7 +66,7 @@ async def neighbor_search(
         # 使用从main.py导入的capability实例处理CDSS逻辑
         output = capability.process(input=record)
 
-        output.diagnosis.value = [{"name":key,"old_score":value["old_score"],"count":value["count"],"score":value["score"],
+        output.diagnosis.value = [{"name":key,"old_score":value["old_score"],"count":value["count"],"score":value["score"],"symptoms":value["symptoms"],
             "hasInfo": 1,
             "type": 1} for key,value in output.diagnosis.value.items()]