123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- import os
- import sys
- current_path = os.getcwd()
- sys.path.append(current_path)
- from community.graph_helper import GraphHelper
- from typing import List
- from cdss.models.schemas import CDSSInput
- class CDSSHelper(GraphHelper):
- def check_sex_allowed(self, node, sex):
- #性别过滤,假设疾病节点有一个属性叫做allowed_sex_type,值为“0,1,2”,分别代表未知,男,女
- sex_allowed = self.graph.nodes[node].get('allowed_sex_list', None)
- if sex_allowed:
- sex_allowed_list = sex_allowed.split(',')
- if sex not in sex_allowed_list:
- #如果性别不匹配,跳过
- return False
- return True
- def check_age_allowed(self, node, age):
- #年龄过滤,假设疾病节点有一个属性叫做allowed_age_range,值为“6-88”,代表年龄在0-88月之间是允许的
- #如果说年龄小于6岁,那么我们就认为是儿童,所以儿童的年龄范围是0-6月
- age_allowed = self.graph.nodes[node].get('allowed_age_range', None)
- if age_allowed:
- age_allowed_list = age_allowed.split('-')
- age_min = int(age_allowed_list[0])
- age_max = int(age_allowed_list[-1])
- if age >= age_min and age < age_max:
- #如果年龄范围正常,那么返回True
- return True
- else:
- #如果没有设置年龄范围,那么默认返回True
- return True
- return False
-
- def cdss_travel(self, input:CDSSInput, start_nodes:List, max_hops=3):
- #这里设置了节点的type取值范围,可以根据实际情况进行修改,允许出现多个类型
- DEPARTMENT=['科室']
- DIESEASE=['疾病']
- DRUG=['药品']
- CHECK=['检查']
- SYMPTOM=['症状']
- #allowed_types = ['科室', '疾病', '药品', '检查', '症状']
- allowed_types = DEPARTMENT + DIESEASE+ DRUG + CHECK + SYMPTOM
- #这里设置了边的type取值范围,可以根据实际情况进行修改,允许出现多个类型
- #不过后面的代码里面没有对边的type进行过滤,所以这里是留做以后扩展的
- allowed_links = ['has_symptom', 'need_check', 'recommend_drug', 'belongs_to']
- #详细解释下面一行代码
- #queue是一个队列,里面存放了待遍历的节点,每个节点都有一个depth,表示当前节点的深度,
- #一个path,表示当前节点的路径,一个data,表示当前节点的一些额外信息,比如allowed_types和allowed_links
- #allowed_types表示当前节点的类型,allowed_links表示当前节点的边的类型
- #这里的start_nodes是一个列表,里面存放了起始节点,每个起始节点都有一个depth为0,一个path为"/",一个data为{'allowed_types': allowed_types, 'allowed_links':allowed_links}
- #这里的"/"表示根节点,因为根节点没有父节点,所以路径为"/"
- #这里的data是一个字典,里面存放了allowed_types和allowed_links,这两个值都是列表,里面存放了允许的类型
- queue = [(node, 0, "/", {'allowed_types': allowed_types, 'allowed_links':allowed_links}) for node in start_nodes]
- visited = set()
- results = {}
- #整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换
- if input.pat_age.value > 0 and input.pat_age.type == 'year':
- #这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的
- input.pat_age.value = input.pat_age.value * 12
- input.pat_age.type = 'month'
-
- #STEP 1: 假设start_nodes里面都是症状,第一步我们先找到这些症状对应的疾病
- #由于这部分是按照症状逐一去寻找疾病,所以实际应用中可以缓存这些结果
- while queue:
- node, depth, path, data = queue.pop(0)
- #allowed_types = data['allowed_types']
- #allowed_links = data['allowed_links']
- indent = depth * 4
- node_type = self.graph.nodes[node].get('type')
- if node_type in DIESEASE:
- if self.check_sex_allowed(node, input.pat_sex.value) == False:
- continue
- if self.check_age_allowed(node, input.pat_age.value) == False:
- continue
- if node in results.keys():
- results[node]["count"] = results[node]["count"] + 1
- #print("疾病", node, "出现的次数", results[node]["count"])
- else:
- results[node] = {"type": node_type, "count":1, 'path':path}
- continue
-
- if node in visited or depth > max_hops:
- #print(">>> already visited or reach max hops")
- continue
-
- visited.add(node)
- for edge in self.graph.edges(node, data=True):
- src, dest, edge_data = edge
- #if edge_data.get('type') not in allowed_links:
- # continue
- if edge[1] not in visited and depth + 1 < max_hops:
- queue.append((edge[1], depth + 1, path+"/"+src, data))
- #print("-" * (indent+4), f"start travel from {src} to {dest}")
-
- #STEP 2: 找到这些疾病对应的科室,检查和药品
- #由于这部分是按照疾病逐一去寻找,所以实际应用中可以缓存这些结果
- for disease in results.keys():
- queue = [(disease, 0, {'allowed_types': DEPARTMENT, 'allowed_links':['belongs_to']})]
- visited = set()
- while queue:
- node, depth, data = queue.pop(0)
- indent = depth * 4
- if node in visited or depth > max_hops:
- #print(">>> already visited or reach max hops")
- continue
-
- visited.add(node)
- node_type = self.graph.nodes[node].get('type')
- if node_type in DEPARTMENT:
- #展开科室,重复次数为疾病出现的次数,为了方便后续统计
- department_data = [node] * results[disease]["count"]
- # if results[disease]["count"] > 1:
- # print("展开了科室", node, "次数", results[disease]["count"], "次")
- if 'department' in results[disease].keys():
- results[disease]["department"] = results[disease]["department"] + department_data
- else:
- results[disease]["department"] = department_data
- continue
- if node_type in CHECK:
- if 'check' in results[disease].keys():
- results[disease]["check"] = list(set(results[disease]["check"]+[node]))
- else:
- results[disease]["check"] = [node]
- continue
- if node_type in DRUG:
- if 'drug' in results[disease].keys():
- results[disease]["drug"] = list(set(results[disease]["drug"]+[node]))
- else:
- results[disease]["drug"] = [node]
- continue
- for edge in self.graph.edges(node, data=True):
- src, dest, edge_data = edge
- #if edge_data.get('type') not in allowed_links:
- # continue
- if edge[1] not in visited and depth + 1 < max_hops:
- queue.append((edge[1], depth + 1, data))
- #print("-" * (indent+4), f"start travel from {src} to {dest}")
- #STEP 3: 对于结果按照科室维度进行汇总
- final_results = {}
- total = 0
- for disease in results.keys():
- if 'department' in results[disease].keys():
- total += 1
- for department in results[disease]["department"]:
- if not department in final_results.keys():
- final_results[department] = {
- "diseases": [disease],
- "checks": results[disease].get("check",[]),
- "drugs": results[disease].get("drug",[]),
- "count": 1
- }
- else:
- final_results[department]["diseases"] = final_results[department]["diseases"]+[disease]
- final_results[department]["checks"] = final_results[department]["checks"]+results[disease].get("check",[])
- final_results[department]["drugs"] = final_results[department]["drugs"]+results[disease].get("drug",[])
- final_results[department]["count"] += 1
-
- for department in final_results.keys():
- final_results[department]["score"] = final_results[department]["count"] / total
- #STEP 4: 对于final_results里面的disease,checks和durgs统计出现的次数并且按照次数降序排序
- def sort_data(data, count=5):
- tmp = {}
- for item in data:
- if item in tmp.keys():
- tmp[item]["count"] +=1
- else:
- tmp[item] = {"count":1}
- sorted_data = sorted(tmp.items(), key=lambda x:x[1]["count"],reverse=True)
- return sorted_data[:count]
-
- for department in final_results.keys():
- final_results[department]['name'] = department
- final_results[department]["diseases"] = sort_data(final_results[department]["diseases"])
- final_results[department]["checks"] = sort_data(final_results[department]["checks"])
- final_results[department]["drugs"] = sort_data(final_results[department]["drugs"])
-
- sorted_final_results = sorted(final_results.items(), key=lambda x:x[1]["count"],reverse=True)
-
- #STEP 5: 对于final_results里面的checks和durgs统计全局出现的次数并且按照次数降序排序
- checks = {}
- drugs ={}
- total_check = 0
- total_drug = 0
- for department in final_results.keys():
- for check, data in final_results[department]["checks"]:
- total_check += 1
- if check in checks.keys():
- checks[check]["count"] += data["count"]
- else:
- checks[check] = {"count":data["count"]}
- for drug, data in final_results[department]["drugs"]:
- total_drug += 1
- if drug in drugs.keys():
- drugs[drug]["count"] += data["count"]
- else:
- drugs[drug] = {"count":data["count"]}
-
- sorted_checks = sorted(checks.items(), key=lambda x:x[1]["count"],reverse=True)
- sorted_drugs = sorted(drugs.items(), key=lambda x:x[1]["count"],reverse=True)
- #STEP 6: 整合数据并返回
- # if "department" in item.keys():
- # final_results["department"] = list(set(final_results["department"]+item["department"]))
- # if "diseases" in item.keys():
- # final_results["diseases"] = list(set(final_results["diseases"]+item["diseases"]))
- # if "checks" in item.keys():
- # final_results["checks"] = list(set(final_results["checks"]+item["checks"]))
- # if "drugs" in item.keys():
- # final_results["drugs"] = list(set(final_results["drugs"]+item["drugs"]))
- # if "symptoms" in item.keys():
- # final_results["symptoms"] = list(set(final_results["symptoms"]+item["symptoms"]))
- return {"details":sorted_final_results,
- "checks":sorted_checks, "drugs":sorted_drugs,
- "total_checks":total_check, "total_drugs":total_drug}
|