cdss_helper.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. import os
  2. import sys
  3. import logging
  4. import json
  5. current_path = os.getcwd()
  6. sys.path.append(current_path)
  7. from libs.graph_helper import GraphHelper
  8. from typing import List
  9. from cdss.models.schemas import CDSSInput
  10. from config.site import SiteConfig
  11. import networkx as nx
  12. import pandas as pd
  13. logger = logging.getLogger(__name__)
  14. class CDSSHelper(GraphHelper):
  15. def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
  16. """节点检索功能"""
  17. es_result = self.es.search_title_index("graph_entity_index", node_id, limit)
  18. results = []
  19. for item in es_result:
  20. score = item["score"]
  21. results.append({
  22. 'id': item["title"],
  23. 'score': score,
  24. "name": item["title"],
  25. })
  26. return results
  27. def _load_entity_data(self):
  28. config = SiteConfig()
  29. CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
  30. logger.info("load entity data")
  31. #这里设置了读取的属性
  32. data = {"id":[], "name":[], "type":[], "allowed_sex_list":[], "allowed_age_range":[]}
  33. with open(f"{CACHED_DATA_PATH}\\entities_med.json", "r", encoding="utf-8") as f:
  34. entities = json.load(f)
  35. for item in entities:
  36. data["id"].append(int(item[0]))
  37. data["name"].append(item[1]["name"])
  38. data["type"].append(item[1]["type"])
  39. data["allowed_sex_liste"].append(item[1]["allowed_sex_list"]) if "allowed_sex_list" in item[1] else data["allowed_sex_list"].append("")
  40. data["allowed_age_range"].append(item[1]["allowed_age_range"]) if "allowed_age_range" in item[1] else data["allowed_age_range"].append("")
  41. #item[1]["id"] = item[0]
  42. #item[1]["name"] = item[0]
  43. #attrs = item[1]
  44. #self.graph.add_node(item[0], **attrs)
  45. self.entity_data = pd.DataFrame(data)
  46. self.entity_data.set_index("id", inplace=True)
  47. logger.info("load entity data finished")
  48. def _load_relation_data(self):
  49. config = SiteConfig()
  50. CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
  51. logger.info("load relationship data")
  52. for i in range(99):
  53. if os.path.exists(f"{CACHED_DATA_PATH}\\relationship_med_{i}.json"):
  54. logger.info(f"load entity data {CACHED_DATA_PATH}\\relationship_med_{i}.json")
  55. with open(f"{CACHED_DATA_PATH}\\relationship_med_{i}.json", "r", encoding="utf-8") as f:
  56. data = {"src":[], "dest":[], "type":[], "weight":[]}
  57. relations = json.load(f)
  58. for item in relations:
  59. data["src"].append(int(item[0]))
  60. data["dest"].append(int(item[2]))
  61. if data['src'] == 2969539 or data['dest'] == 2969539:
  62. print(">>>>>>>> FOUND 2969539")
  63. data["type"].append(item[4]["type"])
  64. data["weight"].append(item[4]["weight"]) if "weight" in item[4] else data["weight"].append(1)
  65. self.relation_data = pd.concat([self.relation_data, pd.DataFrame(data)], ignore_index=True)
  66. def build_graph(self):
  67. self.entity_data = pd.DataFrame({"id":[],"name":[], "type":[], "allowed_sex_list":[], "allowed_age_range":[]})
  68. self.relation_data = pd.DataFrame({"src":[], "dest":[], "type":[], "weight":[]})
  69. self._load_entity_data()
  70. self._load_relation_data()
  71. self._load_local_data()
  72. self.graph = nx.from_pandas_edgelist(self.relation_data, "src", "dest", edge_attr=True, create_using=nx.DiGraph())
  73. nx.set_node_attributes(self.graph, self.entity_data.to_dict(orient="index"))
  74. #print(self.graph.in_edges('1257357',data=True))
  75. def _load_local_data(self):
  76. #这里加载update数据和权重数据
  77. config = SiteConfig()
  78. self.update_data_path = config.get_config('UPDATE_DATA_PATH')
  79. self.factor_data_path = config.get_config('FACTOR_DATA_PATH')
  80. logger.info(f"load update data from {self.update_data_path}")
  81. for root, dirs, files in os.walk(self.update_data_path):
  82. for file in files:
  83. file_path = os.path.join(root, file)
  84. if file_path.endswith(".json") and file.startswith("ent"):
  85. self._load_update_entity_json(file_path)
  86. if file_path.endswith(".json") and file.startswith("rel"):
  87. self._load_update_relationship_json(file_path)
  88. def _load_update_entity_json(self, file):
  89. '''load json data from file'''
  90. logger.info(f"load entity update data from {file}")
  91. #这里加载update数据,update数据是一个json文件,格式同cached data如下:
  92. with open(file, "r", encoding="utf-8") as f:
  93. entities = json.load(f)
  94. for item in entities:
  95. original_data = self.entity_data[self.entity_data.index==item[0]]
  96. if original_data.empty:
  97. continue
  98. original_data = original_data.iloc[0]
  99. id=int(item[0])
  100. name = item[1]["name"] if "name" in item[1] else original_data['name']
  101. type = item[1]["type"] if "type" in item[1] else original_data['type']
  102. allowed_sex_liste = item[1]["allowed_sex_list"] if "allowed_sex_list" in item[1] else original_data['allowed_sex_list']
  103. allowed_age_range = item[1]["allowed_age_range"] if "allowed_age_range" in item[1] else original_data['allowed_age_range']
  104. self.entity_data.loc[id,["name", "type", "allowed_sex_list","allowed_age_range"]] = [name, type, allowed_sex_liste, allowed_age_range]
  105. def _load_update_relationship_json(self, file):
  106. '''load json data from file'''
  107. logger.info(f"load relationship update data from {file}")
  108. with open(file, "r", encoding="utf-8") as f:
  109. relations = json.load(f)
  110. for item in relations:
  111. data = {}
  112. original_data = self.relation_data[(self.relation_data['src']==data['src']) &
  113. (self.relation_data['dest']==data['dest']) &
  114. (self.relation_data['type']==data['type'])]
  115. if original_data.empty:
  116. continue
  117. original_data = original_data.iloc[0]
  118. data["src"] = int(item[0])
  119. data["dest"]= int(item[2])
  120. data["type"]= item[4]["type"]
  121. data["weight"]=item[4]["weight"] if "weight" in item[4] else original_data['weight']
  122. self.relation_data.loc[(self.relation_data['src']==data['src']) &
  123. (self.relation_data['dest']==data['dest']) &
  124. (self.relation_data['type']==data['type']), 'weight'] = data["weight"]
  125. def check_sex_allowed(self, node, sex):
  126. #性别过滤,假设疾病节点有一个属性叫做allowed_sex_type,值为“0,1,2”,分别代表未知,男,女
  127. sex_allowed = self.graph.nodes[node].get('allowed_sex_list', None)
  128. if sex_allowed:
  129. if len(sex_allowed) == 0:
  130. #如果性别列表为空,那么默认允许所有性别
  131. return True
  132. sex_allowed_list = sex_allowed.split(',')
  133. if sex not in sex_allowed_list:
  134. #如果性别不匹配,跳过
  135. return False
  136. return True
  137. def check_age_allowed(self, node, age):
  138. #年龄过滤,假设疾病节点有一个属性叫做allowed_age_range,值为“6-88”,代表年龄在0-88月之间是允许的
  139. #如果说年龄小于6岁,那么我们就认为是儿童,所以儿童的年龄范围是0-6月
  140. age_allowed = self.graph.nodes[node].get('allowed_age_range', None)
  141. if age_allowed:
  142. if len(age_allowed) == 0:
  143. #如果年龄范围为空,那么默认允许所有年龄
  144. return True
  145. age_allowed_list = age_allowed.split('-')
  146. age_min = int(age_allowed_list[0])
  147. age_max = int(age_allowed_list[-1])
  148. if age >= age_min and age < age_max:
  149. #如果年龄范围正常,那么返回True
  150. return True
  151. else:
  152. #如果没有设置年龄范围,那么默认返回True
  153. return True
  154. return False
  155. def cdss_travel(self, input:CDSSInput, start_nodes:List, max_hops=3):
  156. #这里设置了节点的type取值范围,可以根据实际情况进行修改,允许出现多个类型
  157. DEPARTMENT=['科室','Department']
  158. DIESEASE=['疾病','Disease']
  159. DRUG=['药品','Drug']
  160. CHECK=['检查','Check']
  161. SYMPTOM=['症状','Symptom']
  162. allowed_types = DEPARTMENT + DIESEASE+ DRUG + CHECK + SYMPTOM
  163. #这里设置了边的type取值范围,可以根据实际情况进行修改,允许出现多个类型
  164. #不过后面的代码里面没有对边的type进行过滤,所以这里是留做以后扩展的
  165. allowed_links = ['has_symptom', 'need_check', 'recommend_drug', 'belongs_to']
  166. #这里要将用户输入的文本转换成节点id,由于存在同名节点的情况,所以实际node_ids的数量会大于start_nodes的数量
  167. node_ids = []
  168. node_id_names = {}
  169. for node in start_nodes:
  170. logger.debug(f"searching for node {node}")
  171. result = self.entity_data[self.entity_data['name'] == node]
  172. for index, data in result.iterrows():
  173. node_id_names[index] = data["name"]
  174. node_ids = node_ids + [index]
  175. logger.info(f"start travel from {node_id_names}")
  176. #这里是一个队列,用于存储待遍历的症状:
  177. node_ids_filtered = []
  178. for node in node_ids:
  179. if self.graph.has_node(node):
  180. node_ids_filtered.append(node)
  181. else:
  182. logger.debug(f"node {node} not found")
  183. node_ids = node_ids_filtered
  184. queue = [(node, 0, node_id_names[node], {'allowed_types': allowed_types, 'allowed_links':allowed_links}) for node in node_ids]
  185. visited = set()
  186. results = {}
  187. #整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换
  188. if input.pat_age.value > 0 and input.pat_age.type == 'year':
  189. #这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的
  190. input.pat_age.value = input.pat_age.value * 12
  191. input.pat_age.type = 'month'
  192. #STEP 1: 假设start_nodes里面都是症状,第一步我们先找到这些症状对应的疾病
  193. #TODO 由于这部分是按照症状逐一去寻找疾病,所以实际应用中可以缓存这些结果
  194. while queue:
  195. node, depth, path, data = queue.pop(0)
  196. #这里是通过id去获取节点的name和type
  197. node_type = self.entity_data[self.entity_data.index == node]['type'].tolist()[0]
  198. node_name = self.entity_data[self.entity_data.index == node]['name'].tolist()[0]
  199. logger.debug(f"node {node} type {node_type}")
  200. if node_type in DIESEASE:
  201. logger.debug(f"node {node} type {node_type} is a disease")
  202. if self.check_sex_allowed(node, input.pat_sex.value) == False:
  203. continue
  204. if self.check_age_allowed(node, input.pat_age.value) == False:
  205. continue
  206. if node in results.keys():
  207. results[node]["count"] = results[node]["count"] + 1
  208. results[node]["path"].append(path)
  209. else:
  210. results[node] = {"type": node_type, "count":1, "name":node_name, 'path':[path]}
  211. continue
  212. if node in visited or depth > max_hops:
  213. logger.debug(f"{node} already visited or reach max hops")
  214. continue
  215. visited.add(node)
  216. logger.debug(f"check edges from {node}")
  217. for edge in self.graph.in_edges(node, data=True):
  218. src, dest, edge_data = edge
  219. if src not in visited and depth + 1 < max_hops:
  220. logger.debug(f"put into queue travel from {src} to {dest}")
  221. queue.append((src, depth + 1, path, data))
  222. else:
  223. logger.debug(f"skip travel from {src} to {dest}")
  224. #print("-" * (indent+4), f"start travel from {src} to {dest}")
  225. logger.info(f"STEP 1 finished")
  226. #这里输出markdonw格式日志
  227. log_data = ["|疾病|症状|出现次数|是否相关"]
  228. log_data.append("|--|--|--|--|")
  229. for item in results:
  230. data = results[item]
  231. data['relevant'] = False
  232. if data["count"] / len(start_nodes) > 0.5:
  233. #疾病有50%以上的症状出现,才认为是相关的
  234. data['relevant'] = True
  235. log_data.append(f"|{data['name']}|{','.join(data['path'])}|{data['count']}|{data['relevant']}|")
  236. content = "疾病和症状相关性统计表格\n"+"\n".join(log_data)
  237. logger.debug(f"\n{content}")
  238. #STEP 2: 找到这些疾病对应的科室,检查和药品
  239. #由于这部分是按照疾病逐一去寻找,所以实际应用中可以缓存这些结果
  240. logger.info("STEP 2 start")
  241. for disease in results.keys():
  242. #TODO 这里需要对疾病对应的科室检查药品进行加载缓存,性能可以得到很大的提升
  243. if results[disease]["relevant"] == False:
  244. continue
  245. logger.debug(f"search data for {disease}:{results[disease]['name']}")
  246. queue = []
  247. queue.append((disease, 0, disease, {'allowed_types': DEPARTMENT, 'allowed_links':['belongs_to']}))
  248. #这里尝试过将visited放倒for disease循环外面,但是会造成一些问题,性能提升也不明显,所以这里还是放在for disease循环里面
  249. visited = set()
  250. while queue:
  251. node, depth, disease, data = queue.pop(0)
  252. if node in visited or depth > max_hops:
  253. continue
  254. visited.add(node)
  255. node_type = self.entity_data[self.entity_data.index == node]['type'].tolist()[0]
  256. node_name = self.entity_data[self.entity_data.index == node]['name'].tolist()[0]
  257. logger.debug(f"node {results[disease].get("name", disease)} {node_name} type {node_type}")
  258. #node_type = self.graph.nodes[node].get('type')
  259. if node_type in DEPARTMENT:
  260. #展开科室,重复次数为疾病出现的次数,为了方便后续统计
  261. department_data = [node_name] * results[disease]["count"]
  262. if 'department' in results[disease].keys():
  263. results[disease]["department"] = results[disease]["department"] + department_data
  264. else:
  265. results[disease]["department"] = department_data
  266. continue
  267. if node_type in CHECK:
  268. if 'check' in results[disease].keys():
  269. results[disease]["check"] = list(set(results[disease]["check"]+[node_name]))
  270. else:
  271. results[disease]["check"] = [node_name]
  272. continue
  273. if node_type in DRUG:
  274. if 'drug' in results[disease].keys():
  275. results[disease]["drug"] = list(set(results[disease]["drug"]+[node_name]))
  276. else:
  277. results[disease]["drug"] = [node_name]
  278. continue
  279. for edge in self.graph.out_edges(node, data=True):
  280. src, dest, edge_data = edge
  281. src_name = self.entity_data[self.entity_data.index == src]['name'].tolist()[0]
  282. dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
  283. dest_type = self.entity_data[self.entity_data.index == dest]['type'].tolist()[0]
  284. if dest_type in allowed_types:
  285. if dest not in visited and depth + 1 < max_hops:
  286. logger.debug(f"put travel request in queue from {src}:{src_name} to {dest}:{dest_name}")
  287. queue.append((edge[1], depth + 1, disease, data))
  288. #TODO 可以在这里将results里面的每个疾病对应的科室,检查和药品进行缓存,方便后续使用
  289. # for item in results.keys():
  290. # department_data = results[item].get("department", [])
  291. # count_data = results[item].get("count")
  292. # check_data = results[item].get("check", [])
  293. # drug_data = results[item].get("drug", [])
  294. # #缓存代码放在这里
  295. logger.info(f"STEP 2 finished")
  296. #这里输出日志
  297. log_data = ["|disease|count|department|check|drug|"]
  298. log_data.append("|--|--|--|--|--|")
  299. for item in results.keys():
  300. department_data = results[item].get("department", [])
  301. count_data = results[item].get("count")
  302. check_data = results[item].get("check", [])
  303. drug_data = results[item].get("drug", [])
  304. log_data.append(f"|{results[item].get("name", item)}|{count_data}|{','.join(department_data)}|{','.join(check_data)}|{','.join(drug_data)}|")
  305. logger.debug("疾病科室检查药品相关统计\n"+"\n".join(log_data))
  306. #日志输出完毕
  307. #STEP 3: 对于结果按照科室维度进行汇总
  308. logger.info(f"STEP 3 start")
  309. final_results = {}
  310. total = 0
  311. for disease in results.keys():
  312. #由于存在有些疾病没有科室的情况,所以这里需要做一下处理
  313. departments = ['DEFAULT']
  314. if 'department' in results[disease].keys():
  315. departments = results[disease]["department"]
  316. for department in departments:
  317. total += 1
  318. if not department in final_results.keys():
  319. final_results[department] = {
  320. "diseases": [results[disease].get("name",disease)],
  321. "checks": results[disease].get("check",[]),
  322. "drugs": results[disease].get("drug",[]),
  323. "count": 1
  324. }
  325. else:
  326. final_results[department]["diseases"] = final_results[department]["diseases"]+[results[disease].get("name",disease)]
  327. final_results[department]["checks"] = final_results[department]["checks"]+results[disease].get("check",[])
  328. final_results[department]["drugs"] = final_results[department]["drugs"]+results[disease].get("drug",[])
  329. final_results[department]["count"] += 1
  330. #这里是统计科室出现的分布
  331. for department in final_results.keys():
  332. final_results[department]["score"] = final_results[department]["count"] / total
  333. logger.info(f"STEP 3 finished")
  334. #这里输出日志
  335. log_data = ["|department|disease|check|drug|count|score"]
  336. log_data.append("|--|--|--|--|--|--|")
  337. for department in final_results.keys():
  338. diesease_data = final_results[department].get("diseases", [])
  339. check_data = final_results[department].get("checks", [])
  340. drug_data = final_results[department].get("drugs", [])
  341. count_data = final_results[department].get("count", 0)
  342. score_data = final_results[department].get("score", 0)
  343. log_data.append(f"|{department}|{','.join(diesease_data)}|{','.join(check_data)}|{','.join(drug_data)}|{count_data}|{score_data}|")
  344. logger.debug("\n"+"\n".join(log_data))
  345. #STEP 4: 对于final_results里面的disease,checks和durgs统计出现的次数并且按照次数降序排序
  346. logger.info(f"STEP 4 start")
  347. def sort_data(data, count=5):
  348. tmp = {}
  349. for item in data:
  350. if item in tmp.keys():
  351. tmp[item]["count"] +=1
  352. else:
  353. tmp[item] = {"count":1}
  354. sorted_data = sorted(tmp.items(), key=lambda x:x[1]["count"],reverse=True)
  355. return sorted_data[:count]
  356. for department in final_results.keys():
  357. final_results[department]['name'] = department
  358. final_results[department]["diseases"] = sort_data(final_results[department]["diseases"])
  359. final_results[department]["checks"] = sort_data(final_results[department]["checks"])
  360. final_results[department]["drugs"] = sort_data(final_results[department]["drugs"])
  361. #这里把科室做一个排序,按照出现的次数降序排序
  362. sorted_final_results = sorted(final_results.items(), key=lambda x:x[1]["count"],reverse=True)
  363. logger.info(f"STEP 4 finished")
  364. #这里输出markdown日志
  365. log_data = ["|department|disease|check|drug|count|score"]
  366. log_data.append("|--|--|--|--|--|--|")
  367. for department in final_results.keys():
  368. diesease_data = final_results[department].get("diseases")
  369. check_data = final_results[department].get("checks")
  370. drug_data = final_results[department].get("drugs")
  371. count_data = final_results[department].get("count", 0)
  372. score_data = final_results[department].get("score", 0)
  373. log_data.append(f"|{department}|{diesease_data}|{check_data}|{drug_data}|{count_data}|{score_data}|")
  374. logger.debug("\n"+"\n".join(log_data))
  375. #STEP 5: 对于final_results里面的diseases, checks和durgs统计全局出现的次数并且按照次数降序排序
  376. logger.info(f"STEP 5 start")
  377. checks = {}
  378. drugs = {}
  379. diags = {}
  380. total_check = 0
  381. total_drug = 0
  382. total_diags = 0
  383. for department in final_results.keys():
  384. #这里是提取了科室出现的概率,对于缺省的科室设置了0.1
  385. #对于疾病来说用疾病在科室中出现的次数乘以科室出现的概率作为分数
  386. department_factor = 0.1 if department == 'DEFAULT' else final_results[department]["score"]
  387. for disease, data in final_results[department]["diseases"]:
  388. total_diags += 1
  389. if disease in diags.keys():
  390. diags[disease]["count"] += data["count"]
  391. diags[disease]["score"] += data["count"] * department_factor
  392. else:
  393. diags[disease] = {"count":data["count"], "score":data["count"] * department_factor}
  394. #对于检查和药品直接累加出现的次数
  395. for check, data in final_results[department]["checks"]:
  396. total_check += 1
  397. if check in checks.keys():
  398. checks[check]["count"] += data["count"]
  399. else:
  400. checks[check] = {"count":data["count"]}
  401. for drug, data in final_results[department]["drugs"]:
  402. total_drug += 1
  403. if drug in drugs.keys():
  404. drugs[drug]["count"] += data["count"]
  405. else:
  406. drugs[drug] = {"count":data["count"]}
  407. sorted_diags = sorted(diags.items(), key=lambda x:x[1]["score"],reverse=True)
  408. sorted_checks = sorted(checks.items(), key=lambda x:x[1]["count"],reverse=True)
  409. sorted_drugs = sorted(drugs.items(), key=lambda x:x[1]["count"],reverse=True)
  410. logger.info(f"STEP 5 finished")
  411. #这里输出markdown日志
  412. log_data = ["|department|disease|check|drug|count|score"]
  413. log_data.append("|--|--|--|--|--|--|")
  414. for department in final_results.keys():
  415. diesease_data = final_results[department].get("diseases")
  416. check_data = final_results[department].get("checks")
  417. drug_data = final_results[department].get("drugs")
  418. count_data = final_results[department].get("count", 0)
  419. score_data = final_results[department].get("score", 0)
  420. log_data.append(f"|{department}|{diesease_data}|{check_data}|{drug_data}|{count_data}|{score_data}|")
  421. logger.debug("这里是经过排序的数据\n"+"\n".join(log_data))
  422. #STEP 6: 整合数据并返回
  423. # if "department" in item.keys():
  424. # final_results["department"] = list(set(final_results["department"]+item["department"]))
  425. # if "diseases" in item.keys():
  426. # final_results["diseases"] = list(set(final_results["diseases"]+item["diseases"]))
  427. # if "checks" in item.keys():
  428. # final_results["checks"] = list(set(final_results["checks"]+item["checks"]))
  429. # if "drugs" in item.keys():
  430. # final_results["drugs"] = list(set(final_results["drugs"]+item["drugs"]))
  431. # if "symptoms" in item.keys():
  432. # final_results["symptoms"] = list(set(final_results["symptoms"]+item["symptoms"]))
  433. return {"details":sorted_final_results,
  434. "diags":sorted_diags, "total_diags":total_diags,
  435. "checks":sorted_checks, "drugs":sorted_drugs,
  436. "total_checks":total_check, "total_drugs":total_drug}