cdss_helper2.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. import copy
  2. from hmac import new
  3. import os
  4. import sys
  5. import logging
  6. import json
  7. import time
  8. from sqlalchemy import false
  9. from service.kg_edge_service import KGEdgeService
  10. from db.session import get_db
  11. from service.kg_node_service import KGNodeService
  12. from service.kg_prop_service import KGPropService
  13. from cachetools import TTLCache
  14. from cachetools.keys import hashkey
  15. current_path = os.getcwd()
  16. sys.path.append(current_path)
  17. from community.graph_helper import GraphHelper
  18. from typing import List
  19. from agent.cdss.models.schemas import CDSSInput
  20. from config.site import SiteConfig
  21. import networkx as nx
  22. import pandas as pd
  23. logger = logging.getLogger(__name__)
  24. current_path = os.getcwd()
  25. sys.path.append(current_path)
  26. # 图谱数据缓存路径(由dump_graph_data.py生成)
  27. CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data')
  28. class CDSSHelper(GraphHelper):
  29. def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
  30. kg_node_service = KGNodeService(next(get_db()))
  31. es_result = kg_node_service.search_title_index("", node_id,node_type, limit)
  32. results = []
  33. for item in es_result:
  34. n = self.graph.nodes.get(item["id"])
  35. score = item["score"]
  36. if n:
  37. results.append({
  38. 'id': item["title"],
  39. 'score': score,
  40. "name": item["title"],
  41. })
  42. return results
  43. def _load_entity_data(self):
  44. config = SiteConfig()
  45. # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
  46. print("load entity data")
  47. # 这里设置了读取的属性
  48. data = {"id": [], "name": [], "type": [],"is_symptom": [], "sex": [], "age": [],"score": [],"TCM": []}
  49. if not os.path.exists(os.path.join(CACHED_DATA_PATH, 'entities_med.json')):
  50. return []
  51. with open(os.path.join(CACHED_DATA_PATH, 'entities_med.json'), "r", encoding="utf-8") as f:
  52. entities = json.load(f)
  53. for item in entities:
  54. #如果id已经存在,则跳过
  55. # if item[0] in data["id"]:
  56. # print(f"skip {item[0]}")
  57. # continue
  58. data["id"].append(int(item[0]))
  59. data["name"].append(item[1]["name"])
  60. data["type"].append(item[1]["type"])
  61. self._append_entity_attribute(data, item, "sex")
  62. self._append_entity_attribute(data, item, "age")
  63. self._append_entity_attribute(data, item, "is_symptom")
  64. self._append_entity_attribute(data, item, "score")
  65. self._append_entity_attribute(data, item, "TCM")
  66. # item[1]["id"] = item[0]
  67. # item[1]["name"] = item[0]
  68. # attrs = item[1]
  69. # self.graph.add_node(item[0], **attrs)
  70. self.entity_data = pd.DataFrame(data)
  71. self.entity_data.set_index("id", inplace=True)
  72. print("load entity data finished")
  73. def get_entity_data(self):
  74. return self.entity_data
  75. def _append_entity_attribute(self, data, item, attr_name):
  76. if attr_name in item[1]:
  77. value = item[1][attr_name].split(":")
  78. if len(value) < 2:
  79. data[attr_name].append(value[0])
  80. else:
  81. data[attr_name].append(value[1])
  82. else:
  83. data[attr_name].append("")
  84. def _load_relation_data(self):
  85. config = SiteConfig()
  86. # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
  87. print("load relationship data")
  88. for i in range(500):
  89. if not os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
  90. break
  91. if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
  92. print(f"load entity data {CACHED_DATA_PATH}\\relationship_med_{i}.json")
  93. with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f:
  94. data = {"src": [], "dest": [], "type": [], "weight": []}
  95. relations = json.load(f)
  96. for item in relations:
  97. data["src"].append(int(item[0]))
  98. data["dest"].append(int(item[2]))
  99. data["type"].append(item[4]["type"])
  100. if "order" in item[4]:
  101. order = item[4]["order"].split(":")
  102. if len(order) < 2:
  103. data["weight"].append(int(order[0]))
  104. else:
  105. data["weight"].append(int(order[1]))
  106. else:
  107. data["weight"].append(1)
  108. self.relation_data = pd.concat([self.relation_data, pd.DataFrame(data)], ignore_index=True)
  109. def build_graph(self):
  110. self.entity_data = pd.DataFrame(
  111. {"id": [], "name": [], "type": [], "sex": [], "allowed_age_range": []})
  112. self.relation_data = pd.DataFrame({"src": [], "dest": [], "type": [], "weight": []})
  113. self._load_entity_data()
  114. self._load_relation_data()
  115. self._load_local_data()
  116. self.graph = nx.from_pandas_edgelist(self.relation_data, "src", "dest", edge_attr=True,
  117. create_using=nx.DiGraph())
  118. nx.set_node_attributes(self.graph, self.entity_data.to_dict(orient="index"))
  119. # print(self.graph.in_edges('1257357',data=True))
  120. def _load_local_data(self):
  121. # 这里加载update数据和权重数据
  122. config = SiteConfig()
  123. self.update_data_path = config.get_config('UPDATE_DATA_PATH')
  124. self.factor_data_path = config.get_config('FACTOR_DATA_PATH')
  125. print(f"load update data from {self.update_data_path}")
  126. for root, dirs, files in os.walk(self.update_data_path):
  127. for file in files:
  128. file_path = os.path.join(root, file)
  129. if file_path.endswith(".json") and file.startswith("ent"):
  130. self._load_update_entity_json(file_path)
  131. if file_path.endswith(".json") and file.startswith("rel"):
  132. self._load_update_relationship_json(file_path)
  133. def _load_update_entity_json(self, file):
  134. '''load json data from file'''
  135. print(f"load entity update data from {file}")
  136. # 这里加载update数据,update数据是一个json文件,格式同cached data如下:
  137. with open(file, "r", encoding="utf-8") as f:
  138. entities = json.load(f)
  139. for item in entities:
  140. original_data = self.entity_data[self.entity_data.index == item[0]]
  141. if original_data.empty:
  142. continue
  143. original_data = original_data.iloc[0]
  144. id = int(item[0])
  145. name = item[1]["name"] if "name" in item[1] else original_data['name']
  146. type = item[1]["type"] if "type" in item[1] else original_data['type']
  147. allowed_sex_list = item[1]["allowed_sex_list"] if "allowed_sex_list" in item[1] else original_data[
  148. 'allowed_sex_list']
  149. allowed_age_range = item[1]["allowed_age_range"] if "allowed_age_range" in item[1] else original_data[
  150. 'allowed_age_range']
  151. self.entity_data.loc[id, ["name", "type", "allowed_sex_list", "allowed_age_range"]] = [name, type,
  152. allowed_sex_list,
  153. allowed_age_range]
  154. def _load_update_relationship_json(self, file):
  155. '''load json data from file'''
  156. print(f"load relationship update data from {file}")
  157. with open(file, "r", encoding="utf-8") as f:
  158. relations = json.load(f)
  159. for item in relations:
  160. data = {}
  161. original_data = self.relation_data[(self.relation_data['src'] == data['src']) &
  162. (self.relation_data['dest'] == data['dest']) &
  163. (self.relation_data['type'] == data['type'])]
  164. if original_data.empty:
  165. continue
  166. original_data = original_data.iloc[0]
  167. data["src"] = int(item[0])
  168. data["dest"] = int(item[2])
  169. data["type"] = item[4]["type"]
  170. data["weight"] = item[4]["weight"] if "weight" in item[4] else original_data['weight']
  171. self.relation_data.loc[(self.relation_data['src'] == data['src']) &
  172. (self.relation_data['dest'] == data['dest']) &
  173. (self.relation_data['type'] == data['type']), 'weight'] = data["weight"]
  174. def check_sex_allowed(self, node, sex):
  175. # 性别过滤,假设疾病节点有一个属性叫做allowed_sex_type,值为“0,1,2”,分别代表未知,男,女
  176. sex_allowed = self.graph.nodes[node].get('sex', None)
  177. #sexProps = self.propService.get_props_by_ref_id(node, 'sex')
  178. #if len(sexProps) > 0 and sexProps[0]['prop_value'] is not None and sexProps[0][
  179. #'prop_value'] != input.pat_sex.value:
  180. #continue
  181. if sex_allowed:
  182. if len(sex_allowed) == 0:
  183. # 如果性别列表为空,那么默认允许所有性别
  184. return True
  185. sex_allowed_list = sex_allowed.split(',')
  186. if sex not in sex_allowed_list:
  187. # 如果性别不匹配,跳过
  188. return False
  189. return True
  190. def check_age_allowed(self, node, age):
  191. # 年龄过滤,假设疾病节点有一个属性叫做allowed_age_range,值为“6-88”,代表年龄在0-88月之间是允许的
  192. # 如果说年龄小于6岁,那么我们就认为是儿童,所以儿童的年龄范围是0-6月
  193. age_allowed = self.graph.nodes[node].get('age', None)
  194. if age_allowed:
  195. if len(age_allowed) == 0:
  196. # 如果年龄范围为空,那么默认允许所有年龄
  197. return True
  198. age_allowed_list = age_allowed.split('-')
  199. age_min = int(age_allowed_list[0])
  200. age_max = int(age_allowed_list[-1])
  201. if age_max ==0:
  202. return True
  203. if age >= age_min and age < age_max:
  204. # 如果年龄范围正常,那么返回True
  205. return True
  206. else:
  207. # 如果没有设置年龄范围,那么默认返回True
  208. return True
  209. return False
  210. def check_diease_allowed(self, node):
  211. if (node ==1489628 or node ==1685808 or node ==1272212 or node ==1303510 or node ==1303336 or node ==1299948
  212. or node ==1916891 or node ==1993268 or node ==1613483 or node ==1783840 or node==1353546 or node==1491526 or node==132923
  213. or node==1287322 or node==1921248 or node==1303271 or node==1811668):
  214. return False
  215. is_symptom = self.graph.nodes[node].get('is_symptom', None)
  216. if is_symptom == "是":
  217. return False
  218. is_tcm = self.graph.nodes[node].get('TCM', None)
  219. if is_tcm == "是":
  220. return False
  221. return True
  222. propService = KGPropService(next(get_db()))
  223. cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
  224. def cdss_travel(self, input: CDSSInput, start_nodes: List, max_hops=3):
  225. """
  226. 基于输入的症状节点,在知识图谱中进行遍历,查找相关疾病、科室、检查和药品
  227. 参数:
  228. input: CDSSInput对象,包含患者的基本信息(年龄、性别等)
  229. start_nodes: 症状节点名称列表,作为遍历的起点
  230. max_hops: 最大遍历深度,默认为3
  231. 返回值:
  232. 返回一个包含以下信息的字典:
  233. - details: 按科室汇总的结果
  234. - diags: 按相关性排序的疾病列表
  235. - checks: 按出现频率排序的检查列表
  236. - drugs: 按出现频率排序的药品列表
  237. - total_diags: 疾病总数
  238. - total_checks: 检查总数
  239. - total_drugs: 药品总数
  240. 主要步骤:
  241. 1. 初始化允许的节点类型和关系类型
  242. 2. 将症状名称转换为节点ID
  243. 3. 遍历图谱查找相关疾病(STEP 1)
  244. 4. 查找疾病对应的科室、检查和药品(STEP 2)
  245. 5. 按科室汇总结果(STEP 3)
  246. 6. 对结果进行排序和统计(STEP 4-6)
  247. """
  248. # 定义允许的节点类型,包括科室、疾病、药品、检查和症状
  249. # 这些类型用于后续的节点过滤和路径查找
  250. DEPARTMENT = ['科室', 'Department']
  251. DIESEASE = ['疾病', 'Disease']
  252. DRUG = ['药品', 'Drug']
  253. CHECK = ['检查', 'Check']
  254. SYMPTOM = ['症状', 'Symptom']
  255. #allowed_types = DEPARTMENT + DIESEASE + DRUG + CHECK + SYMPTOM
  256. allowed_types = DEPARTMENT + DIESEASE + SYMPTOM
  257. # 定义允许的关系类型,包括has_symptom、need_check、recommend_drug、belongs_to
  258. # 这些关系类型用于后续的路径查找和过滤
  259. symptom_edge = ['has_symptom', '疾病相关症状']
  260. symptom_same_edge = ['症状同义词', '症状同义词2.0']
  261. department_edge = ['belongs_to','所属科室']
  262. allowed_links = symptom_edge+symptom_same_edge
  263. # allowed_links = symptom_edge + department_edge
  264. # 将输入的症状名称转换为节点ID
  265. # 由于可能存在同名节点,转换后的节点ID数量可能大于输入的症状数量
  266. node_ids = []
  267. node_id_names = {}
  268. # start_nodes里面重复的症状,去重同样的症状
  269. start_nodes = list(set(start_nodes))
  270. for node in start_nodes:
  271. #print(f"searching for node {node}")
  272. result = self.entity_data[self.entity_data['name'] == node]
  273. #print(f"searching for node {result}")
  274. for index, data in result.iterrows():
  275. if data["type"] in SYMPTOM or data["type"] in DIESEASE:
  276. node_id_names[index] = data["name"]
  277. node_ids = node_ids + [index]
  278. #print(f"start travel from {node_id_names}")
  279. # 这里是一个队列,用于存储待遍历的症状:
  280. node_ids_filtered = []
  281. for node in node_ids:
  282. if self.graph.has_node(node):
  283. node_ids_filtered.append(node)
  284. else:
  285. logger.debug(f"node {node} not found")
  286. node_ids = node_ids_filtered
  287. results = self.step1(node_ids,node_id_names, input, allowed_types, symptom_same_edge,allowed_links,max_hops,DIESEASE)
  288. results = self.validDisease(results, start_nodes)
  289. sorted_count_diags = sorted(results.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
  290. diags = {}
  291. target_symptom_names = []
  292. for symptom_id in node_ids:
  293. target_symptom_name = self.entity_data[self.entity_data.index == symptom_id]['name'].tolist()[0]
  294. target_symptom_names.append(target_symptom_name)
  295. same_symptoms = self.get_in_edges(symptom_id, symptom_same_edge)
  296. #same_symptoms的name属性全部添加到target_symptom_names中
  297. for same_symptom in same_symptoms:
  298. target_symptom_names.append(same_symptom['name'].tolist()[0])
  299. for item in sorted_count_diags:
  300. disease_id = item[0]
  301. disease_name = self.entity_data[self.entity_data.index == disease_id]['name'].tolist()[0]
  302. symptoms_data = self.get_symptoms_data(disease_id, symptom_edge)
  303. if symptoms_data is None:
  304. continue
  305. symptoms = []
  306. for symptom in symptoms_data:
  307. if symptom =='发烧':
  308. continue
  309. matched = False
  310. if symptom in target_symptom_names:
  311. matched = True
  312. symptoms.append({"name": symptom, "matched": matched})
  313. # symtoms中matched=true的排在前面,matched=false的排在后面
  314. symptoms = sorted(symptoms, key=lambda x: x["matched"], reverse=True)
  315. start_nodes_size = len(start_nodes)
  316. # if start_nodes_size > 1:
  317. # start_nodes_size = start_nodes_size*0.5
  318. new_item = {"count": item[1]["count"],
  319. "score": float(item[1]["count"]) / start_nodes_size * 0.1, "symptoms": symptoms}
  320. diags[disease_name] = new_item
  321. sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
  322. return {
  323. "score_diags": sorted_score_diags
  324. # "checks":sorted_checks, "drugs":sorted_drugs,
  325. # "total_checks":total_check, "total_drugs":total_drug
  326. }
  327. def validDisease(self, results, start_nodes):
  328. """
  329. 输出有效的疾病信息为Markdown格式
  330. :param results: 疾病结果字典
  331. :param start_nodes: 起始症状节点列表
  332. :return: 格式化后的Markdown字符串
  333. """
  334. log_data = ["|疾病|症状|出现次数|是否相关"]
  335. log_data.append("|--|--|--|--|")
  336. filtered_results = {}
  337. for item in results:
  338. data = results[item]
  339. data['relevant'] = False
  340. if data["increase"] / len(start_nodes) > 0.5:
  341. #cache_key = f'disease_name_ref_id_{data['name']}'
  342. data['relevant'] = True
  343. filtered_results[item] = data
  344. # 初始化疾病的父类疾病
  345. # disease_name = data["name"]
  346. # key = 'disease_name_parent_' +disease_name
  347. # cached_value = self.cache.get(key)
  348. # if cached_value is None:
  349. # out_edges = self.graph.out_edges(item, data=True)
  350. #
  351. # for edge in out_edges:
  352. # src, dest, edge_data = edge
  353. # if edge_data["type"] != '疾病相关父类':
  354. # continue
  355. # dest_data = self.entity_data[self.entity_data.index == dest]
  356. # if dest_data.empty:
  357. # continue
  358. # dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
  359. # self.cache.set(key, dest_name)
  360. # break
  361. if data['relevant'] == False:
  362. continue
  363. log_data.append(f"|{data['name']}|{','.join(data['path'])}|{data['count']}|{data['relevant']}|")
  364. content = "疾病和症状相关性统计表格\n" + "\n".join(log_data)
  365. print(f"\n{content}")
  366. return filtered_results
  367. def step1(self, node_ids,node_id_names, input, allowed_types, symptom_same_edge,allowed_links,max_hops,DIESEASE):
  368. """
  369. 根据症状节点查找相关疾病
  370. :param node_ids: 症状节点ID列表
  371. :param input: 患者信息输入
  372. :param allowed_types: 允许的节点类型
  373. :param allowed_links: 允许的关系类型
  374. :return: 过滤后的疾病结果
  375. """
  376. start_time = time.time()
  377. results = {}
  378. for node in node_ids:
  379. visited = set()
  380. temp_results = {}
  381. cache_key = f"symptom_ref_disease_{str(node)}"
  382. cache_data = self.cache[cache_key] if cache_key in self.cache else None
  383. if cache_data:
  384. temp_results = copy.deepcopy(cache_data)
  385. print(cache_key+":"+node_id_names[node] +':'+ str(len(temp_results)))
  386. if results=={}:
  387. results = temp_results
  388. else:
  389. for disease_id in temp_results:
  390. path = temp_results[disease_id]["path"][0]
  391. if disease_id in results.keys():
  392. results[disease_id]["count"] = results[disease_id]["count"] + temp_results[disease_id]["count"]
  393. results[disease_id]["increase"] = results[disease_id]["increase"]+1
  394. results[disease_id]["path"].append(path)
  395. else:
  396. results[disease_id] = temp_results[disease_id]
  397. continue
  398. queue = [(node, 0, node_id_names[node],10, {'allowed_types': allowed_types, 'allowed_links': allowed_links})]
  399. # 整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换
  400. if input.pat_age and input.pat_age.value is not None and input.pat_age.value > 0 and input.pat_age.type == 'year':
  401. # 这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的
  402. input.pat_age.value = input.pat_age.value * 12
  403. input.pat_age.type = 'month'
  404. # STEP 1: 假设start_nodes里面都是症状,第一步我们先找到这些症状对应的疾病
  405. # TODO 由于这部分是按照症状逐一去寻找疾病,所以实际应用中可以缓存这些结果
  406. while queue:
  407. temp_node, depth, path, weight, data = queue.pop(0)
  408. temp_node = int(temp_node)
  409. # 这里是通过id去获取节点的name和type
  410. entity_data = self.entity_data[self.entity_data.index == temp_node]
  411. # 如果节点不存在,那么跳过
  412. if entity_data.empty:
  413. continue
  414. if self.graph.nodes.get(temp_node) is None:
  415. continue
  416. node_type = self.entity_data[self.entity_data.index == temp_node]['type'].tolist()[0]
  417. node_name = self.entity_data[self.entity_data.index == temp_node]['name'].tolist()[0]
  418. # print(f"node {node} type {node_type}")
  419. if node_type in DIESEASE:
  420. # print(f"node {node} type {node_type} is a disease")
  421. count = weight
  422. if self.check_diease_allowed(temp_node) == False:
  423. continue
  424. if temp_node not in temp_results.keys():
  425. # temp_results[temp_node]["count"] = temp_results[temp_node]["count"] + count
  426. # temp_results[temp_node]["increase"] = temp_results[temp_node]["increase"] + 1
  427. # temp_results[temp_node]["path"].append(path)
  428. # else:
  429. temp_results[temp_node] = {"type": node_type, "count": count, "increase": 1, "name": node_name, 'path': [path]}
  430. continue
  431. if temp_node in visited or depth > max_hops:
  432. # print(f"{node} already visited or reach max hops")
  433. continue
  434. visited.add(temp_node)
  435. # print(f"check edges from {node}")
  436. if temp_node not in self.graph:
  437. # print(f"node {node} not found in graph")
  438. continue
  439. for edge in self.graph.in_edges(temp_node, data=True):
  440. src, dest, edge_data = edge
  441. if src not in visited and depth + 1 <= max_hops and edge_data['type'] in allowed_links:
  442. weight = edge_data['weight']
  443. try :
  444. if weight:
  445. if weight < 10:
  446. weight = 10-weight
  447. else:
  448. weight = 1
  449. else:
  450. weight = 5
  451. if weight>10:
  452. weight = 10
  453. except Exception as e:
  454. print(f'Error processing file {weight}: {str(e)}')
  455. queue.append((src, depth + 1, path, int(weight), data))
  456. # else:
  457. # print(f"skip travel from {src} to {dest}")
  458. print(cache_key+":"+node_id_names[node]+':'+ str(len(temp_results)))
  459. #对temp_results进行深拷贝,然后再进行处理
  460. self.cache[cache_key] = copy.deepcopy(temp_results)
  461. if results == {}:
  462. results = temp_results
  463. else:
  464. for disease_id in temp_results:
  465. path = temp_results[disease_id]["path"][0]
  466. if disease_id in results.keys():
  467. results[disease_id]["count"] = results[disease_id]["count"] + temp_results[disease_id]["count"]
  468. results[disease_id]["increase"] = results[disease_id]["increase"] + 1
  469. results[disease_id]["path"].append(path)
  470. else:
  471. results[disease_id] = temp_results[disease_id]
  472. end_time = time.time()
  473. # 这里我们需要对结果进行过滤,过滤掉不满足条件的疾病
  474. new_results = {}
  475. for item in results:
  476. if input.pat_sex and input.pat_sex.value is not None and self.check_sex_allowed(item, input.pat_sex.value) == False:
  477. continue
  478. if input.pat_age and input.pat_age.value is not None and self.check_age_allowed(item, input.pat_age.value) == False:
  479. continue
  480. new_results[item] = results[item]
  481. results = new_results
  482. print('STEP 1 '+str(len(results)))
  483. print(f"STEP 1 执行完成,耗时:{end_time - start_time:.2f}秒")
  484. print(f"STEP 1 遍历图谱查找相关疾病 finished")
  485. return results
  486. def get_in_edges(self, node_id, allowed_links):
  487. results = []
  488. for edge in self.graph.in_edges(node_id, data=True):
  489. src, dest, edge_data = edge
  490. if edge_data['type'] in allowed_links:
  491. results.append(self.entity_data[self.entity_data.index == src])
  492. return results
  493. def step2(self, results,department_edge):
  494. """
  495. 查找疾病对应的科室、检查和药品信息
  496. :param results: 包含疾病信息的字典
  497. :return: 更新后的results字典
  498. """
  499. start_time = time.time()
  500. print("STEP 2 查找疾病对应的科室、检查和药品 start")
  501. for disease in results.keys():
  502. # cache_key = f"disease_department_{disease}"
  503. # cached_data = self.cache.get(cache_key)
  504. # if cached_data:
  505. # results[disease]["department"] = cached_data
  506. # continue
  507. if results[disease]["relevant"] == False:
  508. continue
  509. department_data = []
  510. out_edges = self.graph.out_edges(disease, data=True)
  511. for edge in out_edges:
  512. src, dest, edge_data = edge
  513. if edge_data["type"] not in department_edge:
  514. continue
  515. dest_data = self.entity_data[self.entity_data.index == dest]
  516. if dest_data.empty:
  517. continue
  518. department_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
  519. department_data.extend([department_name] * results[disease]["count"])
  520. if department_data:
  521. results[disease]["department"] = department_data
  522. #self.cache.set(cache_key, department_data)
  523. print(f"STEP 2 finished")
  524. end_time = time.time()
  525. print(f"STEP 2 执行完成,耗时:{end_time - start_time:.2f}秒")
  526. # 输出日志
  527. log_data = ["|disease|count|department|check|drug|"]
  528. log_data.append("|--|--|--|--|--|")
  529. for item in results.keys():
  530. department_data = results[item].get("department", [])
  531. count_data = results[item].get("count")
  532. check_data = results[item].get("check", [])
  533. drug_data = results[item].get("drug", [])
  534. log_data.append(
  535. f"|{results[item].get("name", item)}|{count_data}|{','.join(department_data)}|{','.join(check_data)}|{','.join(drug_data)}|")
  536. print("疾病科室检查药品相关统计\n" + "\n".join(log_data))
  537. return results
  538. def step3(self, results):
  539. print(f"STEP 3 对于结果按照科室维度进行汇总 start")
  540. final_results = {}
  541. total = 0
  542. for disease in results.keys():
  543. disease = int(disease)
  544. # 由于存在有些疾病没有科室的情况,所以这里需要做一下处理
  545. departments = ['DEFAULT']
  546. if 'department' in results[disease].keys():
  547. departments = results[disease]["department"]
  548. else:
  549. edges = KGEdgeService(next(get_db())).get_edges_by_nodes(src_id=disease, category='所属科室')
  550. #edges有可能为空,这里需要做一下处理
  551. if len(edges) > 0:
  552. departments = [edge['dest_node']['name'] for edge in edges]
  553. # 处理查询结果
  554. for department in departments:
  555. total += 1
  556. if not department in final_results.keys():
  557. final_results[department] = {
  558. "diseases": [str(disease)+":"+results[disease].get("name", disease)],
  559. "checks": results[disease].get("check", []),
  560. "drugs": results[disease].get("drug", []),
  561. "count": 1
  562. }
  563. else:
  564. final_results[department]["diseases"] = final_results[department]["diseases"] + [str(disease)+":"+
  565. results[disease].get("name", disease)]
  566. final_results[department]["checks"] = final_results[department]["checks"] + results[disease].get(
  567. "check", [])
  568. final_results[department]["drugs"] = final_results[department]["drugs"] + results[disease].get(
  569. "drug", [])
  570. final_results[department]["count"] += 1
  571. # 这里是统计科室出现的分布
  572. for department in final_results.keys():
  573. final_results[department]["score"] = final_results[department]["count"] / total
  574. print(f"STEP 3 finished")
  575. # 这里输出日志
  576. log_data = ["|department|disease|check|drug|count|score"]
  577. log_data.append("|--|--|--|--|--|--|")
  578. for department in final_results.keys():
  579. diesease_data = final_results[department].get("diseases", [])
  580. check_data = final_results[department].get("checks", [])
  581. drug_data = final_results[department].get("drugs", [])
  582. count_data = final_results[department].get("count", 0)
  583. score_data = final_results[department].get("score", 0)
  584. log_data.append(
  585. f"|{department}|{','.join(diesease_data)}|{','.join(check_data)}|{','.join(drug_data)}|{count_data}|{score_data}|")
  586. print("\n" + "\n".join(log_data))
  587. return final_results
  588. def step4(self, final_results):
  589. """
  590. 对final_results中的疾病、检查和药品进行统计和排序
  591. 参数:
  592. final_results: 包含科室、疾病、检查和药品的字典
  593. 返回值:
  594. 排序后的final_results
  595. """
  596. print(f"STEP 4 start")
  597. start_time = time.time()
  598. def sort_data(data, count=10):
  599. tmp = {}
  600. for item in data:
  601. if item in tmp.keys():
  602. tmp[item]["count"] += 1
  603. else:
  604. tmp[item] = {"count": 1}
  605. sorted_data = sorted(tmp.items(), key=lambda x: x[1]["count"], reverse=True)
  606. return sorted_data[:count]
  607. for department in final_results.keys():
  608. final_results[department]['name'] = department
  609. final_results[department]["diseases"] = sort_data(final_results[department]["diseases"])
  610. #final_results[department]["checks"] = sort_data(final_results[department]["checks"])
  611. #final_results[department]["drugs"] = sort_data(final_results[department]["drugs"])
  612. # 这里把科室做一个排序,按照出现的次数降序排序
  613. sorted_final_results = sorted(final_results.items(), key=lambda x: x[1]["count"], reverse=True)
  614. print(f"STEP 4 finished")
  615. end_time = time.time()
  616. print(f"STEP 4 执行完成,耗时:{end_time - start_time:.2f}秒")
  617. # 这里输出markdown日志
  618. log_data = ["|department|disease|check|drug|count|score"]
  619. log_data.append("|--|--|--|--|--|--|")
  620. for department in final_results.keys():
  621. diesease_data = final_results[department].get("diseases")
  622. check_data = final_results[department].get("checks")
  623. drug_data = final_results[department].get("drugs")
  624. count_data = final_results[department].get("count", 0)
  625. score_data = final_results[department].get("score", 0)
  626. log_data.append(f"|{department}|{diesease_data}|{check_data}|{drug_data}|{count_data}|{score_data}|")
  627. print("\n" + "\n".join(log_data))
  628. return sorted_final_results
  629. def step5(self, final_results, input, start_nodes, symptom_edge):
  630. """
  631. 按科室汇总结果并排序
  632. 参数:
  633. final_results: 各科室的初步结果
  634. input: 患者输入信息
  635. 返回值:
  636. 返回排序后的诊断结果
  637. """
  638. print(f"STEP 5 start")
  639. start_time = time.time()
  640. diags = {}
  641. total_diags = 0
  642. for department in final_results.keys():
  643. department_factor = 0.1 if department == 'DEFAULT' else final_results[department]["score"]
  644. count = 0
  645. #当前科室权重增加0.1
  646. if input.department.value == department:
  647. count = 1
  648. for disease, data in final_results[department]["diseases"]:
  649. total_diags += 1
  650. if disease in diags.keys():
  651. diags[disease]["count"] += data["count"]+count
  652. diags[disease]["score"] += (data["count"]+count)*0.1 * department_factor
  653. else:
  654. diags[disease] = {"count": data["count"]+count, "score": (data["count"]+count)*0.1 * department_factor}
  655. #sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)[:10]
  656. sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
  657. diags = {}
  658. for item in sorted_score_diags:
  659. disease_info = item[0].split(":");
  660. disease_id = disease_info[0]
  661. disease = disease_info[1]
  662. symptoms_data = self.get_symptoms_data(disease_id, symptom_edge)
  663. if symptoms_data is None:
  664. continue
  665. symptoms = []
  666. for symptom in symptoms_data:
  667. matched = False
  668. if symptom in start_nodes:
  669. matched = True
  670. symptoms.append({"name": symptom, "matched": matched})
  671. # symtoms中matched=true的排在前面,matched=false的排在后面
  672. symptoms = sorted(symptoms, key=lambda x: x["matched"], reverse=True)
  673. start_nodes_size = len(start_nodes)
  674. # if start_nodes_size > 1:
  675. # start_nodes_size = start_nodes_size*0.5
  676. new_item = {"old_score": item[1]["score"],"count": item[1]["count"], "score": float(item[1]["count"])/start_nodes_size/2*0.1,"symptoms":symptoms}
  677. diags[disease] = new_item
  678. sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
  679. print(f"STEP 5 finished")
  680. end_time = time.time()
  681. print(f"STEP 5 执行完成,耗时:{end_time - start_time:.2f}秒")
  682. log_data = ["|department|disease|count|score"]
  683. log_data.append("|--|--|--|--|")
  684. for department in final_results.keys():
  685. diesease_data = final_results[department].get("diseases")
  686. count_data = final_results[department].get("count", 0)
  687. score_data = final_results[department].get("score", 0)
  688. log_data.append(f"|{department}|{diesease_data}|{count_data}|{score_data}|")
  689. print("这里是经过排序的数据\n" + "\n".join(log_data))
  690. return sorted_score_diags, total_diags
  691. def get_symptoms_data(self, disease_id, symptom_edge):
  692. """
  693. 获取疾病相关的症状数据
  694. :param disease_id: 疾病节点ID
  695. :param symptom_edge: 症状关系类型列表
  696. :return: 症状数据列表
  697. """
  698. key = f'disease_{disease_id}_symptom'
  699. symptom_data = self.cache[key] if key in self.cache else None
  700. if symptom_data is None:
  701. out_edges = self.graph.out_edges(int(disease_id), data=True)
  702. symptom_data = []
  703. for edge in out_edges:
  704. src, dest, edge_data = edge
  705. if edge_data["type"] not in symptom_edge:
  706. continue
  707. dest_data = self.entity_data[self.entity_data.index == dest]
  708. if dest_data.empty:
  709. continue
  710. dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
  711. if dest_name not in symptom_data:
  712. symptom_data.append(dest_name)
  713. self.cache[key]=symptom_data
  714. return symptom_data