123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807 |
- import copy
- from hmac import new
- import os
- import sys
- import logging
- import json
- import time
- from sqlalchemy import false
- from service.kg_edge_service import KGEdgeService
- from db.session import get_db
- from service.kg_node_service import KGNodeService
- from service.kg_prop_service import KGPropService
- from cachetools import TTLCache
- from cachetools.keys import hashkey
- current_path = os.getcwd()
- sys.path.append(current_path)
- from community.graph_helper import GraphHelper
- from typing import List
- from agent.cdss.models.schemas import CDSSInput
- from config.site import SiteConfig
- import networkx as nx
- import pandas as pd
- logger = logging.getLogger(__name__)
- current_path = os.getcwd()
- sys.path.append(current_path)
- # 图谱数据缓存路径(由dump_graph_data.py生成)
- CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data')
- class CDSSHelper(GraphHelper):
- def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
- kg_node_service = KGNodeService(next(get_db()))
- es_result = kg_node_service.search_title_index("", node_id,node_type, limit)
- results = []
- for item in es_result:
- n = self.graph.nodes.get(item["id"])
- score = item["score"]
- if n:
- results.append({
- 'id': item["title"],
- 'score': score,
- "name": item["title"],
- })
- return results
- def _load_entity_data(self):
- config = SiteConfig()
- # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
- print("load entity data")
- # 这里设置了读取的属性
- data = {"id": [], "name": [], "type": [],"is_symptom": [], "sex": [], "age": [],"score": [],"TCM": []}
- if not os.path.exists(os.path.join(CACHED_DATA_PATH, 'entities_med.json')):
- return []
- with open(os.path.join(CACHED_DATA_PATH, 'entities_med.json'), "r", encoding="utf-8") as f:
- entities = json.load(f)
- for item in entities:
- #如果id已经存在,则跳过
- # if item[0] in data["id"]:
- # print(f"skip {item[0]}")
- # continue
- data["id"].append(int(item[0]))
- data["name"].append(item[1]["name"])
- data["type"].append(item[1]["type"])
- self._append_entity_attribute(data, item, "sex")
- self._append_entity_attribute(data, item, "age")
- self._append_entity_attribute(data, item, "is_symptom")
- self._append_entity_attribute(data, item, "score")
- self._append_entity_attribute(data, item, "TCM")
- # item[1]["id"] = item[0]
- # item[1]["name"] = item[0]
- # attrs = item[1]
- # self.graph.add_node(item[0], **attrs)
- self.entity_data = pd.DataFrame(data)
- self.entity_data.set_index("id", inplace=True)
- print("load entity data finished")
- def get_entity_data(self):
- return self.entity_data
- def _append_entity_attribute(self, data, item, attr_name):
- if attr_name in item[1]:
- value = item[1][attr_name].split(":")
- if len(value) < 2:
- data[attr_name].append(value[0])
- else:
- data[attr_name].append(value[1])
- else:
- data[attr_name].append("")
- def _load_relation_data(self):
- config = SiteConfig()
- # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
- print("load relationship data")
- for i in range(500):
- if not os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
- break
- if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
- print(f"load entity data {CACHED_DATA_PATH}\\relationship_med_{i}.json")
- with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f:
- data = {"src": [], "dest": [], "type": [], "weight": []}
- relations = json.load(f)
- for item in relations:
- data["src"].append(int(item[0]))
- data["dest"].append(int(item[2]))
- data["type"].append(item[4]["type"])
- if "order" in item[4]:
- order = item[4]["order"].split(":")
- if len(order) < 2:
- data["weight"].append(int(order[0]))
- else:
- data["weight"].append(int(order[1]))
- else:
- data["weight"].append(1)
- self.relation_data = pd.concat([self.relation_data, pd.DataFrame(data)], ignore_index=True)
- def build_graph(self):
- self.entity_data = pd.DataFrame(
- {"id": [], "name": [], "type": [], "sex": [], "allowed_age_range": []})
- self.relation_data = pd.DataFrame({"src": [], "dest": [], "type": [], "weight": []})
- self._load_entity_data()
- self._load_relation_data()
- self._load_local_data()
- self.graph = nx.from_pandas_edgelist(self.relation_data, "src", "dest", edge_attr=True,
- create_using=nx.DiGraph())
- nx.set_node_attributes(self.graph, self.entity_data.to_dict(orient="index"))
- # print(self.graph.in_edges('1257357',data=True))
- def _load_local_data(self):
- # 这里加载update数据和权重数据
- config = SiteConfig()
- self.update_data_path = config.get_config('UPDATE_DATA_PATH')
- self.factor_data_path = config.get_config('FACTOR_DATA_PATH')
- print(f"load update data from {self.update_data_path}")
- for root, dirs, files in os.walk(self.update_data_path):
- for file in files:
- file_path = os.path.join(root, file)
- if file_path.endswith(".json") and file.startswith("ent"):
- self._load_update_entity_json(file_path)
- if file_path.endswith(".json") and file.startswith("rel"):
- self._load_update_relationship_json(file_path)
- def _load_update_entity_json(self, file):
- '''load json data from file'''
- print(f"load entity update data from {file}")
- # 这里加载update数据,update数据是一个json文件,格式同cached data如下:
- with open(file, "r", encoding="utf-8") as f:
- entities = json.load(f)
- for item in entities:
- original_data = self.entity_data[self.entity_data.index == item[0]]
- if original_data.empty:
- continue
- original_data = original_data.iloc[0]
- id = int(item[0])
- name = item[1]["name"] if "name" in item[1] else original_data['name']
- type = item[1]["type"] if "type" in item[1] else original_data['type']
- allowed_sex_list = item[1]["allowed_sex_list"] if "allowed_sex_list" in item[1] else original_data[
- 'allowed_sex_list']
- allowed_age_range = item[1]["allowed_age_range"] if "allowed_age_range" in item[1] else original_data[
- 'allowed_age_range']
- self.entity_data.loc[id, ["name", "type", "allowed_sex_list", "allowed_age_range"]] = [name, type,
- allowed_sex_list,
- allowed_age_range]
- def _load_update_relationship_json(self, file):
- '''load json data from file'''
- print(f"load relationship update data from {file}")
- with open(file, "r", encoding="utf-8") as f:
- relations = json.load(f)
- for item in relations:
- data = {}
- original_data = self.relation_data[(self.relation_data['src'] == data['src']) &
- (self.relation_data['dest'] == data['dest']) &
- (self.relation_data['type'] == data['type'])]
- if original_data.empty:
- continue
- original_data = original_data.iloc[0]
- data["src"] = int(item[0])
- data["dest"] = int(item[2])
- data["type"] = item[4]["type"]
- data["weight"] = item[4]["weight"] if "weight" in item[4] else original_data['weight']
- self.relation_data.loc[(self.relation_data['src'] == data['src']) &
- (self.relation_data['dest'] == data['dest']) &
- (self.relation_data['type'] == data['type']), 'weight'] = data["weight"]
- def check_sex_allowed(self, node, sex):
- # 性别过滤,假设疾病节点有一个属性叫做allowed_sex_type,值为“0,1,2”,分别代表未知,男,女
- sex_allowed = self.graph.nodes[node].get('sex', None)
- #sexProps = self.propService.get_props_by_ref_id(node, 'sex')
- #if len(sexProps) > 0 and sexProps[0]['prop_value'] is not None and sexProps[0][
- #'prop_value'] != input.pat_sex.value:
- #continue
- if sex_allowed:
- if len(sex_allowed) == 0:
- # 如果性别列表为空,那么默认允许所有性别
- return True
- sex_allowed_list = sex_allowed.split(',')
- if sex not in sex_allowed_list:
- # 如果性别不匹配,跳过
- return False
- return True
- def check_age_allowed(self, node, age):
- # 年龄过滤,假设疾病节点有一个属性叫做allowed_age_range,值为“6-88”,代表年龄在0-88月之间是允许的
- # 如果说年龄小于6岁,那么我们就认为是儿童,所以儿童的年龄范围是0-6月
- age_allowed = self.graph.nodes[node].get('age', None)
- if age_allowed:
- if len(age_allowed) == 0:
- # 如果年龄范围为空,那么默认允许所有年龄
- return True
- age_allowed_list = age_allowed.split('-')
- age_min = int(age_allowed_list[0])
- age_max = int(age_allowed_list[-1])
- if age_max ==0:
- return True
- if age >= age_min and age < age_max:
- # 如果年龄范围正常,那么返回True
- return True
- else:
- # 如果没有设置年龄范围,那么默认返回True
- return True
- return False
- def check_diease_allowed(self, node):
- if (node ==1489628 or node ==1685808 or node ==1272212 or node ==1303510 or node ==1303336 or node ==1299948
- or node ==1916891 or node ==1993268 or node ==1613483 or node ==1783840 or node==1353546 or node==1491526 or node==132923
- or node==1287322 or node==1921248 or node==1303271 or node==1811668):
- return False
- is_symptom = self.graph.nodes[node].get('is_symptom', None)
- if is_symptom == "是":
- return False
- is_tcm = self.graph.nodes[node].get('TCM', None)
- if is_tcm == "是":
- return False
- return True
- propService = KGPropService(next(get_db()))
- cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
- def cdss_travel(self, input: CDSSInput, start_nodes: List, max_hops=3):
- """
- 基于输入的症状节点,在知识图谱中进行遍历,查找相关疾病、科室、检查和药品
- 参数:
- input: CDSSInput对象,包含患者的基本信息(年龄、性别等)
- start_nodes: 症状节点名称列表,作为遍历的起点
- max_hops: 最大遍历深度,默认为3
- 返回值:
- 返回一个包含以下信息的字典:
- - details: 按科室汇总的结果
- - diags: 按相关性排序的疾病列表
- - checks: 按出现频率排序的检查列表
- - drugs: 按出现频率排序的药品列表
- - total_diags: 疾病总数
- - total_checks: 检查总数
- - total_drugs: 药品总数
- 主要步骤:
- 1. 初始化允许的节点类型和关系类型
- 2. 将症状名称转换为节点ID
- 3. 遍历图谱查找相关疾病(STEP 1)
- 4. 查找疾病对应的科室、检查和药品(STEP 2)
- 5. 按科室汇总结果(STEP 3)
- 6. 对结果进行排序和统计(STEP 4-6)
- """
- # 定义允许的节点类型,包括科室、疾病、药品、检查和症状
- # 这些类型用于后续的节点过滤和路径查找
- DEPARTMENT = ['科室', 'Department']
- DIESEASE = ['疾病', 'Disease']
- DRUG = ['药品', 'Drug']
- CHECK = ['检查', 'Check']
- SYMPTOM = ['症状', 'Symptom']
- #allowed_types = DEPARTMENT + DIESEASE + DRUG + CHECK + SYMPTOM
- allowed_types = DEPARTMENT + DIESEASE + SYMPTOM
- # 定义允许的关系类型,包括has_symptom、need_check、recommend_drug、belongs_to
- # 这些关系类型用于后续的路径查找和过滤
- symptom_edge = ['has_symptom', '疾病相关症状']
- symptom_same_edge = ['症状同义词', '症状同义词2.0']
- department_edge = ['belongs_to','所属科室']
- allowed_links = symptom_edge+symptom_same_edge
- # allowed_links = symptom_edge + department_edge
- # 将输入的症状名称转换为节点ID
- # 由于可能存在同名节点,转换后的节点ID数量可能大于输入的症状数量
- node_ids = []
- node_id_names = {}
- # start_nodes里面重复的症状,去重同样的症状
- start_nodes = list(set(start_nodes))
- for node in start_nodes:
- #print(f"searching for node {node}")
- result = self.entity_data[self.entity_data['name'] == node]
- #print(f"searching for node {result}")
- for index, data in result.iterrows():
- if data["type"] in SYMPTOM or data["type"] in DIESEASE:
- node_id_names[index] = data["name"]
- node_ids = node_ids + [index]
- #print(f"start travel from {node_id_names}")
- # 这里是一个队列,用于存储待遍历的症状:
- node_ids_filtered = []
- for node in node_ids:
- if self.graph.has_node(node):
- node_ids_filtered.append(node)
- else:
- logger.debug(f"node {node} not found")
- node_ids = node_ids_filtered
- results = self.step1(node_ids,node_id_names, input, allowed_types, symptom_same_edge,allowed_links,max_hops,DIESEASE)
- results = self.validDisease(results, start_nodes)
- sorted_count_diags = sorted(results.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
- diags = {}
- target_symptom_names = []
- for symptom_id in node_ids:
- target_symptom_name = self.entity_data[self.entity_data.index == symptom_id]['name'].tolist()[0]
- target_symptom_names.append(target_symptom_name)
- same_symptoms = self.get_in_edges(symptom_id, symptom_same_edge)
- #same_symptoms的name属性全部添加到target_symptom_names中
- for same_symptom in same_symptoms:
- target_symptom_names.append(same_symptom['name'].tolist()[0])
- for item in sorted_count_diags:
- disease_id = item[0]
- disease_name = self.entity_data[self.entity_data.index == disease_id]['name'].tolist()[0]
- symptoms_data = self.get_symptoms_data(disease_id, symptom_edge)
- if symptoms_data is None:
- continue
- symptoms = []
- for symptom in symptoms_data:
- if symptom =='发烧':
- continue
- matched = False
- if symptom in target_symptom_names:
- matched = True
- symptoms.append({"name": symptom, "matched": matched})
- # symtoms中matched=true的排在前面,matched=false的排在后面
- symptoms = sorted(symptoms, key=lambda x: x["matched"], reverse=True)
- start_nodes_size = len(start_nodes)
- # if start_nodes_size > 1:
- # start_nodes_size = start_nodes_size*0.5
- new_item = {"count": item[1]["count"],
- "score": float(item[1]["count"]) / start_nodes_size * 0.1, "symptoms": symptoms}
- diags[disease_name] = new_item
- sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
- return {
- "score_diags": sorted_score_diags
- # "checks":sorted_checks, "drugs":sorted_drugs,
- # "total_checks":total_check, "total_drugs":total_drug
- }
- def validDisease(self, results, start_nodes):
- """
- 输出有效的疾病信息为Markdown格式
- :param results: 疾病结果字典
- :param start_nodes: 起始症状节点列表
- :return: 格式化后的Markdown字符串
- """
- log_data = ["|疾病|症状|出现次数|是否相关"]
- log_data.append("|--|--|--|--|")
- filtered_results = {}
- for item in results:
- data = results[item]
- data['relevant'] = False
- if data["increase"] / len(start_nodes) > 0.5:
- #cache_key = f'disease_name_ref_id_{data['name']}'
- data['relevant'] = True
- filtered_results[item] = data
- # 初始化疾病的父类疾病
- # disease_name = data["name"]
- # key = 'disease_name_parent_' +disease_name
- # cached_value = self.cache.get(key)
- # if cached_value is None:
- # out_edges = self.graph.out_edges(item, data=True)
- #
- # for edge in out_edges:
- # src, dest, edge_data = edge
- # if edge_data["type"] != '疾病相关父类':
- # continue
- # dest_data = self.entity_data[self.entity_data.index == dest]
- # if dest_data.empty:
- # continue
- # dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
- # self.cache.set(key, dest_name)
- # break
- if data['relevant'] == False:
- continue
- log_data.append(f"|{data['name']}|{','.join(data['path'])}|{data['count']}|{data['relevant']}|")
- content = "疾病和症状相关性统计表格\n" + "\n".join(log_data)
- print(f"\n{content}")
- return filtered_results
-
- def step1(self, node_ids,node_id_names, input, allowed_types, symptom_same_edge,allowed_links,max_hops,DIESEASE):
- """
- 根据症状节点查找相关疾病
- :param node_ids: 症状节点ID列表
- :param input: 患者信息输入
- :param allowed_types: 允许的节点类型
- :param allowed_links: 允许的关系类型
- :return: 过滤后的疾病结果
- """
- start_time = time.time()
- results = {}
- for node in node_ids:
- visited = set()
- temp_results = {}
- cache_key = f"symptom_ref_disease_{str(node)}"
- cache_data = self.cache[cache_key] if cache_key in self.cache else None
- if cache_data:
- temp_results = copy.deepcopy(cache_data)
- print(cache_key+":"+node_id_names[node] +':'+ str(len(temp_results)))
- if results=={}:
- results = temp_results
- else:
- for disease_id in temp_results:
- path = temp_results[disease_id]["path"][0]
- if disease_id in results.keys():
- results[disease_id]["count"] = results[disease_id]["count"] + temp_results[disease_id]["count"]
- results[disease_id]["increase"] = results[disease_id]["increase"]+1
- results[disease_id]["path"].append(path)
- else:
- results[disease_id] = temp_results[disease_id]
- continue
- queue = [(node, 0, node_id_names[node],10, {'allowed_types': allowed_types, 'allowed_links': allowed_links})]
- # 整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换
- if input.pat_age and input.pat_age.value is not None and input.pat_age.value > 0 and input.pat_age.type == 'year':
- # 这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的
- input.pat_age.value = input.pat_age.value * 12
- input.pat_age.type = 'month'
- # STEP 1: 假设start_nodes里面都是症状,第一步我们先找到这些症状对应的疾病
- # TODO 由于这部分是按照症状逐一去寻找疾病,所以实际应用中可以缓存这些结果
- while queue:
- temp_node, depth, path, weight, data = queue.pop(0)
- temp_node = int(temp_node)
- # 这里是通过id去获取节点的name和type
- entity_data = self.entity_data[self.entity_data.index == temp_node]
- # 如果节点不存在,那么跳过
- if entity_data.empty:
- continue
- if self.graph.nodes.get(temp_node) is None:
- continue
- node_type = self.entity_data[self.entity_data.index == temp_node]['type'].tolist()[0]
- node_name = self.entity_data[self.entity_data.index == temp_node]['name'].tolist()[0]
- # print(f"node {node} type {node_type}")
- if node_type in DIESEASE:
- # print(f"node {node} type {node_type} is a disease")
- count = weight
- if self.check_diease_allowed(temp_node) == False:
- continue
- if temp_node not in temp_results.keys():
- # temp_results[temp_node]["count"] = temp_results[temp_node]["count"] + count
- # temp_results[temp_node]["increase"] = temp_results[temp_node]["increase"] + 1
- # temp_results[temp_node]["path"].append(path)
- # else:
- temp_results[temp_node] = {"type": node_type, "count": count, "increase": 1, "name": node_name, 'path': [path]}
- continue
- if temp_node in visited or depth > max_hops:
- # print(f"{node} already visited or reach max hops")
- continue
- visited.add(temp_node)
- # print(f"check edges from {node}")
- if temp_node not in self.graph:
- # print(f"node {node} not found in graph")
- continue
- for edge in self.graph.in_edges(temp_node, data=True):
- src, dest, edge_data = edge
- if src not in visited and depth + 1 <= max_hops and edge_data['type'] in allowed_links:
- weight = edge_data['weight']
- try :
- if weight:
- if weight < 10:
- weight = 10-weight
- else:
- weight = 1
- else:
- weight = 5
- if weight>10:
- weight = 10
- except Exception as e:
- print(f'Error processing file {weight}: {str(e)}')
- queue.append((src, depth + 1, path, int(weight), data))
- # else:
- # print(f"skip travel from {src} to {dest}")
- print(cache_key+":"+node_id_names[node]+':'+ str(len(temp_results)))
- #对temp_results进行深拷贝,然后再进行处理
- self.cache[cache_key] = copy.deepcopy(temp_results)
- if results == {}:
- results = temp_results
- else:
- for disease_id in temp_results:
- path = temp_results[disease_id]["path"][0]
- if disease_id in results.keys():
- results[disease_id]["count"] = results[disease_id]["count"] + temp_results[disease_id]["count"]
- results[disease_id]["increase"] = results[disease_id]["increase"] + 1
- results[disease_id]["path"].append(path)
- else:
- results[disease_id] = temp_results[disease_id]
- end_time = time.time()
- # 这里我们需要对结果进行过滤,过滤掉不满足条件的疾病
- new_results = {}
- for item in results:
- if input.pat_sex and input.pat_sex.value is not None and self.check_sex_allowed(item, input.pat_sex.value) == False:
- continue
- if input.pat_age and input.pat_age.value is not None and self.check_age_allowed(item, input.pat_age.value) == False:
- continue
- new_results[item] = results[item]
- results = new_results
- print('STEP 1 '+str(len(results)))
- print(f"STEP 1 执行完成,耗时:{end_time - start_time:.2f}秒")
- print(f"STEP 1 遍历图谱查找相关疾病 finished")
- return results
- def get_in_edges(self, node_id, allowed_links):
- results = []
- for edge in self.graph.in_edges(node_id, data=True):
- src, dest, edge_data = edge
- if edge_data['type'] in allowed_links:
- results.append(self.entity_data[self.entity_data.index == src])
- return results
- def step2(self, results,department_edge):
- """
- 查找疾病对应的科室、检查和药品信息
- :param results: 包含疾病信息的字典
- :return: 更新后的results字典
- """
- start_time = time.time()
- print("STEP 2 查找疾病对应的科室、检查和药品 start")
- for disease in results.keys():
- # cache_key = f"disease_department_{disease}"
- # cached_data = self.cache.get(cache_key)
- # if cached_data:
- # results[disease]["department"] = cached_data
- # continue
- if results[disease]["relevant"] == False:
- continue
- department_data = []
- out_edges = self.graph.out_edges(disease, data=True)
- for edge in out_edges:
- src, dest, edge_data = edge
- if edge_data["type"] not in department_edge:
- continue
- dest_data = self.entity_data[self.entity_data.index == dest]
- if dest_data.empty:
- continue
- department_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
- department_data.extend([department_name] * results[disease]["count"])
- if department_data:
- results[disease]["department"] = department_data
- #self.cache.set(cache_key, department_data)
- print(f"STEP 2 finished")
- end_time = time.time()
- print(f"STEP 2 执行完成,耗时:{end_time - start_time:.2f}秒")
- # 输出日志
- log_data = ["|disease|count|department|check|drug|"]
- log_data.append("|--|--|--|--|--|")
- for item in results.keys():
- department_data = results[item].get("department", [])
- count_data = results[item].get("count")
- check_data = results[item].get("check", [])
- drug_data = results[item].get("drug", [])
- log_data.append(
- f"|{results[item].get("name", item)}|{count_data}|{','.join(department_data)}|{','.join(check_data)}|{','.join(drug_data)}|")
- print("疾病科室检查药品相关统计\n" + "\n".join(log_data))
- return results
- def step3(self, results):
- print(f"STEP 3 对于结果按照科室维度进行汇总 start")
- final_results = {}
- total = 0
- for disease in results.keys():
- disease = int(disease)
- # 由于存在有些疾病没有科室的情况,所以这里需要做一下处理
- departments = ['DEFAULT']
- if 'department' in results[disease].keys():
- departments = results[disease]["department"]
- else:
- edges = KGEdgeService(next(get_db())).get_edges_by_nodes(src_id=disease, category='所属科室')
- #edges有可能为空,这里需要做一下处理
- if len(edges) > 0:
- departments = [edge['dest_node']['name'] for edge in edges]
- # 处理查询结果
- for department in departments:
- total += 1
- if not department in final_results.keys():
- final_results[department] = {
- "diseases": [str(disease)+":"+results[disease].get("name", disease)],
- "checks": results[disease].get("check", []),
- "drugs": results[disease].get("drug", []),
- "count": 1
- }
- else:
- final_results[department]["diseases"] = final_results[department]["diseases"] + [str(disease)+":"+
- results[disease].get("name", disease)]
- final_results[department]["checks"] = final_results[department]["checks"] + results[disease].get(
- "check", [])
- final_results[department]["drugs"] = final_results[department]["drugs"] + results[disease].get(
- "drug", [])
- final_results[department]["count"] += 1
- # 这里是统计科室出现的分布
- for department in final_results.keys():
- final_results[department]["score"] = final_results[department]["count"] / total
- print(f"STEP 3 finished")
- # 这里输出日志
- log_data = ["|department|disease|check|drug|count|score"]
- log_data.append("|--|--|--|--|--|--|")
- for department in final_results.keys():
- diesease_data = final_results[department].get("diseases", [])
- check_data = final_results[department].get("checks", [])
- drug_data = final_results[department].get("drugs", [])
- count_data = final_results[department].get("count", 0)
- score_data = final_results[department].get("score", 0)
- log_data.append(
- f"|{department}|{','.join(diesease_data)}|{','.join(check_data)}|{','.join(drug_data)}|{count_data}|{score_data}|")
- print("\n" + "\n".join(log_data))
- return final_results
- def step4(self, final_results):
- """
- 对final_results中的疾病、检查和药品进行统计和排序
-
- 参数:
- final_results: 包含科室、疾病、检查和药品的字典
-
- 返回值:
- 排序后的final_results
- """
- print(f"STEP 4 start")
- start_time = time.time()
- def sort_data(data, count=10):
- tmp = {}
- for item in data:
- if item in tmp.keys():
- tmp[item]["count"] += 1
- else:
- tmp[item] = {"count": 1}
- sorted_data = sorted(tmp.items(), key=lambda x: x[1]["count"], reverse=True)
- return sorted_data[:count]
- for department in final_results.keys():
- final_results[department]['name'] = department
- final_results[department]["diseases"] = sort_data(final_results[department]["diseases"])
- #final_results[department]["checks"] = sort_data(final_results[department]["checks"])
- #final_results[department]["drugs"] = sort_data(final_results[department]["drugs"])
- # 这里把科室做一个排序,按照出现的次数降序排序
- sorted_final_results = sorted(final_results.items(), key=lambda x: x[1]["count"], reverse=True)
-
- print(f"STEP 4 finished")
- end_time = time.time()
- print(f"STEP 4 执行完成,耗时:{end_time - start_time:.2f}秒")
- # 这里输出markdown日志
- log_data = ["|department|disease|check|drug|count|score"]
- log_data.append("|--|--|--|--|--|--|")
- for department in final_results.keys():
- diesease_data = final_results[department].get("diseases")
- check_data = final_results[department].get("checks")
- drug_data = final_results[department].get("drugs")
- count_data = final_results[department].get("count", 0)
- score_data = final_results[department].get("score", 0)
- log_data.append(f"|{department}|{diesease_data}|{check_data}|{drug_data}|{count_data}|{score_data}|")
- print("\n" + "\n".join(log_data))
- return sorted_final_results
- def step5(self, final_results, input, start_nodes, symptom_edge):
- """
- 按科室汇总结果并排序
- 参数:
- final_results: 各科室的初步结果
- input: 患者输入信息
- 返回值:
- 返回排序后的诊断结果
- """
- print(f"STEP 5 start")
- start_time = time.time()
- diags = {}
- total_diags = 0
- for department in final_results.keys():
- department_factor = 0.1 if department == 'DEFAULT' else final_results[department]["score"]
- count = 0
- #当前科室权重增加0.1
- if input.department.value == department:
- count = 1
- for disease, data in final_results[department]["diseases"]:
- total_diags += 1
- if disease in diags.keys():
- diags[disease]["count"] += data["count"]+count
- diags[disease]["score"] += (data["count"]+count)*0.1 * department_factor
- else:
- diags[disease] = {"count": data["count"]+count, "score": (data["count"]+count)*0.1 * department_factor}
-
- #sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)[:10]
- sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
- diags = {}
- for item in sorted_score_diags:
- disease_info = item[0].split(":");
- disease_id = disease_info[0]
- disease = disease_info[1]
- symptoms_data = self.get_symptoms_data(disease_id, symptom_edge)
- if symptoms_data is None:
- continue
- symptoms = []
- for symptom in symptoms_data:
- matched = False
- if symptom in start_nodes:
- matched = True
- symptoms.append({"name": symptom, "matched": matched})
- # symtoms中matched=true的排在前面,matched=false的排在后面
- symptoms = sorted(symptoms, key=lambda x: x["matched"], reverse=True)
- start_nodes_size = len(start_nodes)
- # if start_nodes_size > 1:
- # start_nodes_size = start_nodes_size*0.5
- new_item = {"old_score": item[1]["score"],"count": item[1]["count"], "score": float(item[1]["count"])/start_nodes_size/2*0.1,"symptoms":symptoms}
- diags[disease] = new_item
- sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
- print(f"STEP 5 finished")
- end_time = time.time()
- print(f"STEP 5 执行完成,耗时:{end_time - start_time:.2f}秒")
- log_data = ["|department|disease|count|score"]
- log_data.append("|--|--|--|--|")
- for department in final_results.keys():
- diesease_data = final_results[department].get("diseases")
- count_data = final_results[department].get("count", 0)
- score_data = final_results[department].get("score", 0)
- log_data.append(f"|{department}|{diesease_data}|{count_data}|{score_data}|")
- print("这里是经过排序的数据\n" + "\n".join(log_data))
- return sorted_score_diags, total_diags
-
- def get_symptoms_data(self, disease_id, symptom_edge):
- """
- 获取疾病相关的症状数据
- :param disease_id: 疾病节点ID
- :param symptom_edge: 症状关系类型列表
- :return: 症状数据列表
- """
- key = f'disease_{disease_id}_symptom'
- symptom_data = self.cache[key] if key in self.cache else None
- if symptom_data is None:
- out_edges = self.graph.out_edges(int(disease_id), data=True)
- symptom_data = []
- for edge in out_edges:
- src, dest, edge_data = edge
- if edge_data["type"] not in symptom_edge:
- continue
- dest_data = self.entity_data[self.entity_data.index == dest]
- if dest_data.empty:
- continue
- dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
- if dest_name not in symptom_data:
- symptom_data.append(dest_name)
- self.cache[key]=symptom_data
- return symptom_data
|