import copy from hmac import new import os import sys import logging import json import time from sqlalchemy import false from service.kg_edge_service import KGEdgeService from db.session import get_db from service.kg_node_service import KGNodeService from service.kg_prop_service import KGPropService from cachetools import TTLCache from cachetools.keys import hashkey current_path = os.getcwd() sys.path.append(current_path) from community.graph_helper import GraphHelper from typing import List from agent.cdss.models.schemas import CDSSInput from config.site import SiteConfig import networkx as nx import pandas as pd logger = logging.getLogger(__name__) current_path = os.getcwd() sys.path.append(current_path) # 图谱数据缓存路径(由dump_graph_data.py生成) CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data') class CDSSHelper(GraphHelper): def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None): kg_node_service = KGNodeService(next(get_db())) es_result = kg_node_service.search_title_index("", node_id,node_type, limit) results = [] for item in es_result: n = self.graph.nodes.get(item["id"]) score = item["score"] if n: results.append({ 'id': item["title"], 'score': score, "name": item["title"], }) return results def _load_entity_data(self): config = SiteConfig() # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH") print("load entity data") # 这里设置了读取的属性 data = {"id": [], "name": [], "type": [],"is_symptom": [], "sex": [], "age": [],"score": [],"TCM": []} if not os.path.exists(os.path.join(CACHED_DATA_PATH, 'entities_med.json')): return [] with open(os.path.join(CACHED_DATA_PATH, 'entities_med.json'), "r", encoding="utf-8") as f: entities = json.load(f) for item in entities: #如果id已经存在,则跳过 # if item[0] in data["id"]: # print(f"skip {item[0]}") # continue data["id"].append(int(item[0])) data["name"].append(item[1]["name"]) data["type"].append(item[1]["type"]) self._append_entity_attribute(data, item, "sex") self._append_entity_attribute(data, item, "age") self._append_entity_attribute(data, item, "is_symptom") self._append_entity_attribute(data, item, "score") self._append_entity_attribute(data, item, "TCM") # item[1]["id"] = item[0] # item[1]["name"] = item[0] # attrs = item[1] # self.graph.add_node(item[0], **attrs) self.entity_data = pd.DataFrame(data) self.entity_data.set_index("id", inplace=True) print("load entity data finished") def get_entity_data(self): return self.entity_data def _append_entity_attribute(self, data, item, attr_name): if attr_name in item[1]: value = item[1][attr_name].split(":") if len(value) < 2: data[attr_name].append(value[0]) else: data[attr_name].append(value[1]) else: data[attr_name].append("") def _load_relation_data(self): config = SiteConfig() # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH") print("load relationship data") for i in range(500): if not os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")): break if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")): print(f"load entity data {CACHED_DATA_PATH}\\relationship_med_{i}.json") with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f: data = {"src": [], "dest": [], "type": [], "weight": []} relations = json.load(f) for item in relations: data["src"].append(int(item[0])) data["dest"].append(int(item[2])) data["type"].append(item[4]["type"]) if "order" in item[4]: order = item[4]["order"].split(":") if len(order) < 2: data["weight"].append(int(order[0])) else: data["weight"].append(int(order[1])) else: data["weight"].append(1) self.relation_data = pd.concat([self.relation_data, pd.DataFrame(data)], ignore_index=True) def build_graph(self): self.entity_data = pd.DataFrame( {"id": [], "name": [], "type": [], "sex": [], "allowed_age_range": []}) self.relation_data = pd.DataFrame({"src": [], "dest": [], "type": [], "weight": []}) self._load_entity_data() self._load_relation_data() self._load_local_data() self.graph = nx.from_pandas_edgelist(self.relation_data, "src", "dest", edge_attr=True, create_using=nx.DiGraph()) nx.set_node_attributes(self.graph, self.entity_data.to_dict(orient="index")) # print(self.graph.in_edges('1257357',data=True)) def _load_local_data(self): # 这里加载update数据和权重数据 config = SiteConfig() self.update_data_path = config.get_config('UPDATE_DATA_PATH') self.factor_data_path = config.get_config('FACTOR_DATA_PATH') print(f"load update data from {self.update_data_path}") for root, dirs, files in os.walk(self.update_data_path): for file in files: file_path = os.path.join(root, file) if file_path.endswith(".json") and file.startswith("ent"): self._load_update_entity_json(file_path) if file_path.endswith(".json") and file.startswith("rel"): self._load_update_relationship_json(file_path) def _load_update_entity_json(self, file): '''load json data from file''' print(f"load entity update data from {file}") # 这里加载update数据,update数据是一个json文件,格式同cached data如下: with open(file, "r", encoding="utf-8") as f: entities = json.load(f) for item in entities: original_data = self.entity_data[self.entity_data.index == item[0]] if original_data.empty: continue original_data = original_data.iloc[0] id = int(item[0]) name = item[1]["name"] if "name" in item[1] else original_data['name'] type = item[1]["type"] if "type" in item[1] else original_data['type'] allowed_sex_list = item[1]["allowed_sex_list"] if "allowed_sex_list" in item[1] else original_data[ 'allowed_sex_list'] allowed_age_range = item[1]["allowed_age_range"] if "allowed_age_range" in item[1] else original_data[ 'allowed_age_range'] self.entity_data.loc[id, ["name", "type", "allowed_sex_list", "allowed_age_range"]] = [name, type, allowed_sex_list, allowed_age_range] def _load_update_relationship_json(self, file): '''load json data from file''' print(f"load relationship update data from {file}") with open(file, "r", encoding="utf-8") as f: relations = json.load(f) for item in relations: data = {} original_data = self.relation_data[(self.relation_data['src'] == data['src']) & (self.relation_data['dest'] == data['dest']) & (self.relation_data['type'] == data['type'])] if original_data.empty: continue original_data = original_data.iloc[0] data["src"] = int(item[0]) data["dest"] = int(item[2]) data["type"] = item[4]["type"] data["weight"] = item[4]["weight"] if "weight" in item[4] else original_data['weight'] self.relation_data.loc[(self.relation_data['src'] == data['src']) & (self.relation_data['dest'] == data['dest']) & (self.relation_data['type'] == data['type']), 'weight'] = data["weight"] def check_sex_allowed(self, node, sex): # 性别过滤,假设疾病节点有一个属性叫做allowed_sex_type,值为“0,1,2”,分别代表未知,男,女 sex_allowed = self.graph.nodes[node].get('sex', None) #sexProps = self.propService.get_props_by_ref_id(node, 'sex') #if len(sexProps) > 0 and sexProps[0]['prop_value'] is not None and sexProps[0][ #'prop_value'] != input.pat_sex.value: #continue if sex_allowed: if len(sex_allowed) == 0: # 如果性别列表为空,那么默认允许所有性别 return True sex_allowed_list = sex_allowed.split(',') if sex not in sex_allowed_list: # 如果性别不匹配,跳过 return False return True def check_age_allowed(self, node, age): # 年龄过滤,假设疾病节点有一个属性叫做allowed_age_range,值为“6-88”,代表年龄在0-88月之间是允许的 # 如果说年龄小于6岁,那么我们就认为是儿童,所以儿童的年龄范围是0-6月 age_allowed = self.graph.nodes[node].get('age', None) if age_allowed: if len(age_allowed) == 0: # 如果年龄范围为空,那么默认允许所有年龄 return True age_allowed_list = age_allowed.split('-') age_min = int(age_allowed_list[0]) age_max = int(age_allowed_list[-1]) if age_max ==0: return True if age >= age_min and age < age_max: # 如果年龄范围正常,那么返回True return True else: # 如果没有设置年龄范围,那么默认返回True return True return False def check_diease_allowed(self, node): if (node ==1489628 or node ==1685808 or node ==1272212 or node ==1303510 or node ==1303336 or node ==1299948 or node ==1916891 or node ==1993268 or node ==1613483 or node ==1783840 or node==1353546 or node==1491526 or node==132923 or node==1287322 or node==1921248 or node==1303271 or node==1811668): return False is_symptom = self.graph.nodes[node].get('is_symptom', None) if is_symptom == "是": return False is_tcm = self.graph.nodes[node].get('TCM', None) if is_tcm == "是": return False return True propService = KGPropService(next(get_db())) cache = TTLCache(maxsize=100000, ttl=60*60*24*30) def cdss_travel(self, input: CDSSInput, start_nodes: List, max_hops=3): """ 基于输入的症状节点,在知识图谱中进行遍历,查找相关疾病、科室、检查和药品 参数: input: CDSSInput对象,包含患者的基本信息(年龄、性别等) start_nodes: 症状节点名称列表,作为遍历的起点 max_hops: 最大遍历深度,默认为3 返回值: 返回一个包含以下信息的字典: - details: 按科室汇总的结果 - diags: 按相关性排序的疾病列表 - checks: 按出现频率排序的检查列表 - drugs: 按出现频率排序的药品列表 - total_diags: 疾病总数 - total_checks: 检查总数 - total_drugs: 药品总数 主要步骤: 1. 初始化允许的节点类型和关系类型 2. 将症状名称转换为节点ID 3. 遍历图谱查找相关疾病(STEP 1) 4. 查找疾病对应的科室、检查和药品(STEP 2) 5. 按科室汇总结果(STEP 3) 6. 对结果进行排序和统计(STEP 4-6) """ # 定义允许的节点类型,包括科室、疾病、药品、检查和症状 # 这些类型用于后续的节点过滤和路径查找 DEPARTMENT = ['科室', 'Department'] DIESEASE = ['疾病', 'Disease'] DRUG = ['药品', 'Drug'] CHECK = ['检查', 'Check'] SYMPTOM = ['症状', 'Symptom'] #allowed_types = DEPARTMENT + DIESEASE + DRUG + CHECK + SYMPTOM allowed_types = DEPARTMENT + DIESEASE + SYMPTOM # 定义允许的关系类型,包括has_symptom、need_check、recommend_drug、belongs_to # 这些关系类型用于后续的路径查找和过滤 symptom_edge = ['has_symptom', '疾病相关症状'] symptom_same_edge = ['症状同义词', '症状同义词2.0'] department_edge = ['belongs_to','所属科室'] allowed_links = symptom_edge+symptom_same_edge # allowed_links = symptom_edge + department_edge # 将输入的症状名称转换为节点ID # 由于可能存在同名节点,转换后的节点ID数量可能大于输入的症状数量 node_ids = [] node_id_names = {} # start_nodes里面重复的症状,去重同样的症状 start_nodes = list(set(start_nodes)) for node in start_nodes: #print(f"searching for node {node}") result = self.entity_data[self.entity_data['name'] == node] #print(f"searching for node {result}") for index, data in result.iterrows(): if data["type"] in SYMPTOM or data["type"] in DIESEASE: node_id_names[index] = data["name"] node_ids = node_ids + [index] #print(f"start travel from {node_id_names}") # 这里是一个队列,用于存储待遍历的症状: node_ids_filtered = [] for node in node_ids: if self.graph.has_node(node): node_ids_filtered.append(node) else: logger.debug(f"node {node} not found") node_ids = node_ids_filtered results = self.step1(node_ids,node_id_names, input, allowed_types, symptom_same_edge,allowed_links,max_hops,DIESEASE) results = self.validDisease(results, start_nodes) sorted_count_diags = sorted(results.items(), key=lambda x: x[1]["count"], reverse=True)[:10] diags = {} target_symptom_names = [] for symptom_id in node_ids: target_symptom_name = self.entity_data[self.entity_data.index == symptom_id]['name'].tolist()[0] target_symptom_names.append(target_symptom_name) same_symptoms = self.get_in_edges(symptom_id, symptom_same_edge) #same_symptoms的name属性全部添加到target_symptom_names中 for same_symptom in same_symptoms: target_symptom_names.append(same_symptom['name'].tolist()[0]) for item in sorted_count_diags: disease_id = item[0] disease_name = self.entity_data[self.entity_data.index == disease_id]['name'].tolist()[0] symptoms_data = self.get_symptoms_data(disease_id, symptom_edge) if symptoms_data is None: continue symptoms = [] for symptom in symptoms_data: if symptom =='发烧': continue matched = False if symptom in target_symptom_names: matched = True symptoms.append({"name": symptom, "matched": matched}) # symtoms中matched=true的排在前面,matched=false的排在后面 symptoms = sorted(symptoms, key=lambda x: x["matched"], reverse=True) start_nodes_size = len(start_nodes) # if start_nodes_size > 1: # start_nodes_size = start_nodes_size*0.5 new_item = {"count": item[1]["count"], "score": float(item[1]["count"]) / start_nodes_size * 0.1, "symptoms": symptoms} diags[disease_name] = new_item sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True) return { "score_diags": sorted_score_diags # "checks":sorted_checks, "drugs":sorted_drugs, # "total_checks":total_check, "total_drugs":total_drug } def validDisease(self, results, start_nodes): """ 输出有效的疾病信息为Markdown格式 :param results: 疾病结果字典 :param start_nodes: 起始症状节点列表 :return: 格式化后的Markdown字符串 """ log_data = ["|疾病|症状|出现次数|是否相关"] log_data.append("|--|--|--|--|") filtered_results = {} for item in results: data = results[item] data['relevant'] = False if data["increase"] / len(start_nodes) > 0.5: #cache_key = f'disease_name_ref_id_{data['name']}' data['relevant'] = True filtered_results[item] = data # 初始化疾病的父类疾病 # disease_name = data["name"] # key = 'disease_name_parent_' +disease_name # cached_value = self.cache.get(key) # if cached_value is None: # out_edges = self.graph.out_edges(item, data=True) # # for edge in out_edges: # src, dest, edge_data = edge # if edge_data["type"] != '疾病相关父类': # continue # dest_data = self.entity_data[self.entity_data.index == dest] # if dest_data.empty: # continue # dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0] # self.cache.set(key, dest_name) # break if data['relevant'] == False: continue log_data.append(f"|{data['name']}|{','.join(data['path'])}|{data['count']}|{data['relevant']}|") content = "疾病和症状相关性统计表格\n" + "\n".join(log_data) print(f"\n{content}") return filtered_results def step1(self, node_ids,node_id_names, input, allowed_types, symptom_same_edge,allowed_links,max_hops,DIESEASE): """ 根据症状节点查找相关疾病 :param node_ids: 症状节点ID列表 :param input: 患者信息输入 :param allowed_types: 允许的节点类型 :param allowed_links: 允许的关系类型 :return: 过滤后的疾病结果 """ start_time = time.time() results = {} for node in node_ids: visited = set() temp_results = {} cache_key = f"symptom_ref_disease_{str(node)}" cache_data = self.cache[cache_key] if cache_key in self.cache else None if cache_data: temp_results = copy.deepcopy(cache_data) print(cache_key+":"+node_id_names[node] +':'+ str(len(temp_results))) if results=={}: results = temp_results else: for disease_id in temp_results: path = temp_results[disease_id]["path"][0] if disease_id in results.keys(): results[disease_id]["count"] = results[disease_id]["count"] + temp_results[disease_id]["count"] results[disease_id]["increase"] = results[disease_id]["increase"]+1 results[disease_id]["path"].append(path) else: results[disease_id] = temp_results[disease_id] continue queue = [(node, 0, node_id_names[node],10, {'allowed_types': allowed_types, 'allowed_links': allowed_links})] # 整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换 if input.pat_age and input.pat_age.value is not None and input.pat_age.value > 0 and input.pat_age.type == 'year': # 这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的 input.pat_age.value = input.pat_age.value * 12 input.pat_age.type = 'month' # STEP 1: 假设start_nodes里面都是症状,第一步我们先找到这些症状对应的疾病 # TODO 由于这部分是按照症状逐一去寻找疾病,所以实际应用中可以缓存这些结果 while queue: temp_node, depth, path, weight, data = queue.pop(0) temp_node = int(temp_node) # 这里是通过id去获取节点的name和type entity_data = self.entity_data[self.entity_data.index == temp_node] # 如果节点不存在,那么跳过 if entity_data.empty: continue if self.graph.nodes.get(temp_node) is None: continue node_type = self.entity_data[self.entity_data.index == temp_node]['type'].tolist()[0] node_name = self.entity_data[self.entity_data.index == temp_node]['name'].tolist()[0] # print(f"node {node} type {node_type}") if node_type in DIESEASE: # print(f"node {node} type {node_type} is a disease") count = weight if self.check_diease_allowed(temp_node) == False: continue if temp_node not in temp_results.keys(): # temp_results[temp_node]["count"] = temp_results[temp_node]["count"] + count # temp_results[temp_node]["increase"] = temp_results[temp_node]["increase"] + 1 # temp_results[temp_node]["path"].append(path) # else: temp_results[temp_node] = {"type": node_type, "count": count, "increase": 1, "name": node_name, 'path': [path]} continue if temp_node in visited or depth > max_hops: # print(f"{node} already visited or reach max hops") continue visited.add(temp_node) # print(f"check edges from {node}") if temp_node not in self.graph: # print(f"node {node} not found in graph") continue for edge in self.graph.in_edges(temp_node, data=True): src, dest, edge_data = edge if src not in visited and depth + 1 <= max_hops and edge_data['type'] in allowed_links: weight = edge_data['weight'] try : if weight: if weight < 10: weight = 10-weight else: weight = 1 else: weight = 5 if weight>10: weight = 10 except Exception as e: print(f'Error processing file {weight}: {str(e)}') queue.append((src, depth + 1, path, int(weight), data)) # else: # print(f"skip travel from {src} to {dest}") print(cache_key+":"+node_id_names[node]+':'+ str(len(temp_results))) #对temp_results进行深拷贝,然后再进行处理 self.cache[cache_key] = copy.deepcopy(temp_results) if results == {}: results = temp_results else: for disease_id in temp_results: path = temp_results[disease_id]["path"][0] if disease_id in results.keys(): results[disease_id]["count"] = results[disease_id]["count"] + temp_results[disease_id]["count"] results[disease_id]["increase"] = results[disease_id]["increase"] + 1 results[disease_id]["path"].append(path) else: results[disease_id] = temp_results[disease_id] end_time = time.time() # 这里我们需要对结果进行过滤,过滤掉不满足条件的疾病 new_results = {} for item in results: if input.pat_sex and input.pat_sex.value is not None and self.check_sex_allowed(item, input.pat_sex.value) == False: continue if input.pat_age and input.pat_age.value is not None and self.check_age_allowed(item, input.pat_age.value) == False: continue new_results[item] = results[item] results = new_results print('STEP 1 '+str(len(results))) print(f"STEP 1 执行完成,耗时:{end_time - start_time:.2f}秒") print(f"STEP 1 遍历图谱查找相关疾病 finished") return results def get_in_edges(self, node_id, allowed_links): results = [] for edge in self.graph.in_edges(node_id, data=True): src, dest, edge_data = edge if edge_data['type'] in allowed_links: results.append(self.entity_data[self.entity_data.index == src]) return results def step2(self, results,department_edge): """ 查找疾病对应的科室、检查和药品信息 :param results: 包含疾病信息的字典 :return: 更新后的results字典 """ start_time = time.time() print("STEP 2 查找疾病对应的科室、检查和药品 start") for disease in results.keys(): # cache_key = f"disease_department_{disease}" # cached_data = self.cache.get(cache_key) # if cached_data: # results[disease]["department"] = cached_data # continue if results[disease]["relevant"] == False: continue department_data = [] out_edges = self.graph.out_edges(disease, data=True) for edge in out_edges: src, dest, edge_data = edge if edge_data["type"] not in department_edge: continue dest_data = self.entity_data[self.entity_data.index == dest] if dest_data.empty: continue department_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0] department_data.extend([department_name] * results[disease]["count"]) if department_data: results[disease]["department"] = department_data #self.cache.set(cache_key, department_data) print(f"STEP 2 finished") end_time = time.time() print(f"STEP 2 执行完成,耗时:{end_time - start_time:.2f}秒") # 输出日志 log_data = ["|disease|count|department|check|drug|"] log_data.append("|--|--|--|--|--|") for item in results.keys(): department_data = results[item].get("department", []) count_data = results[item].get("count") check_data = results[item].get("check", []) drug_data = results[item].get("drug", []) log_data.append( f"|{results[item].get("name", item)}|{count_data}|{','.join(department_data)}|{','.join(check_data)}|{','.join(drug_data)}|") print("疾病科室检查药品相关统计\n" + "\n".join(log_data)) return results def step3(self, results): print(f"STEP 3 对于结果按照科室维度进行汇总 start") final_results = {} total = 0 for disease in results.keys(): disease = int(disease) # 由于存在有些疾病没有科室的情况,所以这里需要做一下处理 departments = ['DEFAULT'] if 'department' in results[disease].keys(): departments = results[disease]["department"] else: edges = KGEdgeService(next(get_db())).get_edges_by_nodes(src_id=disease, category='所属科室') #edges有可能为空,这里需要做一下处理 if len(edges) > 0: departments = [edge['dest_node']['name'] for edge in edges] # 处理查询结果 for department in departments: total += 1 if not department in final_results.keys(): final_results[department] = { "diseases": [str(disease)+":"+results[disease].get("name", disease)], "checks": results[disease].get("check", []), "drugs": results[disease].get("drug", []), "count": 1 } else: final_results[department]["diseases"] = final_results[department]["diseases"] + [str(disease)+":"+ results[disease].get("name", disease)] final_results[department]["checks"] = final_results[department]["checks"] + results[disease].get( "check", []) final_results[department]["drugs"] = final_results[department]["drugs"] + results[disease].get( "drug", []) final_results[department]["count"] += 1 # 这里是统计科室出现的分布 for department in final_results.keys(): final_results[department]["score"] = final_results[department]["count"] / total print(f"STEP 3 finished") # 这里输出日志 log_data = ["|department|disease|check|drug|count|score"] log_data.append("|--|--|--|--|--|--|") for department in final_results.keys(): diesease_data = final_results[department].get("diseases", []) check_data = final_results[department].get("checks", []) drug_data = final_results[department].get("drugs", []) count_data = final_results[department].get("count", 0) score_data = final_results[department].get("score", 0) log_data.append( f"|{department}|{','.join(diesease_data)}|{','.join(check_data)}|{','.join(drug_data)}|{count_data}|{score_data}|") print("\n" + "\n".join(log_data)) return final_results def step4(self, final_results): """ 对final_results中的疾病、检查和药品进行统计和排序 参数: final_results: 包含科室、疾病、检查和药品的字典 返回值: 排序后的final_results """ print(f"STEP 4 start") start_time = time.time() def sort_data(data, count=10): tmp = {} for item in data: if item in tmp.keys(): tmp[item]["count"] += 1 else: tmp[item] = {"count": 1} sorted_data = sorted(tmp.items(), key=lambda x: x[1]["count"], reverse=True) return sorted_data[:count] for department in final_results.keys(): final_results[department]['name'] = department final_results[department]["diseases"] = sort_data(final_results[department]["diseases"]) #final_results[department]["checks"] = sort_data(final_results[department]["checks"]) #final_results[department]["drugs"] = sort_data(final_results[department]["drugs"]) # 这里把科室做一个排序,按照出现的次数降序排序 sorted_final_results = sorted(final_results.items(), key=lambda x: x[1]["count"], reverse=True) print(f"STEP 4 finished") end_time = time.time() print(f"STEP 4 执行完成,耗时:{end_time - start_time:.2f}秒") # 这里输出markdown日志 log_data = ["|department|disease|check|drug|count|score"] log_data.append("|--|--|--|--|--|--|") for department in final_results.keys(): diesease_data = final_results[department].get("diseases") check_data = final_results[department].get("checks") drug_data = final_results[department].get("drugs") count_data = final_results[department].get("count", 0) score_data = final_results[department].get("score", 0) log_data.append(f"|{department}|{diesease_data}|{check_data}|{drug_data}|{count_data}|{score_data}|") print("\n" + "\n".join(log_data)) return sorted_final_results def step5(self, final_results, input, start_nodes, symptom_edge): """ 按科室汇总结果并排序 参数: final_results: 各科室的初步结果 input: 患者输入信息 返回值: 返回排序后的诊断结果 """ print(f"STEP 5 start") start_time = time.time() diags = {} total_diags = 0 for department in final_results.keys(): department_factor = 0.1 if department == 'DEFAULT' else final_results[department]["score"] count = 0 #当前科室权重增加0.1 if input.department.value == department: count = 1 for disease, data in final_results[department]["diseases"]: total_diags += 1 if disease in diags.keys(): diags[disease]["count"] += data["count"]+count diags[disease]["score"] += (data["count"]+count)*0.1 * department_factor else: diags[disease] = {"count": data["count"]+count, "score": (data["count"]+count)*0.1 * department_factor} #sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)[:10] sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["count"], reverse=True)[:10] diags = {} for item in sorted_score_diags: disease_info = item[0].split(":"); disease_id = disease_info[0] disease = disease_info[1] symptoms_data = self.get_symptoms_data(disease_id, symptom_edge) if symptoms_data is None: continue symptoms = [] for symptom in symptoms_data: matched = False if symptom in start_nodes: matched = True symptoms.append({"name": symptom, "matched": matched}) # symtoms中matched=true的排在前面,matched=false的排在后面 symptoms = sorted(symptoms, key=lambda x: x["matched"], reverse=True) start_nodes_size = len(start_nodes) # if start_nodes_size > 1: # start_nodes_size = start_nodes_size*0.5 new_item = {"old_score": item[1]["score"],"count": item[1]["count"], "score": float(item[1]["count"])/start_nodes_size/2*0.1,"symptoms":symptoms} diags[disease] = new_item sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True) print(f"STEP 5 finished") end_time = time.time() print(f"STEP 5 执行完成,耗时:{end_time - start_time:.2f}秒") log_data = ["|department|disease|count|score"] log_data.append("|--|--|--|--|") for department in final_results.keys(): diesease_data = final_results[department].get("diseases") count_data = final_results[department].get("count", 0) score_data = final_results[department].get("score", 0) log_data.append(f"|{department}|{diesease_data}|{count_data}|{score_data}|") print("这里是经过排序的数据\n" + "\n".join(log_data)) return sorted_score_diags, total_diags def get_symptoms_data(self, disease_id, symptom_edge): """ 获取疾病相关的症状数据 :param disease_id: 疾病节点ID :param symptom_edge: 症状关系类型列表 :return: 症状数据列表 """ key = f'disease_{disease_id}_symptom' symptom_data = self.cache[key] if key in self.cache else None if symptom_data is None: out_edges = self.graph.out_edges(int(disease_id), data=True) symptom_data = [] for edge in out_edges: src, dest, edge_data = edge if edge_data["type"] not in symptom_edge: continue dest_data = self.entity_data[self.entity_data.index == dest] if dest_data.empty: continue dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0] if dest_name not in symptom_data: symptom_data.append(dest_name) self.cache[key]=symptom_data return symptom_data