cdss_helper.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. import os
  2. import sys
  3. current_path = os.getcwd()
  4. sys.path.append(current_path)
  5. from community.graph_helper import GraphHelper
  6. from typing import List
  7. from cdss.models.schemas import CDSSInput
  8. class CDSSHelper(GraphHelper):
  9. def check_sex_allowed(self, node, sex):
  10. #性别过滤,假设疾病节点有一个属性叫做allowed_sex_type,值为“0,1,2”,分别代表未知,男,女
  11. sex_allowed = self.graph.nodes[node].get('allowed_sex_list', None)
  12. if sex_allowed:
  13. sex_allowed_list = sex_allowed.split(',')
  14. if sex not in sex_allowed_list:
  15. #如果性别不匹配,跳过
  16. return False
  17. return True
  18. def check_age_allowed(self, node, age):
  19. #年龄过滤,假设疾病节点有一个属性叫做allowed_age_range,值为“6-88”,代表年龄在0-88月之间是允许的
  20. #如果说年龄小于6岁,那么我们就认为是儿童,所以儿童的年龄范围是0-6月
  21. age_allowed = self.graph.nodes[node].get('allowed_age_range', None)
  22. if age_allowed:
  23. age_allowed_list = age_allowed.split('-')
  24. age_min = int(age_allowed_list[0])
  25. age_max = int(age_allowed_list[-1])
  26. if age >= age_min and age < age_max:
  27. #如果年龄范围正常,那么返回True
  28. return True
  29. else:
  30. #如果没有设置年龄范围,那么默认返回True
  31. return True
  32. return False
  33. def cdss_travel(self, input:CDSSInput, start_nodes:List, max_hops=3):
  34. #这里设置了节点的type取值范围,可以根据实际情况进行修改,允许出现多个类型
  35. DEPARTMENT=['科室']
  36. DIESEASE=['疾病']
  37. DRUG=['药品']
  38. CHECK=['检查']
  39. SYMPTOM=['症状']
  40. #allowed_types = ['科室', '疾病', '药品', '检查', '症状']
  41. allowed_types = DEPARTMENT + DIESEASE+ DRUG + CHECK + SYMPTOM
  42. #这里设置了边的type取值范围,可以根据实际情况进行修改,允许出现多个类型
  43. #不过后面的代码里面没有对边的type进行过滤,所以这里是留做以后扩展的
  44. allowed_links = ['has_symptom', 'need_check', 'recommend_drug', 'belongs_to']
  45. #详细解释下面一行代码
  46. #queue是一个队列,里面存放了待遍历的节点,每个节点都有一个depth,表示当前节点的深度,
  47. #一个path,表示当前节点的路径,一个data,表示当前节点的一些额外信息,比如allowed_types和allowed_links
  48. #allowed_types表示当前节点的类型,allowed_links表示当前节点的边的类型
  49. #这里的start_nodes是一个列表,里面存放了起始节点,每个起始节点都有一个depth为0,一个path为"/",一个data为{'allowed_types': allowed_types, 'allowed_links':allowed_links}
  50. #这里的"/"表示根节点,因为根节点没有父节点,所以路径为"/"
  51. #这里的data是一个字典,里面存放了allowed_types和allowed_links,这两个值都是列表,里面存放了允许的类型
  52. queue = [(node, 0, "/", {'allowed_types': allowed_types, 'allowed_links':allowed_links}) for node in start_nodes]
  53. visited = set()
  54. results = {}
  55. #整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换
  56. if input.pat_age.value > 0 and input.pat_age.type == 'year':
  57. #这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的
  58. input.pat_age.value = input.pat_age.value * 12
  59. input.pat_age.type = 'month'
  60. #STEP 1: 假设start_nodes里面都是症状,第一步我们先找到这些症状对应的疾病
  61. #由于这部分是按照症状逐一去寻找疾病,所以实际应用中可以缓存这些结果
  62. while queue:
  63. node, depth, path, data = queue.pop(0)
  64. #allowed_types = data['allowed_types']
  65. #allowed_links = data['allowed_links']
  66. indent = depth * 4
  67. node_type = self.graph.nodes[node].get('type')
  68. if node_type in DIESEASE:
  69. if self.check_sex_allowed(node, input.pat_sex.value) == False:
  70. continue
  71. if self.check_age_allowed(node, input.pat_age.value) == False:
  72. continue
  73. if node in results.keys():
  74. results[node]["count"] = results[node]["count"] + 1
  75. #print("疾病", node, "出现的次数", results[node]["count"])
  76. else:
  77. results[node] = {"type": node_type, "count":1, 'path':path}
  78. continue
  79. if node in visited or depth > max_hops:
  80. #print(">>> already visited or reach max hops")
  81. continue
  82. visited.add(node)
  83. for edge in self.graph.edges(node, data=True):
  84. src, dest, edge_data = edge
  85. #if edge_data.get('type') not in allowed_links:
  86. # continue
  87. if edge[1] not in visited and depth + 1 < max_hops:
  88. queue.append((edge[1], depth + 1, path+"/"+src, data))
  89. #print("-" * (indent+4), f"start travel from {src} to {dest}")
  90. #STEP 2: 找到这些疾病对应的科室,检查和药品
  91. #由于这部分是按照疾病逐一去寻找,所以实际应用中可以缓存这些结果
  92. for disease in results.keys():
  93. queue = [(disease, 0, {'allowed_types': DEPARTMENT, 'allowed_links':['belongs_to']})]
  94. visited = set()
  95. while queue:
  96. node, depth, data = queue.pop(0)
  97. indent = depth * 4
  98. if node in visited or depth > max_hops:
  99. #print(">>> already visited or reach max hops")
  100. continue
  101. visited.add(node)
  102. node_type = self.graph.nodes[node].get('type')
  103. if node_type in DEPARTMENT:
  104. #展开科室,重复次数为疾病出现的次数,为了方便后续统计
  105. department_data = [node] * results[disease]["count"]
  106. # if results[disease]["count"] > 1:
  107. # print("展开了科室", node, "次数", results[disease]["count"], "次")
  108. if 'department' in results[disease].keys():
  109. results[disease]["department"] = results[disease]["department"] + department_data
  110. else:
  111. results[disease]["department"] = department_data
  112. continue
  113. if node_type in CHECK:
  114. if 'check' in results[disease].keys():
  115. results[disease]["check"] = list(set(results[disease]["check"]+[node]))
  116. else:
  117. results[disease]["check"] = [node]
  118. continue
  119. if node_type in DRUG:
  120. if 'drug' in results[disease].keys():
  121. results[disease]["drug"] = list(set(results[disease]["drug"]+[node]))
  122. else:
  123. results[disease]["drug"] = [node]
  124. continue
  125. for edge in self.graph.edges(node, data=True):
  126. src, dest, edge_data = edge
  127. #if edge_data.get('type') not in allowed_links:
  128. # continue
  129. if edge[1] not in visited and depth + 1 < max_hops:
  130. queue.append((edge[1], depth + 1, data))
  131. #print("-" * (indent+4), f"start travel from {src} to {dest}")
  132. #STEP 3: 对于结果按照科室维度进行汇总
  133. final_results = {}
  134. total = 0
  135. for disease in results.keys():
  136. if 'department' in results[disease].keys():
  137. total += 1
  138. for department in results[disease]["department"]:
  139. if not department in final_results.keys():
  140. final_results[department] = {
  141. "diseases": [disease],
  142. "checks": results[disease].get("check",[]),
  143. "drugs": results[disease].get("drug",[]),
  144. "count": 1
  145. }
  146. else:
  147. final_results[department]["diseases"] = final_results[department]["diseases"]+[disease]
  148. final_results[department]["checks"] = final_results[department]["checks"]+results[disease].get("check",[])
  149. final_results[department]["drugs"] = final_results[department]["drugs"]+results[disease].get("drug",[])
  150. final_results[department]["count"] += 1
  151. for department in final_results.keys():
  152. final_results[department]["score"] = final_results[department]["count"] / total
  153. #STEP 4: 对于final_results里面的disease,checks和durgs统计出现的次数并且按照次数降序排序
  154. def sort_data(data, count=5):
  155. tmp = {}
  156. for item in data:
  157. if item in tmp.keys():
  158. tmp[item]["count"] +=1
  159. else:
  160. tmp[item] = {"count":1}
  161. sorted_data = sorted(tmp.items(), key=lambda x:x[1]["count"],reverse=True)
  162. return sorted_data[:count]
  163. for department in final_results.keys():
  164. final_results[department]['name'] = department
  165. final_results[department]["diseases"] = sort_data(final_results[department]["diseases"])
  166. final_results[department]["checks"] = sort_data(final_results[department]["checks"])
  167. final_results[department]["drugs"] = sort_data(final_results[department]["drugs"])
  168. sorted_final_results = sorted(final_results.items(), key=lambda x:x[1]["count"],reverse=True)
  169. #STEP 5: 对于final_results里面的checks和durgs统计全局出现的次数并且按照次数降序排序
  170. checks = {}
  171. drugs ={}
  172. total_check = 0
  173. total_drug = 0
  174. for department in final_results.keys():
  175. for check, data in final_results[department]["checks"]:
  176. total_check += 1
  177. if check in checks.keys():
  178. checks[check]["count"] += data["count"]
  179. else:
  180. checks[check] = {"count":data["count"]}
  181. for drug, data in final_results[department]["drugs"]:
  182. total_drug += 1
  183. if drug in drugs.keys():
  184. drugs[drug]["count"] += data["count"]
  185. else:
  186. drugs[drug] = {"count":data["count"]}
  187. sorted_checks = sorted(checks.items(), key=lambda x:x[1]["count"],reverse=True)
  188. sorted_drugs = sorted(drugs.items(), key=lambda x:x[1]["count"],reverse=True)
  189. #STEP 6: 整合数据并返回
  190. # if "department" in item.keys():
  191. # final_results["department"] = list(set(final_results["department"]+item["department"]))
  192. # if "diseases" in item.keys():
  193. # final_results["diseases"] = list(set(final_results["diseases"]+item["diseases"]))
  194. # if "checks" in item.keys():
  195. # final_results["checks"] = list(set(final_results["checks"]+item["checks"]))
  196. # if "drugs" in item.keys():
  197. # final_results["drugs"] = list(set(final_results["drugs"]+item["drugs"]))
  198. # if "symptoms" in item.keys():
  199. # final_results["symptoms"] = list(set(final_results["symptoms"]+item["symptoms"]))
  200. return {"details":sorted_final_results,
  201. "checks":sorted_checks, "drugs":sorted_drugs,
  202. "total_checks":total_check, "total_drugs":total_drug}