Browse Source

代码提交

SGTY 4 weeks ago
parent
commit
bf94d7456c

+ 1 - 1
agent/cdss/capbility.py

@@ -36,7 +36,7 @@ class CDSSCapability:
 
             for item in result["score_diags"][:10]:
                 output.diagnosis.value[item[0]] = item[1]
-            output.departments.value = result["departments"]
+
 
             # for item in result["checks"][:5]:
             #     item[1]['score'] = item[1]['count'] / result["total_checks"]

+ 60 - 58
agent/cdss/libs/cdss_helper2.py

@@ -82,6 +82,9 @@ class CDSSHelper(GraphHelper):
         self.entity_data.set_index("id", inplace=True)
         print("load entity data finished")
 
+    def get_entity_data(self):
+        return self.entity_data
+
     def _append_entity_attribute(self, data, item, attr_name):
         if attr_name in item[1]:
             value = item[1][attr_name].split(":")
@@ -236,6 +239,8 @@ class CDSSHelper(GraphHelper):
         return False
 
     def check_diease_allowed(self, node):
+        if node == 1479768:
+            return True
         is_symptom = self.graph.nodes[node].get('is_symptom', None)
         if is_symptom == "是":
             return False
@@ -285,7 +290,7 @@ class CDSSHelper(GraphHelper):
         symptom_edge = ['has_symptom', '疾病相关症状']
         symptom_same_edge = ['症状同义词', '症状同义词2.0']
         department_edge = ['belongs_to','所属科室']
-        allowed_links = symptom_edge+department_edge+symptom_same_edge
+        allowed_links = symptom_edge+symptom_same_edge
         # allowed_links = symptom_edge + department_edge
 
         # 将输入的症状名称转换为节点ID
@@ -314,56 +319,45 @@ class CDSSHelper(GraphHelper):
                 logger.debug(f"node {node} not found")
         node_ids = node_ids_filtered
 
-        # out_edges = self.graph.out_edges(disease, data=True)
-        # for edge in out_edges:
-        #     src, dest, edge_data = edge
-        #     if edge_data["type"] not in department_edge:
-        #         continue
-        #     dest_data = self.entity_data[self.entity_data.index == dest]
-        #     if dest_data.empty:
-        #         continue
-        #     department_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
-        #     department_data.extend([department_name] * results[disease]["count"])
-
+        results = self.step1(node_ids,node_id_names, input, allowed_types, symptom_same_edge,allowed_links,max_hops,DIESEASE)
 
-        results = self.step1(node_ids,node_id_names, input, allowed_types, allowed_links,max_hops,DIESEASE)
-
-        #self.validDisease(results, start_nodes)
         results = self.validDisease(results, start_nodes)
 
-        sorted_score_diags = sorted(results.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
-
-        # 调用step2方法处理科室、检查和药品信息
-        results = self.step2(results,department_edge)
-
-        # STEP 3: 对于结果按照科室维度进行汇总
-        final_results = self.step3(results)
+        sorted_count_diags = sorted(results.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
+        diags = {}
+        target_symptom_names = []
+        for symptom_id in node_ids:
+            target_symptom_name = self.entity_data[self.entity_data.index == symptom_id]['name'].tolist()[0]
+            target_symptom_names.append(target_symptom_name)
+            same_symptoms = self.get_in_edges(symptom_id, symptom_same_edge)
+            #same_symptoms的name属性全部添加到target_symptom_names中
+            for same_symptom in same_symptoms:
+                target_symptom_names.append(same_symptom['name'].tolist()[0])
+        for item in sorted_count_diags:
+            disease_id = item[0]
+            disease_name = self.entity_data[self.entity_data.index == disease_id]['name'].tolist()[0]
+            symptoms_data = self.get_symptoms_data(disease_id, symptom_edge)
+            if symptoms_data is None:
+                continue
+            symptoms = []
+            for symptom in symptoms_data:
+                matched = False
+                if symptom in target_symptom_names:
+                    matched = True
+                symptoms.append({"name": symptom, "matched": matched})
+            # symtoms中matched=true的排在前面,matched=false的排在后面
+            symptoms = sorted(symptoms, key=lambda x: x["matched"], reverse=True)
 
-        sorted_final_results = self.step4(final_results)
-        sorted_final_results = sorted_final_results[:5]
-        departments = []
+            start_nodes_size = len(start_nodes)
+            # if start_nodes_size > 1:
+            #     start_nodes_size = start_nodes_size*0.5
+            new_item = {"count": item[1]["count"],
+                        "score": float(item[1]["count"]) / start_nodes_size  * 0.1, "symptoms": symptoms}
+            diags[disease_name] = new_item
+        sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
 
-        for temp in sorted_final_results:
-            departments.append({"name": temp[0], "count": temp[1]["count"]})
-        
-        # STEP 5: 对于final_results里面的diseases, checks和durgs统计全局出现的次数并且按照次数降序排序
-        sorted_score_diags,total_diags = self.step5(final_results, input,start_nodes,symptom_edge)
-
-        # STEP 6: 整合数据并返回
-        # if "department" in item.keys():
-        #     final_results["department"] = list(set(final_results["department"]+item["department"]))
-        # if "diseases" in item.keys():
-        #     final_results["diseases"] = list(set(final_results["diseases"]+item["diseases"]))
-        # if "checks" in item.keys():
-        #     final_results["checks"] = list(set(final_results["checks"]+item["checks"]))
-        # if "drugs" in item.keys():
-        #     final_results["drugs"] = list(set(final_results["drugs"]+item["drugs"]))
-        # if "symptoms" in item.keys():
-        #     final_results["symptoms"] = list(set(final_results["symptoms"]+item["symptoms"]))
-
-        return {"details": sorted_final_results,
-                "score_diags": sorted_score_diags,"total_diags": total_diags,
-                "departments":departments
+        return {
+                "score_diags": sorted_score_diags
                 # "checks":sorted_checks, "drugs":sorted_drugs,
                 # "total_checks":total_check, "total_drugs":total_drug
                 }
@@ -410,8 +404,8 @@ class CDSSHelper(GraphHelper):
         content = "疾病和症状相关性统计表格\n" + "\n".join(log_data)
         print(f"\n{content}")
         return filtered_results
-
-    def step1(self, node_ids,node_id_names, input, allowed_types, allowed_links,max_hops,DIESEASE):
+    
+    def step1(self, node_ids,node_id_names, input, allowed_types, symptom_same_edge,allowed_links,max_hops,DIESEASE):
         """
         根据症状节点查找相关疾病
         :param node_ids: 症状节点ID列表
@@ -441,12 +435,11 @@ class CDSSHelper(GraphHelper):
                             results[disease_id]["path"].append(path)
                         else:
                             results[disease_id] = temp_results[disease_id]
+                continue
 
 
-                continue
 
             queue = [(node, 0, node_id_names[node],10, {'allowed_types': allowed_types, 'allowed_links': allowed_links})]
-
             # 整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换
             if input.pat_age and input.pat_age.value is not None and input.pat_age.value > 0 and input.pat_age.type == 'year':
                 # 这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的
@@ -458,6 +451,8 @@ class CDSSHelper(GraphHelper):
             while queue:
                 temp_node, depth, path, weight, data = queue.pop(0)
                 temp_node = int(temp_node)
+                if temp_node == 1479768:
+                    print(1111)
                 # 这里是通过id去获取节点的name和type
                 entity_data = self.entity_data[self.entity_data.index == temp_node]
                 # 如果节点不存在,那么跳过
@@ -465,6 +460,7 @@ class CDSSHelper(GraphHelper):
                     continue
                 if self.graph.nodes.get(temp_node) is None:
                     continue
+
                 node_type = self.entity_data[self.entity_data.index == temp_node]['type'].tolist()[0]
                 node_name = self.entity_data[self.entity_data.index == temp_node]['name'].tolist()[0]
                 # print(f"node {node} type {node_type}")
@@ -473,11 +469,11 @@ class CDSSHelper(GraphHelper):
                     count = weight
                     if self.check_diease_allowed(temp_node) == False:
                         continue
-                    if temp_node in temp_results.keys():
-                        temp_results[temp_node]["count"] = temp_results[temp_node]["count"] + count
-                        results[disease_id]["increase"] = results[disease_id]["increase"] + 1
-                        temp_results[temp_node]["path"].append(path)
-                    else:
+                    if temp_node not in temp_results.keys():
+                    #     temp_results[temp_node]["count"] = temp_results[temp_node]["count"] + count
+                    #     temp_results[temp_node]["increase"] = temp_results[temp_node]["increase"] + 1
+                    #     temp_results[temp_node]["path"].append(path)
+                    # else:
                         temp_results[temp_node] = {"type": node_type, "count": count, "increase": 1, "name": node_name, 'path': [path]}
 
                     continue
@@ -491,11 +487,9 @@ class CDSSHelper(GraphHelper):
                 if temp_node not in self.graph:
                     # print(f"node {node} not found in graph")
                     continue
-                # todo 目前是取入边,出边是不是也有用?
                 for edge in self.graph.in_edges(temp_node, data=True):
                     src, dest, edge_data = edge
-                    if src not in visited and depth + 1 < max_hops and edge_data['type'] in allowed_links:
-                        # print(f"put into queue travel from {src} to {dest}")
+                    if src not in visited and depth + 1 <= max_hops and edge_data['type'] in allowed_links:                       
                         weight = edge_data['weight']
                         try :
                             if weight:
@@ -546,6 +540,14 @@ class CDSSHelper(GraphHelper):
         print(f"STEP 1 遍历图谱查找相关疾病 finished")
         return results
 
+    def get_in_edges(self, node_id, allowed_links):
+        results = []
+        for edge in self.graph.in_edges(node_id, data=True):
+            src, dest, edge_data = edge
+            if edge_data['type'] in allowed_links:
+                results.append(self.entity_data[self.entity_data.index == src])
+        return results
+
     def step2(self, results,department_edge):
         """
         查找疾病对应的科室、检查和药品信息

+ 800 - 0
agent/cdss/libs/cdss_helper3.py

@@ -0,0 +1,800 @@
+import copy
+from hmac import new
+import os
+import sys
+import logging
+import json
+import time
+
+from sqlalchemy import false
+
+from service.kg_edge_service import KGEdgeService
+from db.session import get_db
+from service.kg_node_service import KGNodeService
+from service.kg_prop_service import KGPropService
+from cachetools import TTLCache
+from cachetools.keys import hashkey
+
+current_path = os.getcwd()
+sys.path.append(current_path)
+
+from community.graph_helper import GraphHelper
+from typing import List
+from agent.cdss.models.schemas import CDSSInput
+from config.site import SiteConfig
+import networkx as nx
+import pandas as pd
+
+logger = logging.getLogger(__name__)
+
+current_path = os.getcwd()
+sys.path.append(current_path)
+
+# 图谱数据缓存路径(由dump_graph_data.py生成)
+CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data')
+
+
+class CDSSHelper(GraphHelper):
+
+    def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
+
+        kg_node_service = KGNodeService(next(get_db()))
+        es_result = kg_node_service.search_title_index("", node_id,node_type, limit)
+        results = []
+        for item in es_result:
+            n = self.graph.nodes.get(item["id"])
+            score = item["score"]
+            if n:
+                results.append({
+                    'id': item["title"],
+                    'score': score,
+                    "name": item["title"],
+                })
+        return results
+
+    def _load_entity_data(self):
+        config = SiteConfig()
+        # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
+        print("load entity data")
+        # 这里设置了读取的属性
+        data = {"id": [], "name": [], "type": [],"is_symptom": [], "sex": [], "age": [],"score": []}
+        if not os.path.exists(os.path.join(CACHED_DATA_PATH, 'entities_med.json')):
+            return []
+        with open(os.path.join(CACHED_DATA_PATH, 'entities_med.json'), "r", encoding="utf-8") as f:
+            entities = json.load(f)
+            for item in entities:
+                #如果id已经存在,则跳过
+                # if item[0] in data["id"]:
+                #     print(f"skip {item[0]}")
+                #     continue
+                data["id"].append(int(item[0]))
+                data["name"].append(item[1]["name"])
+                data["type"].append(item[1]["type"])
+                self._append_entity_attribute(data, item, "sex")
+                self._append_entity_attribute(data, item, "age")
+                self._append_entity_attribute(data, item, "is_symptom")
+                self._append_entity_attribute(data, item, "score")
+                # item[1]["id"] = item[0]
+                # item[1]["name"] = item[0]
+                # attrs = item[1]
+                # self.graph.add_node(item[0], **attrs)
+        self.entity_data = pd.DataFrame(data)
+        self.entity_data.set_index("id", inplace=True)
+        print("load entity data finished")
+
+    def _append_entity_attribute(self, data, item, attr_name):
+        if attr_name in item[1]:
+            value = item[1][attr_name].split(":")
+            if len(value) < 2:
+                data[attr_name].append(value[0])
+            else:
+                data[attr_name].append(value[1])
+        else:
+            data[attr_name].append("")
+
+    def _load_relation_data(self):
+        config = SiteConfig()
+        # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
+        print("load relationship data")
+
+        for i in range(62):
+            if not os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
+                continue
+            if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
+                print(f"load entity data {CACHED_DATA_PATH}\\relationship_med_{i}.json")
+                with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f:
+                    data = {"src": [], "dest": [], "type": [], "weight": []}
+                    relations = json.load(f)
+                    for item in relations:
+                        data["src"].append(int(item[0]))
+                        data["dest"].append(int(item[2]))
+                        data["type"].append(item[4]["type"])
+                        if "order" in item[4]:
+                            order = item[4]["order"].split(":")
+                            if len(order) < 2:
+                                data["weight"].append(int(order[0]))
+                            else:
+                                data["weight"].append(int(order[1]))
+                        else:
+                            data["weight"].append(1)
+
+                    self.relation_data = pd.concat([self.relation_data, pd.DataFrame(data)], ignore_index=True)
+
+    def build_graph(self):
+        self.entity_data = pd.DataFrame(
+            {"id": [], "name": [], "type": [], "sex": [], "allowed_age_range": []})
+        self.relation_data = pd.DataFrame({"src": [], "dest": [], "type": [], "weight": []})
+        self._load_entity_data()
+        self._load_relation_data()
+        self._load_local_data()
+
+        self.graph = nx.from_pandas_edgelist(self.relation_data, "src", "dest", edge_attr=True,
+                                             create_using=nx.DiGraph())
+
+        nx.set_node_attributes(self.graph, self.entity_data.to_dict(orient="index"))
+        # print(self.graph.in_edges('1257357',data=True))
+
+    def _load_local_data(self):
+        # 这里加载update数据和权重数据
+        config = SiteConfig()
+        self.update_data_path = config.get_config('UPDATE_DATA_PATH')
+        self.factor_data_path = config.get_config('FACTOR_DATA_PATH')
+        print(f"load update data from {self.update_data_path}")
+        for root, dirs, files in os.walk(self.update_data_path):
+            for file in files:
+                file_path = os.path.join(root, file)
+                if file_path.endswith(".json") and file.startswith("ent"):
+                    self._load_update_entity_json(file_path)
+                if file_path.endswith(".json") and file.startswith("rel"):
+                    self._load_update_relationship_json(file_path)
+
+    def _load_update_entity_json(self, file):
+        '''load json data from file'''
+        print(f"load entity update data from {file}")
+
+        # 这里加载update数据,update数据是一个json文件,格式同cached data如下:
+
+        with open(file, "r", encoding="utf-8") as f:
+            entities = json.load(f)
+            for item in entities:
+                original_data = self.entity_data[self.entity_data.index == item[0]]
+                if original_data.empty:
+                    continue
+                original_data = original_data.iloc[0]
+                id = int(item[0])
+                name = item[1]["name"] if "name" in item[1] else original_data['name']
+                type = item[1]["type"] if "type" in item[1] else original_data['type']
+                allowed_sex_list = item[1]["allowed_sex_list"] if "allowed_sex_list" in item[1] else original_data[
+                    'allowed_sex_list']
+                allowed_age_range = item[1]["allowed_age_range"] if "allowed_age_range" in item[1] else original_data[
+                    'allowed_age_range']
+
+                self.entity_data.loc[id, ["name", "type", "allowed_sex_list", "allowed_age_range"]] = [name, type,
+                                                                                                       allowed_sex_list,
+                                                                                                       allowed_age_range]
+
+    def _load_update_relationship_json(self, file):
+        '''load json data from file'''
+        print(f"load relationship update data from {file}")
+
+        with open(file, "r", encoding="utf-8") as f:
+            relations = json.load(f)
+            for item in relations:
+                data = {}
+                original_data = self.relation_data[(self.relation_data['src'] == data['src']) &
+                                                   (self.relation_data['dest'] == data['dest']) &
+                                                   (self.relation_data['type'] == data['type'])]
+                if original_data.empty:
+                    continue
+                original_data = original_data.iloc[0]
+                data["src"] = int(item[0])
+                data["dest"] = int(item[2])
+                data["type"] = item[4]["type"]
+                data["weight"] = item[4]["weight"] if "weight" in item[4] else original_data['weight']
+
+                self.relation_data.loc[(self.relation_data['src'] == data['src']) &
+                                       (self.relation_data['dest'] == data['dest']) &
+                                       (self.relation_data['type'] == data['type']), 'weight'] = data["weight"]
+
+    def check_sex_allowed(self, node, sex):
+
+        # 性别过滤,假设疾病节点有一个属性叫做allowed_sex_type,值为“0,1,2”,分别代表未知,男,女
+        sex_allowed = self.graph.nodes[node].get('sex', None)
+        #sexProps = self.propService.get_props_by_ref_id(node, 'sex')
+        #if len(sexProps) > 0 and sexProps[0]['prop_value'] is not None and sexProps[0][
+            #'prop_value'] != input.pat_sex.value:
+            #continue
+        if sex_allowed:
+            if len(sex_allowed) == 0:
+                # 如果性别列表为空,那么默认允许所有性别
+                return True
+            sex_allowed_list = sex_allowed.split(',')
+            if sex not in sex_allowed_list:
+                # 如果性别不匹配,跳过
+                return False
+        return True
+
+    def check_age_allowed(self, node, age):
+        # 年龄过滤,假设疾病节点有一个属性叫做allowed_age_range,值为“6-88”,代表年龄在0-88月之间是允许的
+        # 如果说年龄小于6岁,那么我们就认为是儿童,所以儿童的年龄范围是0-6月
+        age_allowed = self.graph.nodes[node].get('age', None)
+        if age_allowed:
+            if len(age_allowed) == 0:
+                # 如果年龄范围为空,那么默认允许所有年龄
+                return True
+            age_allowed_list = age_allowed.split('-')
+            age_min = int(age_allowed_list[0])
+            age_max = int(age_allowed_list[-1])
+            if age_max ==0:
+                return True
+            if age >= age_min and age < age_max:
+                # 如果年龄范围正常,那么返回True
+                return True
+        else:
+            # 如果没有设置年龄范围,那么默认返回True
+            return True
+        return False
+
+    def check_diease_allowed(self, node):
+        is_symptom = self.graph.nodes[node].get('is_symptom', None)
+        if is_symptom == "是":
+            return False
+        return True
+
+    propService = KGPropService(next(get_db()))
+    cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
+    def cdss_travel(self, input: CDSSInput, start_nodes: List, max_hops=3):
+        """
+        基于输入的症状节点,在知识图谱中进行遍历,查找相关疾病、科室、检查和药品
+
+        参数:
+            input: CDSSInput对象,包含患者的基本信息(年龄、性别等)
+            start_nodes: 症状节点名称列表,作为遍历的起点
+            max_hops: 最大遍历深度,默认为3
+
+        返回值:
+            返回一个包含以下信息的字典:
+                - details: 按科室汇总的结果
+                - diags: 按相关性排序的疾病列表
+                - checks: 按出现频率排序的检查列表
+                - drugs: 按出现频率排序的药品列表
+                - total_diags: 疾病总数
+                - total_checks: 检查总数
+                - total_drugs: 药品总数
+
+        主要步骤:
+            1. 初始化允许的节点类型和关系类型
+            2. 将症状名称转换为节点ID
+            3. 遍历图谱查找相关疾病(STEP 1)
+            4. 查找疾病对应的科室、检查和药品(STEP 2)
+            5. 按科室汇总结果(STEP 3)
+            6. 对结果进行排序和统计(STEP 4-6)
+        """
+        # 定义允许的节点类型,包括科室、疾病、药品、检查和症状
+        # 这些类型用于后续的节点过滤和路径查找
+        DEPARTMENT = ['科室', 'Department']
+        DIESEASE = ['疾病', 'Disease']
+        DRUG = ['药品', 'Drug']
+        CHECK = ['检查', 'Check']
+        SYMPTOM = ['症状', 'Symptom']
+        #allowed_types = DEPARTMENT + DIESEASE + DRUG + CHECK + SYMPTOM
+        allowed_types = DEPARTMENT + DIESEASE + SYMPTOM
+        # 定义允许的关系类型,包括has_symptom、need_check、recommend_drug、belongs_to
+        # 这些关系类型用于后续的路径查找和过滤
+
+        symptom_edge = ['has_symptom', '疾病相关症状']
+        symptom_same_edge = ['症状同义词', '症状同义词2.0']
+        department_edge = ['belongs_to','所属科室']
+        allowed_links = symptom_edge+symptom_same_edge
+        # allowed_links = symptom_edge + department_edge
+
+        # 将输入的症状名称转换为节点ID
+        # 由于可能存在同名节点,转换后的节点ID数量可能大于输入的症状数量
+        node_ids = []
+        node_id_names = {}
+
+        # start_nodes里面重复的症状,去重同样的症状
+        start_nodes = list(set(start_nodes))
+        for node in start_nodes:
+            #print(f"searching for node {node}")
+            result = self.entity_data[self.entity_data['name'] == node]
+            #print(f"searching for node {result}")
+            for index, data in result.iterrows():
+                if data["type"] in SYMPTOM or data["type"] in DIESEASE:
+                    node_id_names[index] = data["name"]
+                    node_ids = node_ids + [index]
+        #print(f"start travel from {node_id_names}")
+
+        # 这里是一个队列,用于存储待遍历的症状:
+        node_ids_filtered = []
+        for node in node_ids:
+            if self.graph.has_node(node):
+                node_ids_filtered.append(node)
+            else:
+                logger.debug(f"node {node} not found")
+        node_ids = node_ids_filtered
+
+        # out_edges = self.graph.out_edges(disease, data=True)
+        # for edge in out_edges:
+        #     src, dest, edge_data = edge
+        #     if edge_data["type"] not in department_edge:
+        #         continue
+        #     dest_data = self.entity_data[self.entity_data.index == dest]
+        #     if dest_data.empty:
+        #         continue
+        #     department_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
+        #     department_data.extend([department_name] * results[disease]["count"])
+
+
+        results = self.step1(node_ids,node_id_names, input, allowed_types, allowed_links,max_hops,DIESEASE)
+
+        #self.validDisease(results, start_nodes)
+        results = self.validDisease(results, start_nodes)
+
+        sorted_score_diags = sorted(results.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
+
+        # 调用step2方法处理科室、检查和药品信息
+        results = self.step2(results,department_edge)
+
+        # STEP 3: 对于结果按照科室维度进行汇总
+        final_results = self.step3(results)
+
+        sorted_final_results = self.step4(final_results)
+        sorted_final_results = sorted_final_results[:5]
+        departments = []
+
+        for temp in sorted_final_results:
+            departments.append({"name": temp[0], "count": temp[1]["count"]})
+        
+        # STEP 5: 对于final_results里面的diseases, checks和durgs统计全局出现的次数并且按照次数降序排序
+        sorted_score_diags,total_diags = self.step5(final_results, input,start_nodes,symptom_edge)
+
+        # STEP 6: 整合数据并返回
+        # if "department" in item.keys():
+        #     final_results["department"] = list(set(final_results["department"]+item["department"]))
+        # if "diseases" in item.keys():
+        #     final_results["diseases"] = list(set(final_results["diseases"]+item["diseases"]))
+        # if "checks" in item.keys():
+        #     final_results["checks"] = list(set(final_results["checks"]+item["checks"]))
+        # if "drugs" in item.keys():
+        #     final_results["drugs"] = list(set(final_results["drugs"]+item["drugs"]))
+        # if "symptoms" in item.keys():
+        #     final_results["symptoms"] = list(set(final_results["symptoms"]+item["symptoms"]))
+
+        return {"details": sorted_final_results,
+                "score_diags": sorted_score_diags,"total_diags": total_diags,
+                "departments":departments
+                # "checks":sorted_checks, "drugs":sorted_drugs,
+                # "total_checks":total_check, "total_drugs":total_drug
+                }
+
+    def validDisease(self, results, start_nodes):
+        """
+        输出有效的疾病信息为Markdown格式
+        :param results: 疾病结果字典
+        :param start_nodes: 起始症状节点列表
+        :return: 格式化后的Markdown字符串
+        """
+        log_data = ["|疾病|症状|出现次数|是否相关"]
+        log_data.append("|--|--|--|--|")
+        filtered_results = {}
+        for item in results:
+            data = results[item]
+            data['relevant'] = False
+            if data["increase"] / len(start_nodes) > 0.5:
+                #cache_key = f'disease_name_ref_id_{data['name']}'
+                data['relevant'] = True
+                filtered_results[item] = data
+                # 初始化疾病的父类疾病
+                # disease_name = data["name"]
+                # key = 'disease_name_parent_' +disease_name
+                # cached_value = self.cache.get(key)
+                # if cached_value is None:
+                #     out_edges = self.graph.out_edges(item, data=True)
+                #
+                #     for edge in out_edges:
+                #         src, dest, edge_data = edge
+                #         if edge_data["type"] != '疾病相关父类':
+                #             continue
+                #         dest_data = self.entity_data[self.entity_data.index == dest]
+                #         if dest_data.empty:
+                #             continue
+                #         dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
+                #         self.cache.set(key, dest_name)
+                #         break
+            if data['relevant'] == False:
+                continue
+
+            log_data.append(f"|{data['name']}|{','.join(data['path'])}|{data['count']}|{data['relevant']}|")
+
+        content = "疾病和症状相关性统计表格\n" + "\n".join(log_data)
+        print(f"\n{content}")
+        return filtered_results
+
+    def step1(self, node_ids,node_id_names, input, allowed_types, allowed_links,max_hops,DIESEASE):
+        """
+        根据症状节点查找相关疾病
+        :param node_ids: 症状节点ID列表
+        :param input: 患者信息输入
+        :param allowed_types: 允许的节点类型
+        :param allowed_links: 允许的关系类型
+        :return: 过滤后的疾病结果
+        """
+        start_time = time.time()
+        results = {}
+        for node in node_ids:
+            visited = set()
+            temp_results = {}
+            cache_key = f"symptom_ref_disease_{str(node)}"
+            cache_data = self.cache[cache_key] if cache_key in self.cache else None
+            if cache_data:
+                temp_results = copy.deepcopy(cache_data)
+                print(cache_key+":"+node_id_names[node] +':'+ str(len(temp_results)))
+                if results=={}:
+                    results = temp_results
+                else:
+                    for disease_id in temp_results:
+                        path = temp_results[disease_id]["path"][0]
+                        if disease_id in results.keys():
+                            results[disease_id]["count"] = results[disease_id]["count"] + temp_results[disease_id]["count"]
+                            results[disease_id]["increase"] = results[disease_id]["increase"]+1
+                            results[disease_id]["path"].append(path)
+                        else:
+                            results[disease_id] = temp_results[disease_id]
+
+
+                continue
+
+            queue = [(node, 0, node_id_names[node],10, {'allowed_types': allowed_types, 'allowed_links': allowed_links})]
+
+            # 整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换
+            if input.pat_age and input.pat_age.value is not None and input.pat_age.value > 0 and input.pat_age.type == 'year':
+                # 这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的
+                input.pat_age.value = input.pat_age.value * 12
+                input.pat_age.type = 'month'
+
+            # STEP 1: 假设start_nodes里面都是症状,第一步我们先找到这些症状对应的疾病
+            # TODO 由于这部分是按照症状逐一去寻找疾病,所以实际应用中可以缓存这些结果
+            while queue:
+                temp_node, depth, path, weight, data = queue.pop(0)
+                temp_node = int(temp_node)
+                # 这里是通过id去获取节点的name和type
+                entity_data = self.entity_data[self.entity_data.index == temp_node]
+                # 如果节点不存在,那么跳过
+                if entity_data.empty:
+                    continue
+                if self.graph.nodes.get(temp_node) is None:
+                    continue
+                node_type = self.entity_data[self.entity_data.index == temp_node]['type'].tolist()[0]
+                node_name = self.entity_data[self.entity_data.index == temp_node]['name'].tolist()[0]
+                # print(f"node {node} type {node_type}")
+                if node_type in DIESEASE:
+                    # print(f"node {node} type {node_type} is a disease")
+                    count = weight
+                    if self.check_diease_allowed(temp_node) == False:
+                        continue
+                    if temp_node in temp_results.keys():
+                        temp_results[temp_node]["count"] = temp_results[temp_node]["count"] + count
+                        results[disease_id]["increase"] = results[disease_id]["increase"] + 1
+                        temp_results[temp_node]["path"].append(path)
+                    else:
+                        temp_results[temp_node] = {"type": node_type, "count": count, "increase": 1, "name": node_name, 'path': [path]}
+
+                    continue
+
+                if temp_node in visited or depth > max_hops:
+                    # print(f"{node} already visited or reach max hops")
+                    continue
+
+                visited.add(temp_node)
+                # print(f"check edges from {node}")
+                if temp_node not in self.graph:
+                    # print(f"node {node} not found in graph")
+                    continue
+                # todo 目前是取入边,出边是不是也有用?
+                for edge in self.graph.in_edges(temp_node, data=True):
+                    src, dest, edge_data = edge
+                    if src not in visited and depth + 1 < max_hops and edge_data['type'] in allowed_links:
+                        # print(f"put into queue travel from {src} to {dest}")
+                        weight = edge_data['weight']
+                        try :
+                            if weight:
+                                if weight < 10:
+                                    weight = 10-weight
+                                else:
+                                    weight = 1
+                            else:
+                                weight = 5
+
+                            if weight>10:
+                                weight = 10
+                        except Exception as e:
+                            print(f'Error processing file {weight}: {str(e)}')
+
+                        queue.append((src, depth + 1, path, int(weight), data))
+                    # else:
+                    # print(f"skip travel from {src} to {dest}")
+            print(cache_key+":"+node_id_names[node]+':'+ str(len(temp_results)))
+            #对temp_results进行深拷贝,然后再进行处理
+            self.cache[cache_key] = copy.deepcopy(temp_results)
+            if results == {}:
+                results = temp_results
+            else:
+                for disease_id in temp_results:
+                    path = temp_results[disease_id]["path"][0]
+                    if disease_id in results.keys():
+                        results[disease_id]["count"] = results[disease_id]["count"] + temp_results[disease_id]["count"]
+                        results[disease_id]["increase"] = results[disease_id]["increase"] + 1
+                        results[disease_id]["path"].append(path)
+                    else:
+                        results[disease_id] = temp_results[disease_id]
+
+        end_time = time.time()
+
+        # 这里我们需要对结果进行过滤,过滤掉不满足条件的疾病
+        new_results = {}
+
+        for item in results:
+            if input.pat_sex and input.pat_sex.value is not None  and self.check_sex_allowed(item, input.pat_sex.value) == False:
+                continue
+            if input.pat_age and input.pat_age.value is not None and self.check_age_allowed(item, input.pat_age.value) == False:
+                continue
+            new_results[item] = results[item]
+        results = new_results
+        print('STEP 1 '+str(len(results)))
+        print(f"STEP 1 执行完成,耗时:{end_time - start_time:.2f}秒")
+        print(f"STEP 1 遍历图谱查找相关疾病 finished")
+        return results
+
+    def step2(self, results,department_edge):
+        """
+        查找疾病对应的科室、检查和药品信息
+        :param results: 包含疾病信息的字典
+        :return: 更新后的results字典
+        """
+        start_time = time.time()
+        print("STEP 2 查找疾病对应的科室、检查和药品 start")
+
+        for disease in results.keys():
+            # cache_key = f"disease_department_{disease}"
+            # cached_data = self.cache.get(cache_key)
+            # if cached_data:
+            #     results[disease]["department"] = cached_data
+            #     continue
+
+            if results[disease]["relevant"] == False:
+                continue
+
+            department_data = []
+            out_edges = self.graph.out_edges(disease, data=True)
+            for edge in out_edges:
+                src, dest, edge_data = edge
+                if edge_data["type"] not in department_edge:
+                    continue
+                dest_data = self.entity_data[self.entity_data.index == dest]
+                if dest_data.empty:
+                    continue
+                department_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
+                department_data.extend([department_name] * results[disease]["count"])
+
+            if department_data:
+                results[disease]["department"] = department_data
+                #self.cache.set(cache_key, department_data)
+
+        print(f"STEP 2 finished")
+        end_time = time.time()
+        print(f"STEP 2 执行完成,耗时:{end_time - start_time:.2f}秒")
+
+        # 输出日志
+        log_data = ["|disease|count|department|check|drug|"]
+        log_data.append("|--|--|--|--|--|")
+        for item in results.keys():
+            department_data = results[item].get("department", [])
+            count_data = results[item].get("count")
+            check_data = results[item].get("check", [])
+            drug_data = results[item].get("drug", [])
+            log_data.append(
+                f"|{results[item].get("name", item)}|{count_data}|{','.join(department_data)}|{','.join(check_data)}|{','.join(drug_data)}|")
+
+        print("疾病科室检查药品相关统计\n" + "\n".join(log_data))
+        return results
+
+    def step3(self, results):
+        print(f"STEP 3 对于结果按照科室维度进行汇总 start")
+        final_results = {}
+        total = 0
+        for disease in results.keys():
+            disease = int(disease)
+            # 由于存在有些疾病没有科室的情况,所以这里需要做一下处理
+            departments = ['DEFAULT']
+            if 'department' in results[disease].keys():
+                departments = results[disease]["department"]
+            else:
+                edges = KGEdgeService(next(get_db())).get_edges_by_nodes(src_id=disease, category='所属科室')
+                #edges有可能为空,这里需要做一下处理
+                if len(edges) > 0:
+                    departments = [edge['dest_node']['name'] for edge in edges]
+            # 处理查询结果
+            for department in departments:
+                total += 1
+                if not department in final_results.keys():
+                    final_results[department] = {
+                        "diseases": [str(disease)+":"+results[disease].get("name", disease)],
+                        "checks": results[disease].get("check", []),
+                        "drugs": results[disease].get("drug", []),
+                        "count": 1
+                    }
+                else:
+                    final_results[department]["diseases"] = final_results[department]["diseases"] + [str(disease)+":"+
+                        results[disease].get("name", disease)]
+                    final_results[department]["checks"] = final_results[department]["checks"] + results[disease].get(
+                        "check", [])
+                    final_results[department]["drugs"] = final_results[department]["drugs"] + results[disease].get(
+                        "drug", [])
+                    final_results[department]["count"] += 1
+
+        # 这里是统计科室出现的分布
+        for department in final_results.keys():
+            final_results[department]["score"] = final_results[department]["count"] / total
+
+        print(f"STEP 3 finished")
+        # 这里输出日志
+        log_data = ["|department|disease|check|drug|count|score"]
+        log_data.append("|--|--|--|--|--|--|")
+        for department in final_results.keys():
+            diesease_data = final_results[department].get("diseases", [])
+            check_data = final_results[department].get("checks", [])
+            drug_data = final_results[department].get("drugs", [])
+            count_data = final_results[department].get("count", 0)
+            score_data = final_results[department].get("score", 0)
+            log_data.append(
+                f"|{department}|{','.join(diesease_data)}|{','.join(check_data)}|{','.join(drug_data)}|{count_data}|{score_data}|")
+
+        print("\n" + "\n".join(log_data))
+        return final_results
+
+    def step4(self, final_results):
+        """
+        对final_results中的疾病、检查和药品进行统计和排序
+        
+        参数:
+            final_results: 包含科室、疾病、检查和药品的字典
+        
+        返回值:
+            排序后的final_results
+        """
+        print(f"STEP 4 start")
+        start_time = time.time()
+
+        def sort_data(data, count=10):
+            tmp = {}
+            for item in data:
+                if item in tmp.keys():
+                    tmp[item]["count"] += 1
+                else:
+                    tmp[item] = {"count": 1}
+            sorted_data = sorted(tmp.items(), key=lambda x: x[1]["count"], reverse=True)
+            return sorted_data[:count]
+
+        for department in final_results.keys():
+            final_results[department]['name'] = department
+            final_results[department]["diseases"] = sort_data(final_results[department]["diseases"])
+            #final_results[department]["checks"] = sort_data(final_results[department]["checks"])
+            #final_results[department]["drugs"] = sort_data(final_results[department]["drugs"])
+
+        # 这里把科室做一个排序,按照出现的次数降序排序
+        sorted_final_results = sorted(final_results.items(), key=lambda x: x[1]["count"], reverse=True)
+  
+        print(f"STEP 4 finished")
+        end_time = time.time()
+        print(f"STEP 4 执行完成,耗时:{end_time - start_time:.2f}秒")
+        # 这里输出markdown日志
+        log_data = ["|department|disease|check|drug|count|score"]
+        log_data.append("|--|--|--|--|--|--|")
+        for department in final_results.keys():
+            diesease_data = final_results[department].get("diseases")
+            check_data = final_results[department].get("checks")
+            drug_data = final_results[department].get("drugs")
+            count_data = final_results[department].get("count", 0)
+            score_data = final_results[department].get("score", 0)
+            log_data.append(f"|{department}|{diesease_data}|{check_data}|{drug_data}|{count_data}|{score_data}|")
+
+        print("\n" + "\n".join(log_data))
+        return sorted_final_results
+
+    def step5(self, final_results, input, start_nodes, symptom_edge):
+        """
+        按科室汇总结果并排序
+
+        参数:
+            final_results: 各科室的初步结果
+            input: 患者输入信息
+
+        返回值:
+            返回排序后的诊断结果
+        """
+        print(f"STEP 5 start")
+        start_time = time.time()
+        diags = {}
+        total_diags = 0
+
+        for department in final_results.keys():
+            department_factor = 0.1 if department == 'DEFAULT' else final_results[department]["score"]
+            count = 0
+            #当前科室权重增加0.1
+            if input.department.value == department:
+                count = 1
+            for disease, data in final_results[department]["diseases"]:
+                total_diags += 1
+                if disease in diags.keys():
+                    diags[disease]["count"] += data["count"]+count
+                    diags[disease]["score"] += (data["count"]+count)*0.1 * department_factor
+                else:
+                    diags[disease] = {"count": data["count"]+count, "score": (data["count"]+count)*0.1 * department_factor}
+  
+        #sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)[:10]
+        sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
+
+        diags = {}
+        for item in sorted_score_diags:
+            disease_info = item[0].split(":");
+            disease_id = disease_info[0]
+            disease = disease_info[1]
+            symptoms_data = self.get_symptoms_data(disease_id, symptom_edge)
+            if symptoms_data is None:
+                continue
+            symptoms = []
+            for symptom in symptoms_data:
+                matched = False
+                if symptom in start_nodes:
+                    matched = True
+                symptoms.append({"name": symptom, "matched": matched})
+            # symtoms中matched=true的排在前面,matched=false的排在后面
+            symptoms = sorted(symptoms, key=lambda x: x["matched"], reverse=True)
+
+            start_nodes_size = len(start_nodes)
+            # if start_nodes_size > 1:
+            #     start_nodes_size = start_nodes_size*0.5
+            new_item = {"old_score": item[1]["score"],"count": item[1]["count"], "score": float(item[1]["count"])/start_nodes_size/2*0.1,"symptoms":symptoms}
+            diags[disease] = new_item
+        sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
+
+        print(f"STEP 5 finished")
+        end_time = time.time()
+        print(f"STEP 5 执行完成,耗时:{end_time - start_time:.2f}秒")
+
+        log_data = ["|department|disease|count|score"]
+        log_data.append("|--|--|--|--|")
+        for department in final_results.keys():
+            diesease_data = final_results[department].get("diseases")
+            count_data = final_results[department].get("count", 0)
+            score_data = final_results[department].get("score", 0)
+            log_data.append(f"|{department}|{diesease_data}|{count_data}|{score_data}|")
+
+        print("这里是经过排序的数据\n" + "\n".join(log_data))
+        return sorted_score_diags, total_diags
+    
+    def get_symptoms_data(self, disease_id, symptom_edge):
+        """
+        获取疾病相关的症状数据
+        :param disease_id: 疾病节点ID
+        :param symptom_edge: 症状关系类型列表
+        :return: 症状数据列表
+        """
+        key = f'disease_{disease_id}_symptom'
+        symptom_data = self.cache[key] if key in self.cache else None
+        if symptom_data is None:
+            out_edges = self.graph.out_edges(int(disease_id), data=True)
+            symptom_data = []
+            for edge in out_edges:
+                src, dest, edge_data = edge
+                if edge_data["type"] not in symptom_edge:
+                    continue
+                dest_data = self.entity_data[self.entity_data.index == dest]
+                if dest_data.empty:
+                    continue
+                dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
+                if dest_name not in symptom_data:
+                    symptom_data.append(dest_name)
+            self.cache[key]=symptom_data
+        return symptom_data

+ 1 - 1
main.py

@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 logger.propagate = True
 
 # 创建FastAPI应用
-app = FastAPI(title="知识图谱")
+app = FastAPI(title="医疗知识",root_path="/knowledge")
 app.include_router(dify_kb_router)
 app.include_router(saas_kb_router)
 app.include_router(text_search_router)

+ 35 - 23
router/graph_router.py

@@ -1,8 +1,11 @@
+import logging
 import sys,os
 
 from agent.cdss.capbility import CDSSCapability
 from agent.cdss.models.schemas import CDSSInput, CDSSInt, CDSSText
 from model.response import StandardResponse
+from service.cdss_service import CdssService
+from service.kg_node_service import KGNodeService
 
 current_path = os.getcwd()
 sys.path.append(current_path)
@@ -17,9 +20,9 @@ import json
 
 
 
-router = APIRouter(prefix="/graph", tags=["Knowledge Graph"])
-
-@router.get("/nodes/recommend", response_model=StandardResponse)
+router = APIRouter(prefix="/disease", tags=["Knowledge Graph"])
+logger = logging.getLogger(__name__)
+@router.get("/recommend", response_model=StandardResponse)
 async def recommend(
     chief: str,
     present_illness: Optional[str] = None,
@@ -31,10 +34,10 @@ async def recommend(
     app_id = "256fd853-60b0-4357-b11b-8114b4e90ae0"
     conversation_id = get_conversation_id(app_id)
 
-    # desc = "主诉:"+chief
-    # if present_illness:
-    #     desc+="\n现病史:" + present_illness
-    result = call_chat_api(app_id, conversation_id, chief)
+    desc = "主诉:"+chief
+    if present_illness:
+        desc+="\n现病史:" + present_illness
+    result = call_chat_api(app_id, conversation_id, desc)
     json_data = json.loads(result)
     keyword = " ".join(json_data["symptoms"])
     result = await neighbor_search(keyword=keyword,sex=sex,age=age, neighbor_type='Check', limit=10)
@@ -42,24 +45,18 @@ async def recommend(
     print(f"recommend执行完成,耗时:{end_time - start_time:.2f}秒")
     return result;
 
-
-@router.get("/nodes/neighbor_search", response_model=StandardResponse)
+@router.get("/neighbor_search", response_model=StandardResponse)
 async def neighbor_search(
     keyword: str = Query(..., min_length=2),
     sex: Optional[str] = None,
     age: Optional[int] = None,
-    department: Optional[str] = None,
-    limit: int = Query(10, ge=1, le=100),
-    node_type: Optional[str] = Query(None),
-    neighbor_type: Optional[str] = Query(None),
-    min_degree: Optional[int] = Query(None)
+    department: Optional[str] = None
 ):
     """
     根据关键词和属性过滤条件搜索图谱节点
     """
     try:
 
-        print(f"开始执行neighbor_search,参数:keyword={keyword}, limit={limit}, node_type={node_type}, neighbor_type={neighbor_type}, min_degree={min_degree}")
         keywords = keyword.split(" ")
 
         record = CDSSInput(
@@ -71,13 +68,14 @@ async def neighbor_search(
         # 使用从main.py导入的capability实例处理CDSS逻辑
         output = capability.process(input=record)
 
-        output.diagnosis.value = [{"name":key,"old_score":value["old_score"],"count":value["count"],"score":value["score"],"symptoms":value["symptoms"],
-            "hasInfo": 1,
-            "type": 1} for key,value in output.diagnosis.value.items()]
+        output.diagnosis.value = [{"name":key,"count":value["count"],"score":value["score"],"symptoms":value["symptoms"],
+            #"hasInfo": 1,
+            #"type": 1
+            } for key,value in output.diagnosis.value.items()]
 
         return StandardResponse(
             success=True,
-            data={"可能诊断":output.diagnosis.value,"症状":keywords,"就诊科室":output.departments.value}
+            data={"可能诊断":output.diagnosis.value,"症状":keywords}
         )
     except Exception as e:
         print(e)
@@ -87,8 +85,22 @@ async def neighbor_search(
             error_code=500,
             error_msg=str(e)
         )
-capability = CDSSCapability()
-#def get_capability():
-    #from main import capability
-    #return capability
+
+@router.get("/{disease_name}/detail", response_model=StandardResponse)
+async def get_disease_detail(
+    disease_name: str
+):
+    try:
+        service = CdssService()
+        result = service.get_disease_detail(disease_name,'疾病')
+        return StandardResponse(success=True, data=result)
+    except Exception as e:
+        logger.error(f"get_disease_detail failed: {str(e)}")
+        return StandardResponse(
+            success=False,
+            error_code=500,
+            error_msg=str(e)
+        )
+
+#capability = CDSSCapability()
 graph_router = router

+ 0 - 13
router/knowledge_saas.py

@@ -146,19 +146,6 @@ async def update_node(
         logger.error(f"更新节点失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
-@router.delete("/nodes/{node_id}", response_model=StandardResponse)
-async def delete_node(
-    node_id: int,
-    db: Session = Depends(get_db)
-):
-    try:
-        service = TrunksService()
-        service.delete_node(node_id)
-        return StandardResponse(success=True)
-    except Exception as e:
-        logger.error(f"删除节点失败: {str(e)}")
-        raise HTTPException(500, detail=StandardResponse.error(str(e)))
-
 @router.post('/trunks/vector_search', response_model=StandardResponse)
 async def vector_search(
     payload: VectorSearchRequest,

+ 71 - 0
service/cdss_service.py

@@ -0,0 +1,71 @@
+from sqlalchemy.orm import Session
+from typing import Optional
+from model.kg_node import KGNode
+from db.session import get_db
+import logging
+from sqlalchemy.exc import IntegrityError
+
+from service.kg_node_service import KGNodeService
+from tests.service.test_kg_node_service import kg_node_service
+from utils import vectorizer
+from utils.vectorizer import Vectorizer
+from sqlalchemy import func
+from service.kg_prop_service import KGPropService
+from service.kg_edge_service import KGEdgeService
+from cachetools import TTLCache
+from cachetools.keys import hashkey
+logger = logging.getLogger(__name__)
+DISTANCE_THRESHOLD = 0.65
+class CdssService:
+    _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
+
+
+    def get_disease_detail(self, diease_name: str,category: str):
+        nodeService = KGNodeService(next(get_db()))
+        diseases = nodeService.search_title_index("", diease_name, category, 1, 0)
+        if diseases is None:
+            return None
+        disease_id = diseases[0]["id"]
+        edgeService = KGEdgeService(next(get_db()))
+
+        #category='疾病相关实验室检查项目'
+        examinations = edgeService.get_edges_by_nodes(src_id=disease_id, category='疾病相关实验室检查项目')
+        #category='疾病相关辅助检查项目'
+        physical_examinations = edgeService.get_edges_by_nodes(src_id=disease_id, category='疾病相关辅助检查项目')
+        # category='疾病相关鉴别诊断'
+        differential_diagnosis = edgeService.get_edges_by_nodes(src_id=disease_id, category='疾病相关鉴别诊断')
+
+
+        e_nodes = []
+        pe_nodes = []
+        ddx_nodes = []
+
+        if examinations:
+            e_nodes = [exam['dest_node'] for exam in examinations]
+            self.add_props_to_nodes(e_nodes)
+        
+        if physical_examinations:
+            pe_nodes = [pe['dest_node'] for pe in physical_examinations]
+            self.add_props_to_nodes(pe_nodes)
+
+        if differential_diagnosis:
+            ddx_nodes = [pe['dest_node'] for pe in differential_diagnosis]
+            self.add_props_to_nodes(ddx_nodes,['Differential Diagnosis'])
+        
+        return {
+            'disease':diseases[0],
+            'examinations': e_nodes,
+            'physical_examinations': pe_nodes,
+            'differential_diagnosis': ddx_nodes
+        }
+
+    def add_props_to_nodes(self, dest_nodes,prop_names=None):
+        kgPropService = KGPropService(next(get_db()))
+        for node in dest_nodes:
+            props = kgPropService.get_props_by_ref_id(node['id'],prop_names)
+            if props:
+                node['props'] = props
+
+
+
+

+ 1 - 1
service/kg_node_service.py

@@ -21,7 +21,7 @@ class KGNodeService:
     _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
 
     def search_title_index(self, index: str, keywrod: str,category: str, top_k: int = 3,distance: float = 0.3) -> Optional[int]:
-        cache_key = f"search_title_index_{index}:{keywrod}:{category}:{top_k}:{float}"
+        cache_key = f"search_title_index_{index}:{keywrod}:{category}:{top_k}:{distance}"
         if cache_key in self._cache:
             return self._cache[cache_key]
 

+ 4 - 3
service/kg_prop_service.py

@@ -11,11 +11,12 @@ class KGPropService:
     def __init__(self, db: Session):
         self.db = db
 
-    def get_props_by_ref_id(self, ref_id: int, prop_name: str = None) -> List[dict]:
+    def get_props_by_ref_id(self, ref_id: int, prop_names: List[str] = None) -> List[dict]:
         try:
             query = self.db.query(KGProp).filter(KGProp.ref_id == ref_id)
-            if prop_name:
-                query = query.filter(KGProp.prop_name == prop_name)
+            if prop_names:
+                #prop_names是一个列表,需要使用in_方法进行查询
+                query = query.filter(KGProp.prop_name.in_(prop_names))
             props = query.all()
             return [{
                 'id': p.id,