SGTY 2 mēneši atpakaļ
vecāks
revīzija
bf94d7456c

+ 1 - 1
agent/cdss/capbility.py

@@ -36,7 +36,7 @@ class CDSSCapability:
 
 
             for item in result["score_diags"][:10]:
             for item in result["score_diags"][:10]:
                 output.diagnosis.value[item[0]] = item[1]
                 output.diagnosis.value[item[0]] = item[1]
-            output.departments.value = result["departments"]
+
 
 
             # for item in result["checks"][:5]:
             # for item in result["checks"][:5]:
             #     item[1]['score'] = item[1]['count'] / result["total_checks"]
             #     item[1]['score'] = item[1]['count'] / result["total_checks"]

+ 60 - 58
agent/cdss/libs/cdss_helper2.py

@@ -82,6 +82,9 @@ class CDSSHelper(GraphHelper):
         self.entity_data.set_index("id", inplace=True)
         self.entity_data.set_index("id", inplace=True)
         print("load entity data finished")
         print("load entity data finished")
 
 
+    def get_entity_data(self):
+        return self.entity_data
+
     def _append_entity_attribute(self, data, item, attr_name):
     def _append_entity_attribute(self, data, item, attr_name):
         if attr_name in item[1]:
         if attr_name in item[1]:
             value = item[1][attr_name].split(":")
             value = item[1][attr_name].split(":")
@@ -236,6 +239,8 @@ class CDSSHelper(GraphHelper):
         return False
         return False
 
 
     def check_diease_allowed(self, node):
     def check_diease_allowed(self, node):
+        if node == 1479768:
+            return True
         is_symptom = self.graph.nodes[node].get('is_symptom', None)
         is_symptom = self.graph.nodes[node].get('is_symptom', None)
         if is_symptom == "是":
         if is_symptom == "是":
             return False
             return False
@@ -285,7 +290,7 @@ class CDSSHelper(GraphHelper):
         symptom_edge = ['has_symptom', '疾病相关症状']
         symptom_edge = ['has_symptom', '疾病相关症状']
         symptom_same_edge = ['症状同义词', '症状同义词2.0']
         symptom_same_edge = ['症状同义词', '症状同义词2.0']
         department_edge = ['belongs_to','所属科室']
         department_edge = ['belongs_to','所属科室']
-        allowed_links = symptom_edge+department_edge+symptom_same_edge
+        allowed_links = symptom_edge+symptom_same_edge
         # allowed_links = symptom_edge + department_edge
         # allowed_links = symptom_edge + department_edge
 
 
         # 将输入的症状名称转换为节点ID
         # 将输入的症状名称转换为节点ID
@@ -314,56 +319,45 @@ class CDSSHelper(GraphHelper):
                 logger.debug(f"node {node} not found")
                 logger.debug(f"node {node} not found")
         node_ids = node_ids_filtered
         node_ids = node_ids_filtered
 
 
-        # out_edges = self.graph.out_edges(disease, data=True)
-        # for edge in out_edges:
-        #     src, dest, edge_data = edge
-        #     if edge_data["type"] not in department_edge:
-        #         continue
-        #     dest_data = self.entity_data[self.entity_data.index == dest]
-        #     if dest_data.empty:
-        #         continue
-        #     department_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
-        #     department_data.extend([department_name] * results[disease]["count"])
-
+        results = self.step1(node_ids,node_id_names, input, allowed_types, symptom_same_edge,allowed_links,max_hops,DIESEASE)
 
 
-        results = self.step1(node_ids,node_id_names, input, allowed_types, allowed_links,max_hops,DIESEASE)
-
-        #self.validDisease(results, start_nodes)
         results = self.validDisease(results, start_nodes)
         results = self.validDisease(results, start_nodes)
 
 
-        sorted_score_diags = sorted(results.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
-
-        # 调用step2方法处理科室、检查和药品信息
-        results = self.step2(results,department_edge)
-
-        # STEP 3: 对于结果按照科室维度进行汇总
-        final_results = self.step3(results)
+        sorted_count_diags = sorted(results.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
+        diags = {}
+        target_symptom_names = []
+        for symptom_id in node_ids:
+            target_symptom_name = self.entity_data[self.entity_data.index == symptom_id]['name'].tolist()[0]
+            target_symptom_names.append(target_symptom_name)
+            same_symptoms = self.get_in_edges(symptom_id, symptom_same_edge)
+            #same_symptoms的name属性全部添加到target_symptom_names中
+            for same_symptom in same_symptoms:
+                target_symptom_names.append(same_symptom['name'].tolist()[0])
+        for item in sorted_count_diags:
+            disease_id = item[0]
+            disease_name = self.entity_data[self.entity_data.index == disease_id]['name'].tolist()[0]
+            symptoms_data = self.get_symptoms_data(disease_id, symptom_edge)
+            if symptoms_data is None:
+                continue
+            symptoms = []
+            for symptom in symptoms_data:
+                matched = False
+                if symptom in target_symptom_names:
+                    matched = True
+                symptoms.append({"name": symptom, "matched": matched})
+            # symtoms中matched=true的排在前面,matched=false的排在后面
+            symptoms = sorted(symptoms, key=lambda x: x["matched"], reverse=True)
 
 
-        sorted_final_results = self.step4(final_results)
-        sorted_final_results = sorted_final_results[:5]
-        departments = []
+            start_nodes_size = len(start_nodes)
+            # if start_nodes_size > 1:
+            #     start_nodes_size = start_nodes_size*0.5
+            new_item = {"count": item[1]["count"],
+                        "score": float(item[1]["count"]) / start_nodes_size  * 0.1, "symptoms": symptoms}
+            diags[disease_name] = new_item
+        sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
 
 
-        for temp in sorted_final_results:
-            departments.append({"name": temp[0], "count": temp[1]["count"]})
-        
-        # STEP 5: 对于final_results里面的diseases, checks和durgs统计全局出现的次数并且按照次数降序排序
-        sorted_score_diags,total_diags = self.step5(final_results, input,start_nodes,symptom_edge)
-
-        # STEP 6: 整合数据并返回
-        # if "department" in item.keys():
-        #     final_results["department"] = list(set(final_results["department"]+item["department"]))
-        # if "diseases" in item.keys():
-        #     final_results["diseases"] = list(set(final_results["diseases"]+item["diseases"]))
-        # if "checks" in item.keys():
-        #     final_results["checks"] = list(set(final_results["checks"]+item["checks"]))
-        # if "drugs" in item.keys():
-        #     final_results["drugs"] = list(set(final_results["drugs"]+item["drugs"]))
-        # if "symptoms" in item.keys():
-        #     final_results["symptoms"] = list(set(final_results["symptoms"]+item["symptoms"]))
-
-        return {"details": sorted_final_results,
-                "score_diags": sorted_score_diags,"total_diags": total_diags,
-                "departments":departments
+        return {
+                "score_diags": sorted_score_diags
                 # "checks":sorted_checks, "drugs":sorted_drugs,
                 # "checks":sorted_checks, "drugs":sorted_drugs,
                 # "total_checks":total_check, "total_drugs":total_drug
                 # "total_checks":total_check, "total_drugs":total_drug
                 }
                 }
@@ -410,8 +404,8 @@ class CDSSHelper(GraphHelper):
         content = "疾病和症状相关性统计表格\n" + "\n".join(log_data)
         content = "疾病和症状相关性统计表格\n" + "\n".join(log_data)
         print(f"\n{content}")
         print(f"\n{content}")
         return filtered_results
         return filtered_results
-
-    def step1(self, node_ids,node_id_names, input, allowed_types, allowed_links,max_hops,DIESEASE):
+    
+    def step1(self, node_ids,node_id_names, input, allowed_types, symptom_same_edge,allowed_links,max_hops,DIESEASE):
         """
         """
         根据症状节点查找相关疾病
         根据症状节点查找相关疾病
         :param node_ids: 症状节点ID列表
         :param node_ids: 症状节点ID列表
@@ -441,12 +435,11 @@ class CDSSHelper(GraphHelper):
                             results[disease_id]["path"].append(path)
                             results[disease_id]["path"].append(path)
                         else:
                         else:
                             results[disease_id] = temp_results[disease_id]
                             results[disease_id] = temp_results[disease_id]
+                continue
 
 
 
 
-                continue
 
 
             queue = [(node, 0, node_id_names[node],10, {'allowed_types': allowed_types, 'allowed_links': allowed_links})]
             queue = [(node, 0, node_id_names[node],10, {'allowed_types': allowed_types, 'allowed_links': allowed_links})]
-
             # 整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换
             # 整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换
             if input.pat_age and input.pat_age.value is not None and input.pat_age.value > 0 and input.pat_age.type == 'year':
             if input.pat_age and input.pat_age.value is not None and input.pat_age.value > 0 and input.pat_age.type == 'year':
                 # 这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的
                 # 这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的
@@ -458,6 +451,8 @@ class CDSSHelper(GraphHelper):
             while queue:
             while queue:
                 temp_node, depth, path, weight, data = queue.pop(0)
                 temp_node, depth, path, weight, data = queue.pop(0)
                 temp_node = int(temp_node)
                 temp_node = int(temp_node)
+                if temp_node == 1479768:
+                    print(1111)
                 # 这里是通过id去获取节点的name和type
                 # 这里是通过id去获取节点的name和type
                 entity_data = self.entity_data[self.entity_data.index == temp_node]
                 entity_data = self.entity_data[self.entity_data.index == temp_node]
                 # 如果节点不存在,那么跳过
                 # 如果节点不存在,那么跳过
@@ -465,6 +460,7 @@ class CDSSHelper(GraphHelper):
                     continue
                     continue
                 if self.graph.nodes.get(temp_node) is None:
                 if self.graph.nodes.get(temp_node) is None:
                     continue
                     continue
+
                 node_type = self.entity_data[self.entity_data.index == temp_node]['type'].tolist()[0]
                 node_type = self.entity_data[self.entity_data.index == temp_node]['type'].tolist()[0]
                 node_name = self.entity_data[self.entity_data.index == temp_node]['name'].tolist()[0]
                 node_name = self.entity_data[self.entity_data.index == temp_node]['name'].tolist()[0]
                 # print(f"node {node} type {node_type}")
                 # print(f"node {node} type {node_type}")
@@ -473,11 +469,11 @@ class CDSSHelper(GraphHelper):
                     count = weight
                     count = weight
                     if self.check_diease_allowed(temp_node) == False:
                     if self.check_diease_allowed(temp_node) == False:
                         continue
                         continue
-                    if temp_node in temp_results.keys():
-                        temp_results[temp_node]["count"] = temp_results[temp_node]["count"] + count
-                        results[disease_id]["increase"] = results[disease_id]["increase"] + 1
-                        temp_results[temp_node]["path"].append(path)
-                    else:
+                    if temp_node not in temp_results.keys():
+                    #     temp_results[temp_node]["count"] = temp_results[temp_node]["count"] + count
+                    #     temp_results[temp_node]["increase"] = temp_results[temp_node]["increase"] + 1
+                    #     temp_results[temp_node]["path"].append(path)
+                    # else:
                         temp_results[temp_node] = {"type": node_type, "count": count, "increase": 1, "name": node_name, 'path': [path]}
                         temp_results[temp_node] = {"type": node_type, "count": count, "increase": 1, "name": node_name, 'path': [path]}
 
 
                     continue
                     continue
@@ -491,11 +487,9 @@ class CDSSHelper(GraphHelper):
                 if temp_node not in self.graph:
                 if temp_node not in self.graph:
                     # print(f"node {node} not found in graph")
                     # print(f"node {node} not found in graph")
                     continue
                     continue
-                # todo 目前是取入边,出边是不是也有用?
                 for edge in self.graph.in_edges(temp_node, data=True):
                 for edge in self.graph.in_edges(temp_node, data=True):
                     src, dest, edge_data = edge
                     src, dest, edge_data = edge
-                    if src not in visited and depth + 1 < max_hops and edge_data['type'] in allowed_links:
-                        # print(f"put into queue travel from {src} to {dest}")
+                    if src not in visited and depth + 1 <= max_hops and edge_data['type'] in allowed_links:                       
                         weight = edge_data['weight']
                         weight = edge_data['weight']
                         try :
                         try :
                             if weight:
                             if weight:
@@ -546,6 +540,14 @@ class CDSSHelper(GraphHelper):
         print(f"STEP 1 遍历图谱查找相关疾病 finished")
         print(f"STEP 1 遍历图谱查找相关疾病 finished")
         return results
         return results
 
 
+    def get_in_edges(self, node_id, allowed_links):
+        results = []
+        for edge in self.graph.in_edges(node_id, data=True):
+            src, dest, edge_data = edge
+            if edge_data['type'] in allowed_links:
+                results.append(self.entity_data[self.entity_data.index == src])
+        return results
+
     def step2(self, results,department_edge):
     def step2(self, results,department_edge):
         """
         """
         查找疾病对应的科室、检查和药品信息
         查找疾病对应的科室、检查和药品信息

+ 800 - 0
agent/cdss/libs/cdss_helper3.py

@@ -0,0 +1,800 @@
+import copy
+from hmac import new
+import os
+import sys
+import logging
+import json
+import time
+
+from sqlalchemy import false
+
+from service.kg_edge_service import KGEdgeService
+from db.session import get_db
+from service.kg_node_service import KGNodeService
+from service.kg_prop_service import KGPropService
+from cachetools import TTLCache
+from cachetools.keys import hashkey
+
+current_path = os.getcwd()
+sys.path.append(current_path)
+
+from community.graph_helper import GraphHelper
+from typing import List
+from agent.cdss.models.schemas import CDSSInput
+from config.site import SiteConfig
+import networkx as nx
+import pandas as pd
+
+logger = logging.getLogger(__name__)
+
+current_path = os.getcwd()
+sys.path.append(current_path)
+
+# 图谱数据缓存路径(由dump_graph_data.py生成)
+CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data')
+
+
+class CDSSHelper(GraphHelper):
+
+    def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
+
+        kg_node_service = KGNodeService(next(get_db()))
+        es_result = kg_node_service.search_title_index("", node_id,node_type, limit)
+        results = []
+        for item in es_result:
+            n = self.graph.nodes.get(item["id"])
+            score = item["score"]
+            if n:
+                results.append({
+                    'id': item["title"],
+                    'score': score,
+                    "name": item["title"],
+                })
+        return results
+
+    def _load_entity_data(self):
+        config = SiteConfig()
+        # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
+        print("load entity data")
+        # 这里设置了读取的属性
+        data = {"id": [], "name": [], "type": [],"is_symptom": [], "sex": [], "age": [],"score": []}
+        if not os.path.exists(os.path.join(CACHED_DATA_PATH, 'entities_med.json')):
+            return []
+        with open(os.path.join(CACHED_DATA_PATH, 'entities_med.json'), "r", encoding="utf-8") as f:
+            entities = json.load(f)
+            for item in entities:
+                #如果id已经存在,则跳过
+                # if item[0] in data["id"]:
+                #     print(f"skip {item[0]}")
+                #     continue
+                data["id"].append(int(item[0]))
+                data["name"].append(item[1]["name"])
+                data["type"].append(item[1]["type"])
+                self._append_entity_attribute(data, item, "sex")
+                self._append_entity_attribute(data, item, "age")
+                self._append_entity_attribute(data, item, "is_symptom")
+                self._append_entity_attribute(data, item, "score")
+                # item[1]["id"] = item[0]
+                # item[1]["name"] = item[0]
+                # attrs = item[1]
+                # self.graph.add_node(item[0], **attrs)
+        self.entity_data = pd.DataFrame(data)
+        self.entity_data.set_index("id", inplace=True)
+        print("load entity data finished")
+
+    def _append_entity_attribute(self, data, item, attr_name):
+        if attr_name in item[1]:
+            value = item[1][attr_name].split(":")
+            if len(value) < 2:
+                data[attr_name].append(value[0])
+            else:
+                data[attr_name].append(value[1])
+        else:
+            data[attr_name].append("")
+
+    def _load_relation_data(self):
+        config = SiteConfig()
+        # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
+        print("load relationship data")
+
+        for i in range(62):
+            if not os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
+                continue
+            if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
+                print(f"load entity data {CACHED_DATA_PATH}\\relationship_med_{i}.json")
+                with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f:
+                    data = {"src": [], "dest": [], "type": [], "weight": []}
+                    relations = json.load(f)
+                    for item in relations:
+                        data["src"].append(int(item[0]))
+                        data["dest"].append(int(item[2]))
+                        data["type"].append(item[4]["type"])
+                        if "order" in item[4]:
+                            order = item[4]["order"].split(":")
+                            if len(order) < 2:
+                                data["weight"].append(int(order[0]))
+                            else:
+                                data["weight"].append(int(order[1]))
+                        else:
+                            data["weight"].append(1)
+
+                    self.relation_data = pd.concat([self.relation_data, pd.DataFrame(data)], ignore_index=True)
+
+    def build_graph(self):
+        self.entity_data = pd.DataFrame(
+            {"id": [], "name": [], "type": [], "sex": [], "allowed_age_range": []})
+        self.relation_data = pd.DataFrame({"src": [], "dest": [], "type": [], "weight": []})
+        self._load_entity_data()
+        self._load_relation_data()
+        self._load_local_data()
+
+        self.graph = nx.from_pandas_edgelist(self.relation_data, "src", "dest", edge_attr=True,
+                                             create_using=nx.DiGraph())
+
+        nx.set_node_attributes(self.graph, self.entity_data.to_dict(orient="index"))
+        # print(self.graph.in_edges('1257357',data=True))
+
+    def _load_local_data(self):
+        # 这里加载update数据和权重数据
+        config = SiteConfig()
+        self.update_data_path = config.get_config('UPDATE_DATA_PATH')
+        self.factor_data_path = config.get_config('FACTOR_DATA_PATH')
+        print(f"load update data from {self.update_data_path}")
+        for root, dirs, files in os.walk(self.update_data_path):
+            for file in files:
+                file_path = os.path.join(root, file)
+                if file_path.endswith(".json") and file.startswith("ent"):
+                    self._load_update_entity_json(file_path)
+                if file_path.endswith(".json") and file.startswith("rel"):
+                    self._load_update_relationship_json(file_path)
+
+    def _load_update_entity_json(self, file):
+        '''load json data from file'''
+        print(f"load entity update data from {file}")
+
+        # 这里加载update数据,update数据是一个json文件,格式同cached data如下:
+
+        with open(file, "r", encoding="utf-8") as f:
+            entities = json.load(f)
+            for item in entities:
+                original_data = self.entity_data[self.entity_data.index == item[0]]
+                if original_data.empty:
+                    continue
+                original_data = original_data.iloc[0]
+                id = int(item[0])
+                name = item[1]["name"] if "name" in item[1] else original_data['name']
+                type = item[1]["type"] if "type" in item[1] else original_data['type']
+                allowed_sex_list = item[1]["allowed_sex_list"] if "allowed_sex_list" in item[1] else original_data[
+                    'allowed_sex_list']
+                allowed_age_range = item[1]["allowed_age_range"] if "allowed_age_range" in item[1] else original_data[
+                    'allowed_age_range']
+
+                self.entity_data.loc[id, ["name", "type", "allowed_sex_list", "allowed_age_range"]] = [name, type,
+                                                                                                       allowed_sex_list,
+                                                                                                       allowed_age_range]
+
+    def _load_update_relationship_json(self, file):
+        '''load json data from file'''
+        print(f"load relationship update data from {file}")
+
+        with open(file, "r", encoding="utf-8") as f:
+            relations = json.load(f)
+            for item in relations:
+                data = {}
+                original_data = self.relation_data[(self.relation_data['src'] == data['src']) &
+                                                   (self.relation_data['dest'] == data['dest']) &
+                                                   (self.relation_data['type'] == data['type'])]
+                if original_data.empty:
+                    continue
+                original_data = original_data.iloc[0]
+                data["src"] = int(item[0])
+                data["dest"] = int(item[2])
+                data["type"] = item[4]["type"]
+                data["weight"] = item[4]["weight"] if "weight" in item[4] else original_data['weight']
+
+                self.relation_data.loc[(self.relation_data['src'] == data['src']) &
+                                       (self.relation_data['dest'] == data['dest']) &
+                                       (self.relation_data['type'] == data['type']), 'weight'] = data["weight"]
+
+    def check_sex_allowed(self, node, sex):
+
+        # 性别过滤,假设疾病节点有一个属性叫做allowed_sex_type,值为“0,1,2”,分别代表未知,男,女
+        sex_allowed = self.graph.nodes[node].get('sex', None)
+        #sexProps = self.propService.get_props_by_ref_id(node, 'sex')
+        #if len(sexProps) > 0 and sexProps[0]['prop_value'] is not None and sexProps[0][
+            #'prop_value'] != input.pat_sex.value:
+            #continue
+        if sex_allowed:
+            if len(sex_allowed) == 0:
+                # 如果性别列表为空,那么默认允许所有性别
+                return True
+            sex_allowed_list = sex_allowed.split(',')
+            if sex not in sex_allowed_list:
+                # 如果性别不匹配,跳过
+                return False
+        return True
+
+    def check_age_allowed(self, node, age):
+        # 年龄过滤,假设疾病节点有一个属性叫做allowed_age_range,值为“6-88”,代表年龄在0-88月之间是允许的
+        # 如果说年龄小于6岁,那么我们就认为是儿童,所以儿童的年龄范围是0-6月
+        age_allowed = self.graph.nodes[node].get('age', None)
+        if age_allowed:
+            if len(age_allowed) == 0:
+                # 如果年龄范围为空,那么默认允许所有年龄
+                return True
+            age_allowed_list = age_allowed.split('-')
+            age_min = int(age_allowed_list[0])
+            age_max = int(age_allowed_list[-1])
+            if age_max ==0:
+                return True
+            if age >= age_min and age < age_max:
+                # 如果年龄范围正常,那么返回True
+                return True
+        else:
+            # 如果没有设置年龄范围,那么默认返回True
+            return True
+        return False
+
+    def check_diease_allowed(self, node):
+        is_symptom = self.graph.nodes[node].get('is_symptom', None)
+        if is_symptom == "是":
+            return False
+        return True
+
+    propService = KGPropService(next(get_db()))
+    cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
+    def cdss_travel(self, input: CDSSInput, start_nodes: List, max_hops=3):
+        """
+        基于输入的症状节点,在知识图谱中进行遍历,查找相关疾病、科室、检查和药品
+
+        参数:
+            input: CDSSInput对象,包含患者的基本信息(年龄、性别等)
+            start_nodes: 症状节点名称列表,作为遍历的起点
+            max_hops: 最大遍历深度,默认为3
+
+        返回值:
+            返回一个包含以下信息的字典:
+                - details: 按科室汇总的结果
+                - diags: 按相关性排序的疾病列表
+                - checks: 按出现频率排序的检查列表
+                - drugs: 按出现频率排序的药品列表
+                - total_diags: 疾病总数
+                - total_checks: 检查总数
+                - total_drugs: 药品总数
+
+        主要步骤:
+            1. 初始化允许的节点类型和关系类型
+            2. 将症状名称转换为节点ID
+            3. 遍历图谱查找相关疾病(STEP 1)
+            4. 查找疾病对应的科室、检查和药品(STEP 2)
+            5. 按科室汇总结果(STEP 3)
+            6. 对结果进行排序和统计(STEP 4-6)
+        """
+        # 定义允许的节点类型,包括科室、疾病、药品、检查和症状
+        # 这些类型用于后续的节点过滤和路径查找
+        DEPARTMENT = ['科室', 'Department']
+        DIESEASE = ['疾病', 'Disease']
+        DRUG = ['药品', 'Drug']
+        CHECK = ['检查', 'Check']
+        SYMPTOM = ['症状', 'Symptom']
+        #allowed_types = DEPARTMENT + DIESEASE + DRUG + CHECK + SYMPTOM
+        allowed_types = DEPARTMENT + DIESEASE + SYMPTOM
+        # 定义允许的关系类型,包括has_symptom、need_check、recommend_drug、belongs_to
+        # 这些关系类型用于后续的路径查找和过滤
+
+        symptom_edge = ['has_symptom', '疾病相关症状']
+        symptom_same_edge = ['症状同义词', '症状同义词2.0']
+        department_edge = ['belongs_to','所属科室']
+        allowed_links = symptom_edge+symptom_same_edge
+        # allowed_links = symptom_edge + department_edge
+
+        # 将输入的症状名称转换为节点ID
+        # 由于可能存在同名节点,转换后的节点ID数量可能大于输入的症状数量
+        node_ids = []
+        node_id_names = {}
+
+        # start_nodes里面重复的症状,去重同样的症状
+        start_nodes = list(set(start_nodes))
+        for node in start_nodes:
+            #print(f"searching for node {node}")
+            result = self.entity_data[self.entity_data['name'] == node]
+            #print(f"searching for node {result}")
+            for index, data in result.iterrows():
+                if data["type"] in SYMPTOM or data["type"] in DIESEASE:
+                    node_id_names[index] = data["name"]
+                    node_ids = node_ids + [index]
+        #print(f"start travel from {node_id_names}")
+
+        # 这里是一个队列,用于存储待遍历的症状:
+        node_ids_filtered = []
+        for node in node_ids:
+            if self.graph.has_node(node):
+                node_ids_filtered.append(node)
+            else:
+                logger.debug(f"node {node} not found")
+        node_ids = node_ids_filtered
+
+        # out_edges = self.graph.out_edges(disease, data=True)
+        # for edge in out_edges:
+        #     src, dest, edge_data = edge
+        #     if edge_data["type"] not in department_edge:
+        #         continue
+        #     dest_data = self.entity_data[self.entity_data.index == dest]
+        #     if dest_data.empty:
+        #         continue
+        #     department_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
+        #     department_data.extend([department_name] * results[disease]["count"])
+
+
+        results = self.step1(node_ids,node_id_names, input, allowed_types, allowed_links,max_hops,DIESEASE)
+
+        #self.validDisease(results, start_nodes)
+        results = self.validDisease(results, start_nodes)
+
+        sorted_score_diags = sorted(results.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
+
+        # 调用step2方法处理科室、检查和药品信息
+        results = self.step2(results,department_edge)
+
+        # STEP 3: 对于结果按照科室维度进行汇总
+        final_results = self.step3(results)
+
+        sorted_final_results = self.step4(final_results)
+        sorted_final_results = sorted_final_results[:5]
+        departments = []
+
+        for temp in sorted_final_results:
+            departments.append({"name": temp[0], "count": temp[1]["count"]})
+        
+        # STEP 5: 对于final_results里面的diseases, checks和durgs统计全局出现的次数并且按照次数降序排序
+        sorted_score_diags,total_diags = self.step5(final_results, input,start_nodes,symptom_edge)
+
+        # STEP 6: 整合数据并返回
+        # if "department" in item.keys():
+        #     final_results["department"] = list(set(final_results["department"]+item["department"]))
+        # if "diseases" in item.keys():
+        #     final_results["diseases"] = list(set(final_results["diseases"]+item["diseases"]))
+        # if "checks" in item.keys():
+        #     final_results["checks"] = list(set(final_results["checks"]+item["checks"]))
+        # if "drugs" in item.keys():
+        #     final_results["drugs"] = list(set(final_results["drugs"]+item["drugs"]))
+        # if "symptoms" in item.keys():
+        #     final_results["symptoms"] = list(set(final_results["symptoms"]+item["symptoms"]))
+
+        return {"details": sorted_final_results,
+                "score_diags": sorted_score_diags,"total_diags": total_diags,
+                "departments":departments
+                # "checks":sorted_checks, "drugs":sorted_drugs,
+                # "total_checks":total_check, "total_drugs":total_drug
+                }
+
+    def validDisease(self, results, start_nodes):
+        """
+        输出有效的疾病信息为Markdown格式
+        :param results: 疾病结果字典
+        :param start_nodes: 起始症状节点列表
+        :return: 格式化后的Markdown字符串
+        """
+        log_data = ["|疾病|症状|出现次数|是否相关"]
+        log_data.append("|--|--|--|--|")
+        filtered_results = {}
+        for item in results:
+            data = results[item]
+            data['relevant'] = False
+            if data["increase"] / len(start_nodes) > 0.5:
+                #cache_key = f'disease_name_ref_id_{data['name']}'
+                data['relevant'] = True
+                filtered_results[item] = data
+                # 初始化疾病的父类疾病
+                # disease_name = data["name"]
+                # key = 'disease_name_parent_' +disease_name
+                # cached_value = self.cache.get(key)
+                # if cached_value is None:
+                #     out_edges = self.graph.out_edges(item, data=True)
+                #
+                #     for edge in out_edges:
+                #         src, dest, edge_data = edge
+                #         if edge_data["type"] != '疾病相关父类':
+                #             continue
+                #         dest_data = self.entity_data[self.entity_data.index == dest]
+                #         if dest_data.empty:
+                #             continue
+                #         dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
+                #         self.cache.set(key, dest_name)
+                #         break
+            if data['relevant'] == False:
+                continue
+
+            log_data.append(f"|{data['name']}|{','.join(data['path'])}|{data['count']}|{data['relevant']}|")
+
+        content = "疾病和症状相关性统计表格\n" + "\n".join(log_data)
+        print(f"\n{content}")
+        return filtered_results
+
+    def step1(self, node_ids,node_id_names, input, allowed_types, allowed_links,max_hops,DIESEASE):
+        """
+        根据症状节点查找相关疾病
+        :param node_ids: 症状节点ID列表
+        :param input: 患者信息输入
+        :param allowed_types: 允许的节点类型
+        :param allowed_links: 允许的关系类型
+        :return: 过滤后的疾病结果
+        """
+        start_time = time.time()
+        results = {}
+        for node in node_ids:
+            visited = set()
+            temp_results = {}
+            cache_key = f"symptom_ref_disease_{str(node)}"
+            cache_data = self.cache[cache_key] if cache_key in self.cache else None
+            if cache_data:
+                temp_results = copy.deepcopy(cache_data)
+                print(cache_key+":"+node_id_names[node] +':'+ str(len(temp_results)))
+                if results=={}:
+                    results = temp_results
+                else:
+                    for disease_id in temp_results:
+                        path = temp_results[disease_id]["path"][0]
+                        if disease_id in results.keys():
+                            results[disease_id]["count"] = results[disease_id]["count"] + temp_results[disease_id]["count"]
+                            results[disease_id]["increase"] = results[disease_id]["increase"]+1
+                            results[disease_id]["path"].append(path)
+                        else:
+                            results[disease_id] = temp_results[disease_id]
+
+
+                continue
+
+            queue = [(node, 0, node_id_names[node],10, {'allowed_types': allowed_types, 'allowed_links': allowed_links})]
+
+            # 整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换
+            if input.pat_age and input.pat_age.value is not None and input.pat_age.value > 0 and input.pat_age.type == 'year':
+                # 这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的
+                input.pat_age.value = input.pat_age.value * 12
+                input.pat_age.type = 'month'
+
+            # STEP 1: 假设start_nodes里面都是症状,第一步我们先找到这些症状对应的疾病
+            # TODO 由于这部分是按照症状逐一去寻找疾病,所以实际应用中可以缓存这些结果
+            while queue:
+                temp_node, depth, path, weight, data = queue.pop(0)
+                temp_node = int(temp_node)
+                # 这里是通过id去获取节点的name和type
+                entity_data = self.entity_data[self.entity_data.index == temp_node]
+                # 如果节点不存在,那么跳过
+                if entity_data.empty:
+                    continue
+                if self.graph.nodes.get(temp_node) is None:
+                    continue
+                node_type = self.entity_data[self.entity_data.index == temp_node]['type'].tolist()[0]
+                node_name = self.entity_data[self.entity_data.index == temp_node]['name'].tolist()[0]
+                # print(f"node {node} type {node_type}")
+                if node_type in DIESEASE:
+                    # print(f"node {node} type {node_type} is a disease")
+                    count = weight
+                    if self.check_diease_allowed(temp_node) == False:
+                        continue
+                    if temp_node in temp_results.keys():
+                        temp_results[temp_node]["count"] = temp_results[temp_node]["count"] + count
+                        results[disease_id]["increase"] = results[disease_id]["increase"] + 1
+                        temp_results[temp_node]["path"].append(path)
+                    else:
+                        temp_results[temp_node] = {"type": node_type, "count": count, "increase": 1, "name": node_name, 'path': [path]}
+
+                    continue
+
+                if temp_node in visited or depth > max_hops:
+                    # print(f"{node} already visited or reach max hops")
+                    continue
+
+                visited.add(temp_node)
+                # print(f"check edges from {node}")
+                if temp_node not in self.graph:
+                    # print(f"node {node} not found in graph")
+                    continue
+                # todo 目前是取入边,出边是不是也有用?
+                for edge in self.graph.in_edges(temp_node, data=True):
+                    src, dest, edge_data = edge
+                    if src not in visited and depth + 1 < max_hops and edge_data['type'] in allowed_links:
+                        # print(f"put into queue travel from {src} to {dest}")
+                        weight = edge_data['weight']
+                        try :
+                            if weight:
+                                if weight < 10:
+                                    weight = 10-weight
+                                else:
+                                    weight = 1
+                            else:
+                                weight = 5
+
+                            if weight>10:
+                                weight = 10
+                        except Exception as e:
+                            print(f'Error processing file {weight}: {str(e)}')
+
+                        queue.append((src, depth + 1, path, int(weight), data))
+                    # else:
+                    # print(f"skip travel from {src} to {dest}")
+            print(cache_key+":"+node_id_names[node]+':'+ str(len(temp_results)))
+            #对temp_results进行深拷贝,然后再进行处理
+            self.cache[cache_key] = copy.deepcopy(temp_results)
+            if results == {}:
+                results = temp_results
+            else:
+                for disease_id in temp_results:
+                    path = temp_results[disease_id]["path"][0]
+                    if disease_id in results.keys():
+                        results[disease_id]["count"] = results[disease_id]["count"] + temp_results[disease_id]["count"]
+                        results[disease_id]["increase"] = results[disease_id]["increase"] + 1
+                        results[disease_id]["path"].append(path)
+                    else:
+                        results[disease_id] = temp_results[disease_id]
+
+        end_time = time.time()
+
+        # 这里我们需要对结果进行过滤,过滤掉不满足条件的疾病
+        new_results = {}
+
+        for item in results:
+            if input.pat_sex and input.pat_sex.value is not None  and self.check_sex_allowed(item, input.pat_sex.value) == False:
+                continue
+            if input.pat_age and input.pat_age.value is not None and self.check_age_allowed(item, input.pat_age.value) == False:
+                continue
+            new_results[item] = results[item]
+        results = new_results
+        print('STEP 1 '+str(len(results)))
+        print(f"STEP 1 执行完成,耗时:{end_time - start_time:.2f}秒")
+        print(f"STEP 1 遍历图谱查找相关疾病 finished")
+        return results
+
+    def step2(self, results,department_edge):
+        """
+        查找疾病对应的科室、检查和药品信息
+        :param results: 包含疾病信息的字典
+        :return: 更新后的results字典
+        """
+        start_time = time.time()
+        print("STEP 2 查找疾病对应的科室、检查和药品 start")
+
+        for disease in results.keys():
+            # cache_key = f"disease_department_{disease}"
+            # cached_data = self.cache.get(cache_key)
+            # if cached_data:
+            #     results[disease]["department"] = cached_data
+            #     continue
+
+            if results[disease]["relevant"] == False:
+                continue
+
+            department_data = []
+            out_edges = self.graph.out_edges(disease, data=True)
+            for edge in out_edges:
+                src, dest, edge_data = edge
+                if edge_data["type"] not in department_edge:
+                    continue
+                dest_data = self.entity_data[self.entity_data.index == dest]
+                if dest_data.empty:
+                    continue
+                department_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
+                department_data.extend([department_name] * results[disease]["count"])
+
+            if department_data:
+                results[disease]["department"] = department_data
+                #self.cache.set(cache_key, department_data)
+
+        print(f"STEP 2 finished")
+        end_time = time.time()
+        print(f"STEP 2 执行完成,耗时:{end_time - start_time:.2f}秒")
+
+        # 输出日志
+        log_data = ["|disease|count|department|check|drug|"]
+        log_data.append("|--|--|--|--|--|")
+        for item in results.keys():
+            department_data = results[item].get("department", [])
+            count_data = results[item].get("count")
+            check_data = results[item].get("check", [])
+            drug_data = results[item].get("drug", [])
+            log_data.append(
+                f"|{results[item].get("name", item)}|{count_data}|{','.join(department_data)}|{','.join(check_data)}|{','.join(drug_data)}|")
+
+        print("疾病科室检查药品相关统计\n" + "\n".join(log_data))
+        return results
+
+    def step3(self, results):
+        print(f"STEP 3 对于结果按照科室维度进行汇总 start")
+        final_results = {}
+        total = 0
+        for disease in results.keys():
+            disease = int(disease)
+            # 由于存在有些疾病没有科室的情况,所以这里需要做一下处理
+            departments = ['DEFAULT']
+            if 'department' in results[disease].keys():
+                departments = results[disease]["department"]
+            else:
+                edges = KGEdgeService(next(get_db())).get_edges_by_nodes(src_id=disease, category='所属科室')
+                #edges有可能为空,这里需要做一下处理
+                if len(edges) > 0:
+                    departments = [edge['dest_node']['name'] for edge in edges]
+            # 处理查询结果
+            for department in departments:
+                total += 1
+                if not department in final_results.keys():
+                    final_results[department] = {
+                        "diseases": [str(disease)+":"+results[disease].get("name", disease)],
+                        "checks": results[disease].get("check", []),
+                        "drugs": results[disease].get("drug", []),
+                        "count": 1
+                    }
+                else:
+                    final_results[department]["diseases"] = final_results[department]["diseases"] + [str(disease)+":"+
+                        results[disease].get("name", disease)]
+                    final_results[department]["checks"] = final_results[department]["checks"] + results[disease].get(
+                        "check", [])
+                    final_results[department]["drugs"] = final_results[department]["drugs"] + results[disease].get(
+                        "drug", [])
+                    final_results[department]["count"] += 1
+
+        # 这里是统计科室出现的分布
+        for department in final_results.keys():
+            final_results[department]["score"] = final_results[department]["count"] / total
+
+        print(f"STEP 3 finished")
+        # 这里输出日志
+        log_data = ["|department|disease|check|drug|count|score"]
+        log_data.append("|--|--|--|--|--|--|")
+        for department in final_results.keys():
+            diesease_data = final_results[department].get("diseases", [])
+            check_data = final_results[department].get("checks", [])
+            drug_data = final_results[department].get("drugs", [])
+            count_data = final_results[department].get("count", 0)
+            score_data = final_results[department].get("score", 0)
+            log_data.append(
+                f"|{department}|{','.join(diesease_data)}|{','.join(check_data)}|{','.join(drug_data)}|{count_data}|{score_data}|")
+
+        print("\n" + "\n".join(log_data))
+        return final_results
+
+    def step4(self, final_results):
+        """
+        对final_results中的疾病、检查和药品进行统计和排序
+        
+        参数:
+            final_results: 包含科室、疾病、检查和药品的字典
+        
+        返回值:
+            排序后的final_results
+        """
+        print(f"STEP 4 start")
+        start_time = time.time()
+
+        def sort_data(data, count=10):
+            tmp = {}
+            for item in data:
+                if item in tmp.keys():
+                    tmp[item]["count"] += 1
+                else:
+                    tmp[item] = {"count": 1}
+            sorted_data = sorted(tmp.items(), key=lambda x: x[1]["count"], reverse=True)
+            return sorted_data[:count]
+
+        for department in final_results.keys():
+            final_results[department]['name'] = department
+            final_results[department]["diseases"] = sort_data(final_results[department]["diseases"])
+            #final_results[department]["checks"] = sort_data(final_results[department]["checks"])
+            #final_results[department]["drugs"] = sort_data(final_results[department]["drugs"])
+
+        # 这里把科室做一个排序,按照出现的次数降序排序
+        sorted_final_results = sorted(final_results.items(), key=lambda x: x[1]["count"], reverse=True)
+  
+        print(f"STEP 4 finished")
+        end_time = time.time()
+        print(f"STEP 4 执行完成,耗时:{end_time - start_time:.2f}秒")
+        # 这里输出markdown日志
+        log_data = ["|department|disease|check|drug|count|score"]
+        log_data.append("|--|--|--|--|--|--|")
+        for department in final_results.keys():
+            diesease_data = final_results[department].get("diseases")
+            check_data = final_results[department].get("checks")
+            drug_data = final_results[department].get("drugs")
+            count_data = final_results[department].get("count", 0)
+            score_data = final_results[department].get("score", 0)
+            log_data.append(f"|{department}|{diesease_data}|{check_data}|{drug_data}|{count_data}|{score_data}|")
+
+        print("\n" + "\n".join(log_data))
+        return sorted_final_results
+
+    def step5(self, final_results, input, start_nodes, symptom_edge):
+        """
+        按科室汇总结果并排序
+
+        参数:
+            final_results: 各科室的初步结果
+            input: 患者输入信息
+
+        返回值:
+            返回排序后的诊断结果
+        """
+        print(f"STEP 5 start")
+        start_time = time.time()
+        diags = {}
+        total_diags = 0
+
+        for department in final_results.keys():
+            department_factor = 0.1 if department == 'DEFAULT' else final_results[department]["score"]
+            count = 0
+            #当前科室权重增加0.1
+            if input.department.value == department:
+                count = 1
+            for disease, data in final_results[department]["diseases"]:
+                total_diags += 1
+                if disease in diags.keys():
+                    diags[disease]["count"] += data["count"]+count
+                    diags[disease]["score"] += (data["count"]+count)*0.1 * department_factor
+                else:
+                    diags[disease] = {"count": data["count"]+count, "score": (data["count"]+count)*0.1 * department_factor}
+  
+        #sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)[:10]
+        sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
+
+        diags = {}
+        for item in sorted_score_diags:
+            disease_info = item[0].split(":");
+            disease_id = disease_info[0]
+            disease = disease_info[1]
+            symptoms_data = self.get_symptoms_data(disease_id, symptom_edge)
+            if symptoms_data is None:
+                continue
+            symptoms = []
+            for symptom in symptoms_data:
+                matched = False
+                if symptom in start_nodes:
+                    matched = True
+                symptoms.append({"name": symptom, "matched": matched})
+            # symtoms中matched=true的排在前面,matched=false的排在后面
+            symptoms = sorted(symptoms, key=lambda x: x["matched"], reverse=True)
+
+            start_nodes_size = len(start_nodes)
+            # if start_nodes_size > 1:
+            #     start_nodes_size = start_nodes_size*0.5
+            new_item = {"old_score": item[1]["score"],"count": item[1]["count"], "score": float(item[1]["count"])/start_nodes_size/2*0.1,"symptoms":symptoms}
+            diags[disease] = new_item
+        sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
+
+        print(f"STEP 5 finished")
+        end_time = time.time()
+        print(f"STEP 5 执行完成,耗时:{end_time - start_time:.2f}秒")
+
+        log_data = ["|department|disease|count|score"]
+        log_data.append("|--|--|--|--|")
+        for department in final_results.keys():
+            diesease_data = final_results[department].get("diseases")
+            count_data = final_results[department].get("count", 0)
+            score_data = final_results[department].get("score", 0)
+            log_data.append(f"|{department}|{diesease_data}|{count_data}|{score_data}|")
+
+        print("这里是经过排序的数据\n" + "\n".join(log_data))
+        return sorted_score_diags, total_diags
+    
+    def get_symptoms_data(self, disease_id, symptom_edge):
+        """
+        获取疾病相关的症状数据
+        :param disease_id: 疾病节点ID
+        :param symptom_edge: 症状关系类型列表
+        :return: 症状数据列表
+        """
+        key = f'disease_{disease_id}_symptom'
+        symptom_data = self.cache[key] if key in self.cache else None
+        if symptom_data is None:
+            out_edges = self.graph.out_edges(int(disease_id), data=True)
+            symptom_data = []
+            for edge in out_edges:
+                src, dest, edge_data = edge
+                if edge_data["type"] not in symptom_edge:
+                    continue
+                dest_data = self.entity_data[self.entity_data.index == dest]
+                if dest_data.empty:
+                    continue
+                dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
+                if dest_name not in symptom_data:
+                    symptom_data.append(dest_name)
+            self.cache[key]=symptom_data
+        return symptom_data

+ 1 - 1
main.py

@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 logger.propagate = True
 logger.propagate = True
 
 
 # 创建FastAPI应用
 # 创建FastAPI应用
-app = FastAPI(title="知识图谱")
+app = FastAPI(title="医疗知识",root_path="/knowledge")
 app.include_router(dify_kb_router)
 app.include_router(dify_kb_router)
 app.include_router(saas_kb_router)
 app.include_router(saas_kb_router)
 app.include_router(text_search_router)
 app.include_router(text_search_router)

+ 35 - 23
router/graph_router.py

@@ -1,8 +1,11 @@
+import logging
 import sys,os
 import sys,os
 
 
 from agent.cdss.capbility import CDSSCapability
 from agent.cdss.capbility import CDSSCapability
 from agent.cdss.models.schemas import CDSSInput, CDSSInt, CDSSText
 from agent.cdss.models.schemas import CDSSInput, CDSSInt, CDSSText
 from model.response import StandardResponse
 from model.response import StandardResponse
+from service.cdss_service import CdssService
+from service.kg_node_service import KGNodeService
 
 
 current_path = os.getcwd()
 current_path = os.getcwd()
 sys.path.append(current_path)
 sys.path.append(current_path)
@@ -17,9 +20,9 @@ import json
 
 
 
 
 
 
-router = APIRouter(prefix="/graph", tags=["Knowledge Graph"])
-
-@router.get("/nodes/recommend", response_model=StandardResponse)
+router = APIRouter(prefix="/disease", tags=["Knowledge Graph"])
+logger = logging.getLogger(__name__)
+@router.get("/recommend", response_model=StandardResponse)
 async def recommend(
 async def recommend(
     chief: str,
     chief: str,
     present_illness: Optional[str] = None,
     present_illness: Optional[str] = None,
@@ -31,10 +34,10 @@ async def recommend(
     app_id = "256fd853-60b0-4357-b11b-8114b4e90ae0"
     app_id = "256fd853-60b0-4357-b11b-8114b4e90ae0"
     conversation_id = get_conversation_id(app_id)
     conversation_id = get_conversation_id(app_id)
 
 
-    # desc = "主诉:"+chief
-    # if present_illness:
-    #     desc+="\n现病史:" + present_illness
-    result = call_chat_api(app_id, conversation_id, chief)
+    desc = "主诉:"+chief
+    if present_illness:
+        desc+="\n现病史:" + present_illness
+    result = call_chat_api(app_id, conversation_id, desc)
     json_data = json.loads(result)
     json_data = json.loads(result)
     keyword = " ".join(json_data["symptoms"])
     keyword = " ".join(json_data["symptoms"])
     result = await neighbor_search(keyword=keyword,sex=sex,age=age, neighbor_type='Check', limit=10)
     result = await neighbor_search(keyword=keyword,sex=sex,age=age, neighbor_type='Check', limit=10)
@@ -42,24 +45,18 @@ async def recommend(
     print(f"recommend执行完成,耗时:{end_time - start_time:.2f}秒")
     print(f"recommend执行完成,耗时:{end_time - start_time:.2f}秒")
     return result;
     return result;
 
 
-
-@router.get("/nodes/neighbor_search", response_model=StandardResponse)
+@router.get("/neighbor_search", response_model=StandardResponse)
 async def neighbor_search(
 async def neighbor_search(
     keyword: str = Query(..., min_length=2),
     keyword: str = Query(..., min_length=2),
     sex: Optional[str] = None,
     sex: Optional[str] = None,
     age: Optional[int] = None,
     age: Optional[int] = None,
-    department: Optional[str] = None,
-    limit: int = Query(10, ge=1, le=100),
-    node_type: Optional[str] = Query(None),
-    neighbor_type: Optional[str] = Query(None),
-    min_degree: Optional[int] = Query(None)
+    department: Optional[str] = None
 ):
 ):
     """
     """
     根据关键词和属性过滤条件搜索图谱节点
     根据关键词和属性过滤条件搜索图谱节点
     """
     """
     try:
     try:
 
 
-        print(f"开始执行neighbor_search,参数:keyword={keyword}, limit={limit}, node_type={node_type}, neighbor_type={neighbor_type}, min_degree={min_degree}")
         keywords = keyword.split(" ")
         keywords = keyword.split(" ")
 
 
         record = CDSSInput(
         record = CDSSInput(
@@ -71,13 +68,14 @@ async def neighbor_search(
         # 使用从main.py导入的capability实例处理CDSS逻辑
         # 使用从main.py导入的capability实例处理CDSS逻辑
         output = capability.process(input=record)
         output = capability.process(input=record)
 
 
-        output.diagnosis.value = [{"name":key,"old_score":value["old_score"],"count":value["count"],"score":value["score"],"symptoms":value["symptoms"],
-            "hasInfo": 1,
-            "type": 1} for key,value in output.diagnosis.value.items()]
+        output.diagnosis.value = [{"name":key,"count":value["count"],"score":value["score"],"symptoms":value["symptoms"],
+            #"hasInfo": 1,
+            #"type": 1
+            } for key,value in output.diagnosis.value.items()]
 
 
         return StandardResponse(
         return StandardResponse(
             success=True,
             success=True,
-            data={"可能诊断":output.diagnosis.value,"症状":keywords,"就诊科室":output.departments.value}
+            data={"可能诊断":output.diagnosis.value,"症状":keywords}
         )
         )
     except Exception as e:
     except Exception as e:
         print(e)
         print(e)
@@ -87,8 +85,22 @@ async def neighbor_search(
             error_code=500,
             error_code=500,
             error_msg=str(e)
             error_msg=str(e)
         )
         )
-capability = CDSSCapability()
-#def get_capability():
-    #from main import capability
-    #return capability
+
+@router.get("/{disease_name}/detail", response_model=StandardResponse)
+async def get_disease_detail(
+    disease_name: str
+):
+    try:
+        service = CdssService()
+        result = service.get_disease_detail(disease_name,'疾病')
+        return StandardResponse(success=True, data=result)
+    except Exception as e:
+        logger.error(f"get_disease_detail failed: {str(e)}")
+        return StandardResponse(
+            success=False,
+            error_code=500,
+            error_msg=str(e)
+        )
+
+#capability = CDSSCapability()
 graph_router = router
 graph_router = router

+ 0 - 13
router/knowledge_saas.py

@@ -146,19 +146,6 @@ async def update_node(
         logger.error(f"更新节点失败: {str(e)}")
         logger.error(f"更新节点失败: {str(e)}")
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
         raise HTTPException(500, detail=StandardResponse.error(str(e)))
 
 
-@router.delete("/nodes/{node_id}", response_model=StandardResponse)
-async def delete_node(
-    node_id: int,
-    db: Session = Depends(get_db)
-):
-    try:
-        service = TrunksService()
-        service.delete_node(node_id)
-        return StandardResponse(success=True)
-    except Exception as e:
-        logger.error(f"删除节点失败: {str(e)}")
-        raise HTTPException(500, detail=StandardResponse.error(str(e)))
-
 @router.post('/trunks/vector_search', response_model=StandardResponse)
 @router.post('/trunks/vector_search', response_model=StandardResponse)
 async def vector_search(
 async def vector_search(
     payload: VectorSearchRequest,
     payload: VectorSearchRequest,

+ 71 - 0
service/cdss_service.py

@@ -0,0 +1,71 @@
+from sqlalchemy.orm import Session
+from typing import Optional
+from model.kg_node import KGNode
+from db.session import get_db
+import logging
+from sqlalchemy.exc import IntegrityError
+
+from service.kg_node_service import KGNodeService
+from tests.service.test_kg_node_service import kg_node_service
+from utils import vectorizer
+from utils.vectorizer import Vectorizer
+from sqlalchemy import func
+from service.kg_prop_service import KGPropService
+from service.kg_edge_service import KGEdgeService
+from cachetools import TTLCache
+from cachetools.keys import hashkey
+logger = logging.getLogger(__name__)
+DISTANCE_THRESHOLD = 0.65
+class CdssService:
+    _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
+
+
+    def get_disease_detail(self, diease_name: str,category: str):
+        nodeService = KGNodeService(next(get_db()))
+        diseases = nodeService.search_title_index("", diease_name, category, 1, 0)
+        if diseases is None:
+            return None
+        disease_id = diseases[0]["id"]
+        edgeService = KGEdgeService(next(get_db()))
+
+        #category='疾病相关实验室检查项目'
+        examinations = edgeService.get_edges_by_nodes(src_id=disease_id, category='疾病相关实验室检查项目')
+        #category='疾病相关辅助检查项目'
+        physical_examinations = edgeService.get_edges_by_nodes(src_id=disease_id, category='疾病相关辅助检查项目')
+        # category='疾病相关鉴别诊断'
+        differential_diagnosis = edgeService.get_edges_by_nodes(src_id=disease_id, category='疾病相关鉴别诊断')
+
+
+        e_nodes = []
+        pe_nodes = []
+        ddx_nodes = []
+
+        if examinations:
+            e_nodes = [exam['dest_node'] for exam in examinations]
+            self.add_props_to_nodes(e_nodes)
+        
+        if physical_examinations:
+            pe_nodes = [pe['dest_node'] for pe in physical_examinations]
+            self.add_props_to_nodes(pe_nodes)
+
+        if differential_diagnosis:
+            ddx_nodes = [pe['dest_node'] for pe in differential_diagnosis]
+            self.add_props_to_nodes(ddx_nodes,['Differential Diagnosis'])
+        
+        return {
+            'disease':diseases[0],
+            'examinations': e_nodes,
+            'physical_examinations': pe_nodes,
+            'differential_diagnosis': ddx_nodes
+        }
+
+    def add_props_to_nodes(self, dest_nodes,prop_names=None):
+        kgPropService = KGPropService(next(get_db()))
+        for node in dest_nodes:
+            props = kgPropService.get_props_by_ref_id(node['id'],prop_names)
+            if props:
+                node['props'] = props
+
+
+
+

+ 1 - 1
service/kg_node_service.py

@@ -21,7 +21,7 @@ class KGNodeService:
     _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
     _cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
 
 
     def search_title_index(self, index: str, keywrod: str,category: str, top_k: int = 3,distance: float = 0.3) -> Optional[int]:
     def search_title_index(self, index: str, keywrod: str,category: str, top_k: int = 3,distance: float = 0.3) -> Optional[int]:
-        cache_key = f"search_title_index_{index}:{keywrod}:{category}:{top_k}:{float}"
+        cache_key = f"search_title_index_{index}:{keywrod}:{category}:{top_k}:{distance}"
         if cache_key in self._cache:
         if cache_key in self._cache:
             return self._cache[cache_key]
             return self._cache[cache_key]
 
 

+ 4 - 3
service/kg_prop_service.py

@@ -11,11 +11,12 @@ class KGPropService:
     def __init__(self, db: Session):
     def __init__(self, db: Session):
         self.db = db
         self.db = db
 
 
-    def get_props_by_ref_id(self, ref_id: int, prop_name: str = None) -> List[dict]:
+    def get_props_by_ref_id(self, ref_id: int, prop_names: List[str] = None) -> List[dict]:
         try:
         try:
             query = self.db.query(KGProp).filter(KGProp.ref_id == ref_id)
             query = self.db.query(KGProp).filter(KGProp.ref_id == ref_id)
-            if prop_name:
-                query = query.filter(KGProp.prop_name == prop_name)
+            if prop_names:
+                #prop_names是一个列表,需要使用in_方法进行查询
+                query = query.filter(KGProp.prop_name.in_(prop_names))
             props = query.all()
             props = query.all()
             return [{
             return [{
                 'id': p.id,
                 'id': p.id,