cdss_helper2.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758
  1. import copy
  2. from hmac import new
  3. import os
  4. import sys
  5. import logging
  6. import json
  7. import time
  8. from service.kg_edge_service import KGEdgeService
  9. from db.session import get_db
  10. from service.kg_node_service import KGNodeService
  11. from service.kg_prop_service import KGPropService
  12. from cachetools import TTLCache
  13. from cachetools.keys import hashkey
  14. current_path = os.getcwd()
  15. sys.path.append(current_path)
  16. from community.graph_helper import GraphHelper
  17. from typing import List
  18. from agent.cdss.models.schemas import CDSSInput
  19. from config.site import SiteConfig
  20. import networkx as nx
  21. import pandas as pd
  22. logger = logging.getLogger(__name__)
  23. current_path = os.getcwd()
  24. sys.path.append(current_path)
  25. # 图谱数据缓存路径(由dump_graph_data.py生成)
  26. CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data')
  27. class CDSSHelper(GraphHelper):
  28. def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
  29. kg_node_service = KGNodeService(next(get_db()))
  30. es_result = kg_node_service.search_title_index("", node_id,node_type, limit)
  31. results = []
  32. for item in es_result:
  33. n = self.graph.nodes.get(item["id"])
  34. score = item["score"]
  35. if n:
  36. results.append({
  37. 'id': item["title"],
  38. 'score': score,
  39. "name": item["title"],
  40. })
  41. return results
  42. def _load_entity_data(self):
  43. config = SiteConfig()
  44. # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
  45. print("load entity data")
  46. # 这里设置了读取的属性
  47. data = {"id": [], "name": [], "type": [],"is_symptom": [], "sex": [], "age": [],"score": []}
  48. if not os.path.exists(os.path.join(CACHED_DATA_PATH, 'entities_med.json')):
  49. return []
  50. with open(os.path.join(CACHED_DATA_PATH, 'entities_med.json'), "r", encoding="utf-8") as f:
  51. entities = json.load(f)
  52. for item in entities:
  53. #如果id已经存在,则跳过
  54. # if item[0] in data["id"]:
  55. # print(f"skip {item[0]}")
  56. # continue
  57. data["id"].append(int(item[0]))
  58. data["name"].append(item[1]["name"])
  59. data["type"].append(item[1]["type"])
  60. self._append_entity_attribute(data, item, "sex")
  61. self._append_entity_attribute(data, item, "age")
  62. self._append_entity_attribute(data, item, "is_symptom")
  63. self._append_entity_attribute(data, item, "score")
  64. # item[1]["id"] = item[0]
  65. # item[1]["name"] = item[0]
  66. # attrs = item[1]
  67. # self.graph.add_node(item[0], **attrs)
  68. self.entity_data = pd.DataFrame(data)
  69. self.entity_data.set_index("id", inplace=True)
  70. print("load entity data finished")
  71. def _append_entity_attribute(self, data, item, attr_name):
  72. if attr_name in item[1]:
  73. value = item[1][attr_name].split(":")
  74. if len(value) < 2:
  75. data[attr_name].append(value[0])
  76. else:
  77. data[attr_name].append(value[1])
  78. else:
  79. data[attr_name].append("")
  80. def _load_relation_data(self):
  81. config = SiteConfig()
  82. # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
  83. print("load relationship data")
  84. for i in range(62):
  85. if not os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
  86. continue
  87. if os.path.exists(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json")):
  88. print(f"load entity data {CACHED_DATA_PATH}\\relationship_med_{i}.json")
  89. with open(os.path.join(CACHED_DATA_PATH, f"relationship_med_{i}.json"), "r", encoding="utf-8") as f:
  90. data = {"src": [], "dest": [], "type": [], "weight": []}
  91. relations = json.load(f)
  92. for item in relations:
  93. data["src"].append(int(item[0]))
  94. data["dest"].append(int(item[2]))
  95. data["type"].append(item[4]["type"])
  96. if "order" in item[4]:
  97. order = item[4]["order"].split(":")
  98. if len(order) < 2:
  99. data["weight"].append(int(order[0]))
  100. else:
  101. data["weight"].append(int(order[1]))
  102. else:
  103. data["weight"].append(1)
  104. self.relation_data = pd.concat([self.relation_data, pd.DataFrame(data)], ignore_index=True)
  105. def build_graph(self):
  106. self.entity_data = pd.DataFrame(
  107. {"id": [], "name": [], "type": [], "sex": [], "allowed_age_range": []})
  108. self.relation_data = pd.DataFrame({"src": [], "dest": [], "type": [], "weight": []})
  109. self._load_entity_data()
  110. self._load_relation_data()
  111. self._load_local_data()
  112. self.graph = nx.from_pandas_edgelist(self.relation_data, "src", "dest", edge_attr=True,
  113. create_using=nx.DiGraph())
  114. nx.set_node_attributes(self.graph, self.entity_data.to_dict(orient="index"))
  115. # print(self.graph.in_edges('1257357',data=True))
  116. def _load_local_data(self):
  117. # 这里加载update数据和权重数据
  118. config = SiteConfig()
  119. self.update_data_path = config.get_config('UPDATE_DATA_PATH')
  120. self.factor_data_path = config.get_config('FACTOR_DATA_PATH')
  121. print(f"load update data from {self.update_data_path}")
  122. for root, dirs, files in os.walk(self.update_data_path):
  123. for file in files:
  124. file_path = os.path.join(root, file)
  125. if file_path.endswith(".json") and file.startswith("ent"):
  126. self._load_update_entity_json(file_path)
  127. if file_path.endswith(".json") and file.startswith("rel"):
  128. self._load_update_relationship_json(file_path)
  129. def _load_update_entity_json(self, file):
  130. '''load json data from file'''
  131. print(f"load entity update data from {file}")
  132. # 这里加载update数据,update数据是一个json文件,格式同cached data如下:
  133. with open(file, "r", encoding="utf-8") as f:
  134. entities = json.load(f)
  135. for item in entities:
  136. original_data = self.entity_data[self.entity_data.index == item[0]]
  137. if original_data.empty:
  138. continue
  139. original_data = original_data.iloc[0]
  140. id = int(item[0])
  141. name = item[1]["name"] if "name" in item[1] else original_data['name']
  142. type = item[1]["type"] if "type" in item[1] else original_data['type']
  143. allowed_sex_list = item[1]["allowed_sex_list"] if "allowed_sex_list" in item[1] else original_data[
  144. 'allowed_sex_list']
  145. allowed_age_range = item[1]["allowed_age_range"] if "allowed_age_range" in item[1] else original_data[
  146. 'allowed_age_range']
  147. self.entity_data.loc[id, ["name", "type", "allowed_sex_list", "allowed_age_range"]] = [name, type,
  148. allowed_sex_list,
  149. allowed_age_range]
  150. def _load_update_relationship_json(self, file):
  151. '''load json data from file'''
  152. print(f"load relationship update data from {file}")
  153. with open(file, "r", encoding="utf-8") as f:
  154. relations = json.load(f)
  155. for item in relations:
  156. data = {}
  157. original_data = self.relation_data[(self.relation_data['src'] == data['src']) &
  158. (self.relation_data['dest'] == data['dest']) &
  159. (self.relation_data['type'] == data['type'])]
  160. if original_data.empty:
  161. continue
  162. original_data = original_data.iloc[0]
  163. data["src"] = int(item[0])
  164. data["dest"] = int(item[2])
  165. data["type"] = item[4]["type"]
  166. data["weight"] = item[4]["weight"] if "weight" in item[4] else original_data['weight']
  167. self.relation_data.loc[(self.relation_data['src'] == data['src']) &
  168. (self.relation_data['dest'] == data['dest']) &
  169. (self.relation_data['type'] == data['type']), 'weight'] = data["weight"]
  170. def check_sex_allowed(self, node, sex):
  171. # 性别过滤,假设疾病节点有一个属性叫做allowed_sex_type,值为“0,1,2”,分别代表未知,男,女
  172. sex_allowed = self.graph.nodes[node].get('sex', None)
  173. #sexProps = self.propService.get_props_by_ref_id(node, 'sex')
  174. #if len(sexProps) > 0 and sexProps[0]['prop_value'] is not None and sexProps[0][
  175. #'prop_value'] != input.pat_sex.value:
  176. #continue
  177. if sex_allowed:
  178. if len(sex_allowed) == 0:
  179. # 如果性别列表为空,那么默认允许所有性别
  180. return True
  181. sex_allowed_list = sex_allowed.split(',')
  182. if sex not in sex_allowed_list:
  183. # 如果性别不匹配,跳过
  184. return False
  185. return True
  186. def check_age_allowed(self, node, age):
  187. # 年龄过滤,假设疾病节点有一个属性叫做allowed_age_range,值为“6-88”,代表年龄在0-88月之间是允许的
  188. # 如果说年龄小于6岁,那么我们就认为是儿童,所以儿童的年龄范围是0-6月
  189. age_allowed = self.graph.nodes[node].get('age', None)
  190. if age_allowed:
  191. if len(age_allowed) == 0:
  192. # 如果年龄范围为空,那么默认允许所有年龄
  193. return True
  194. age_allowed_list = age_allowed.split('-')
  195. age_min = int(age_allowed_list[0])
  196. age_max = int(age_allowed_list[-1])
  197. if age_max ==0:
  198. return True
  199. if age >= age_min and age < age_max:
  200. # 如果年龄范围正常,那么返回True
  201. return True
  202. else:
  203. # 如果没有设置年龄范围,那么默认返回True
  204. return True
  205. return False
  206. def check_diease_allowed(self, node):
  207. is_symptom = self.graph.nodes[node].get('is_symptom', None)
  208. if is_symptom == "是":
  209. return False
  210. return True
  211. propService = KGPropService(next(get_db()))
  212. cache = TTLCache(maxsize=100000, ttl=60*60*24*30)
  213. def cdss_travel(self, input: CDSSInput, start_nodes: List, max_hops=3):
  214. """
  215. 基于输入的症状节点,在知识图谱中进行遍历,查找相关疾病、科室、检查和药品
  216. 参数:
  217. input: CDSSInput对象,包含患者的基本信息(年龄、性别等)
  218. start_nodes: 症状节点名称列表,作为遍历的起点
  219. max_hops: 最大遍历深度,默认为3
  220. 返回值:
  221. 返回一个包含以下信息的字典:
  222. - details: 按科室汇总的结果
  223. - diags: 按相关性排序的疾病列表
  224. - checks: 按出现频率排序的检查列表
  225. - drugs: 按出现频率排序的药品列表
  226. - total_diags: 疾病总数
  227. - total_checks: 检查总数
  228. - total_drugs: 药品总数
  229. 主要步骤:
  230. 1. 初始化允许的节点类型和关系类型
  231. 2. 将症状名称转换为节点ID
  232. 3. 遍历图谱查找相关疾病(STEP 1)
  233. 4. 查找疾病对应的科室、检查和药品(STEP 2)
  234. 5. 按科室汇总结果(STEP 3)
  235. 6. 对结果进行排序和统计(STEP 4-6)
  236. """
  237. # 定义允许的节点类型,包括科室、疾病、药品、检查和症状
  238. # 这些类型用于后续的节点过滤和路径查找
  239. DEPARTMENT = ['科室', 'Department']
  240. DIESEASE = ['疾病', 'Disease']
  241. DRUG = ['药品', 'Drug']
  242. CHECK = ['检查', 'Check']
  243. SYMPTOM = ['症状', 'Symptom']
  244. #allowed_types = DEPARTMENT + DIESEASE + DRUG + CHECK + SYMPTOM
  245. allowed_types = DEPARTMENT + DIESEASE + SYMPTOM
  246. # 定义允许的关系类型,包括has_symptom、need_check、recommend_drug、belongs_to
  247. # 这些关系类型用于后续的路径查找和过滤
  248. allowed_links = ['has_symptom','疾病相关症状','belongs_to','所属科室']
  249. # 将输入的症状名称转换为节点ID
  250. # 由于可能存在同名节点,转换后的节点ID数量可能大于输入的症状数量
  251. node_ids = []
  252. node_id_names = {}
  253. # start_nodes里面重复的症状,去重同样的症状
  254. start_nodes = list(set(start_nodes))
  255. for node in start_nodes:
  256. #print(f"searching for node {node}")
  257. result = self.entity_data[self.entity_data['name'] == node]
  258. #print(f"searching for node {result}")
  259. for index, data in result.iterrows():
  260. if data["type"] in SYMPTOM or data["type"] in DIESEASE:
  261. node_id_names[index] = data["name"]
  262. node_ids = node_ids + [index]
  263. #print(f"start travel from {node_id_names}")
  264. # 这里是一个队列,用于存储待遍历的症状:
  265. node_ids_filtered = []
  266. for node in node_ids:
  267. if self.graph.has_node(node):
  268. node_ids_filtered.append(node)
  269. else:
  270. logger.debug(f"node {node} not found")
  271. node_ids = node_ids_filtered
  272. results = self.step1(node_ids,node_id_names, input, allowed_types, allowed_links,max_hops,DIESEASE)
  273. #self.validDisease(results, start_nodes)
  274. results = self.validDisease(results, start_nodes)
  275. # 调用step2方法处理科室、检查和药品信息
  276. results = self.step2(results)
  277. # STEP 3: 对于结果按照科室维度进行汇总
  278. final_results = self.step3(results)
  279. sorted_final_results = self.step4(final_results)
  280. sorted_final_results = sorted_final_results[:5]
  281. departments = []
  282. for temp in sorted_final_results:
  283. departments.append({"name": temp[0], "count": temp[1]["count"]})
  284. # STEP 5: 对于final_results里面的diseases, checks和durgs统计全局出现的次数并且按照次数降序排序
  285. sorted_score_diags,total_diags = self.step5(final_results, input,start_nodes)
  286. # STEP 6: 整合数据并返回
  287. # if "department" in item.keys():
  288. # final_results["department"] = list(set(final_results["department"]+item["department"]))
  289. # if "diseases" in item.keys():
  290. # final_results["diseases"] = list(set(final_results["diseases"]+item["diseases"]))
  291. # if "checks" in item.keys():
  292. # final_results["checks"] = list(set(final_results["checks"]+item["checks"]))
  293. # if "drugs" in item.keys():
  294. # final_results["drugs"] = list(set(final_results["drugs"]+item["drugs"]))
  295. # if "symptoms" in item.keys():
  296. # final_results["symptoms"] = list(set(final_results["symptoms"]+item["symptoms"]))
  297. return {"details": sorted_final_results,
  298. "score_diags": sorted_score_diags,"total_diags": total_diags,
  299. "departments":departments
  300. # "checks":sorted_checks, "drugs":sorted_drugs,
  301. # "total_checks":total_check, "total_drugs":total_drug
  302. }
  303. def validDisease(self, results, start_nodes):
  304. """
  305. 输出有效的疾病信息为Markdown格式
  306. :param results: 疾病结果字典
  307. :param start_nodes: 起始症状节点列表
  308. :return: 格式化后的Markdown字符串
  309. """
  310. log_data = ["|疾病|症状|出现次数|是否相关"]
  311. log_data.append("|--|--|--|--|")
  312. filtered_results = {}
  313. for item in results:
  314. data = results[item]
  315. data['relevant'] = False
  316. if data["increase"] / len(start_nodes) > 0.5:
  317. #cache_key = f'disease_name_ref_id_{data['name']}'
  318. data['relevant'] = True
  319. filtered_results[item] = data
  320. # 初始化疾病的父类疾病
  321. # disease_name = data["name"]
  322. # key = 'disease_name_parent_' +disease_name
  323. # cached_value = self.cache.get(key)
  324. # if cached_value is None:
  325. # out_edges = self.graph.out_edges(item, data=True)
  326. #
  327. # for edge in out_edges:
  328. # src, dest, edge_data = edge
  329. # if edge_data["type"] != '疾病相关父类':
  330. # continue
  331. # dest_data = self.entity_data[self.entity_data.index == dest]
  332. # if dest_data.empty:
  333. # continue
  334. # dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
  335. # self.cache.set(key, dest_name)
  336. # break
  337. if data['relevant'] == False:
  338. continue
  339. log_data.append(f"|{data['name']}|{','.join(data['path'])}|{data['count']}|{data['relevant']}|")
  340. content = "疾病和症状相关性统计表格\n" + "\n".join(log_data)
  341. print(f"\n{content}")
  342. return filtered_results
  343. def step1(self, node_ids,node_id_names, input, allowed_types, allowed_links,max_hops,DIESEASE):
  344. """
  345. 根据症状节点查找相关疾病
  346. :param node_ids: 症状节点ID列表
  347. :param input: 患者信息输入
  348. :param allowed_types: 允许的节点类型
  349. :param allowed_links: 允许的关系类型
  350. :return: 过滤后的疾病结果
  351. """
  352. start_time = time.time()
  353. results = {}
  354. for node in node_ids:
  355. visited = set()
  356. temp_results = {}
  357. cache_key = f"symptom_ref_disease_{str(node)}"
  358. cache_data = self.cache[cache_key] if cache_key in self.cache else None
  359. if cache_data:
  360. temp_results = copy.deepcopy(cache_data)
  361. print(cache_key+":"+node_id_names[node] +':'+ str(len(temp_results)))
  362. if results=={}:
  363. results = temp_results
  364. else:
  365. for disease_id in temp_results:
  366. path = temp_results[disease_id]["path"][0]
  367. if disease_id in results.keys():
  368. results[disease_id]["count"] = results[disease_id]["count"] + temp_results[disease_id]["count"]
  369. results[disease_id]["increase"] = results[disease_id]["increase"]+1
  370. results[disease_id]["path"].append(path)
  371. else:
  372. results[disease_id] = temp_results[disease_id]
  373. continue
  374. queue = [(node, 0, node_id_names[node],10, {'allowed_types': allowed_types, 'allowed_links': allowed_links})]
  375. # 整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换
  376. if input.pat_age and input.pat_age.value is not None and input.pat_age.value > 0 and input.pat_age.type == 'year':
  377. # 这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的
  378. input.pat_age.value = input.pat_age.value * 12
  379. input.pat_age.type = 'month'
  380. # STEP 1: 假设start_nodes里面都是症状,第一步我们先找到这些症状对应的疾病
  381. # TODO 由于这部分是按照症状逐一去寻找疾病,所以实际应用中可以缓存这些结果
  382. while queue:
  383. temp_node, depth, path, weight, data = queue.pop(0)
  384. temp_node = int(temp_node)
  385. # 这里是通过id去获取节点的name和type
  386. entity_data = self.entity_data[self.entity_data.index == temp_node]
  387. # 如果节点不存在,那么跳过
  388. if entity_data.empty:
  389. continue
  390. if self.graph.nodes.get(temp_node) is None:
  391. continue
  392. node_type = self.entity_data[self.entity_data.index == temp_node]['type'].tolist()[0]
  393. node_name = self.entity_data[self.entity_data.index == temp_node]['name'].tolist()[0]
  394. # print(f"node {node} type {node_type}")
  395. if node_type in DIESEASE:
  396. # print(f"node {node} type {node_type} is a disease")
  397. count = weight
  398. if self.check_diease_allowed(temp_node) == False:
  399. continue
  400. if temp_node in temp_results.keys():
  401. temp_results[temp_node]["count"] = temp_results[temp_node]["count"] + count
  402. results[disease_id]["increase"] = results[disease_id]["increase"] + 1
  403. temp_results[temp_node]["path"].append(path)
  404. else:
  405. temp_results[temp_node] = {"type": node_type, "count": count, "increase": 1, "name": node_name, 'path': [path]}
  406. continue
  407. if temp_node in visited or depth > max_hops:
  408. # print(f"{node} already visited or reach max hops")
  409. continue
  410. visited.add(temp_node)
  411. # print(f"check edges from {node}")
  412. if temp_node not in self.graph:
  413. # print(f"node {node} not found in graph")
  414. continue
  415. # todo 目前是取入边,出边是不是也有用?
  416. for edge in self.graph.in_edges(temp_node, data=True):
  417. src, dest, edge_data = edge
  418. if src not in visited and depth + 1 < max_hops:
  419. # print(f"put into queue travel from {src} to {dest}")
  420. weight = edge_data['weight']
  421. try :
  422. if weight:
  423. if weight < 10:
  424. weight = 10-weight
  425. else:
  426. weight = 1
  427. else:
  428. weight = 5
  429. if weight>10:
  430. weight = 10
  431. except Exception as e:
  432. print(f'Error processing file {weight}: {str(e)}')
  433. queue.append((src, depth + 1, path, int(weight), data))
  434. # else:
  435. # print(f"skip travel from {src} to {dest}")
  436. print(cache_key+":"+node_id_names[node]+':'+ str(len(temp_results)))
  437. #对temp_results进行深拷贝,然后再进行处理
  438. self.cache[cache_key] = copy.deepcopy(temp_results)
  439. if results == {}:
  440. results = temp_results
  441. else:
  442. for disease_id in temp_results:
  443. path = temp_results[disease_id]["path"][0]
  444. if disease_id in results.keys():
  445. results[disease_id]["count"] = results[disease_id]["count"] + temp_results[disease_id]["count"]
  446. results[disease_id]["increase"] = results[disease_id]["increase"] + 1
  447. results[disease_id]["path"].append(path)
  448. else:
  449. results[disease_id] = temp_results[disease_id]
  450. end_time = time.time()
  451. # 这里我们需要对结果进行过滤,过滤掉不满足条件的疾病
  452. new_results = {}
  453. for item in results:
  454. if input.pat_sex and input.pat_sex.value is not None and self.check_sex_allowed(item, input.pat_sex.value) == False:
  455. continue
  456. if input.pat_age and input.pat_age.value is not None and self.check_age_allowed(item, input.pat_age.value) == False:
  457. continue
  458. new_results[item] = results[item]
  459. results = new_results
  460. print('STEP 1 '+str(len(results)))
  461. print(f"STEP 1 执行完成,耗时:{end_time - start_time:.2f}秒")
  462. print(f"STEP 1 遍历图谱查找相关疾病 finished")
  463. return results
  464. def step2(self, results):
  465. """
  466. 查找疾病对应的科室、检查和药品信息
  467. :param results: 包含疾病信息的字典
  468. :return: 更新后的results字典
  469. """
  470. start_time = time.time()
  471. print("STEP 2 查找疾病对应的科室、检查和药品 start")
  472. for disease in results.keys():
  473. # cache_key = f"disease_department_{disease}"
  474. # cached_data = self.cache.get(cache_key)
  475. # if cached_data:
  476. # results[disease]["department"] = cached_data
  477. # continue
  478. if results[disease]["relevant"] == False:
  479. continue
  480. department_data = []
  481. out_edges = self.graph.out_edges(disease, data=True)
  482. for edge in out_edges:
  483. src, dest, edge_data = edge
  484. if edge_data["type"] != 'belongs_to' and edge_data["type"] != '所属科室':
  485. continue
  486. dest_data = self.entity_data[self.entity_data.index == dest]
  487. if dest_data.empty:
  488. continue
  489. department_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
  490. department_data.extend([department_name] * results[disease]["count"])
  491. if department_data:
  492. results[disease]["department"] = department_data
  493. #self.cache.set(cache_key, department_data)
  494. print(f"STEP 2 finished")
  495. end_time = time.time()
  496. print(f"STEP 2 执行完成,耗时:{end_time - start_time:.2f}秒")
  497. # 输出日志
  498. log_data = ["|disease|count|department|check|drug|"]
  499. log_data.append("|--|--|--|--|--|")
  500. for item in results.keys():
  501. department_data = results[item].get("department", [])
  502. count_data = results[item].get("count")
  503. check_data = results[item].get("check", [])
  504. drug_data = results[item].get("drug", [])
  505. log_data.append(
  506. f"|{results[item].get("name", item)}|{count_data}|{','.join(department_data)}|{','.join(check_data)}|{','.join(drug_data)}|")
  507. print("疾病科室检查药品相关统计\n" + "\n".join(log_data))
  508. return results
  509. def step3(self, results):
  510. print(f"STEP 3 对于结果按照科室维度进行汇总 start")
  511. final_results = {}
  512. total = 0
  513. for disease in results.keys():
  514. disease = int(disease)
  515. # 由于存在有些疾病没有科室的情况,所以这里需要做一下处理
  516. departments = ['DEFAULT']
  517. if 'department' in results[disease].keys():
  518. departments = results[disease]["department"]
  519. else:
  520. edges = KGEdgeService(next(get_db())).get_edges_by_nodes(src_id=disease, category='所属科室')
  521. #edges有可能为空,这里需要做一下处理
  522. if len(edges) == 0:
  523. continue
  524. departments = [edge['dest_node']['name'] for edge in edges]
  525. # 处理查询结果
  526. for department in departments:
  527. total += 1
  528. if not department in final_results.keys():
  529. final_results[department] = {
  530. "diseases": [str(disease)+":"+results[disease].get("name", disease)],
  531. "checks": results[disease].get("check", []),
  532. "drugs": results[disease].get("drug", []),
  533. "count": 1
  534. }
  535. else:
  536. final_results[department]["diseases"] = final_results[department]["diseases"] + [str(disease)+":"+
  537. results[disease].get("name", disease)]
  538. final_results[department]["checks"] = final_results[department]["checks"] + results[disease].get(
  539. "check", [])
  540. final_results[department]["drugs"] = final_results[department]["drugs"] + results[disease].get(
  541. "drug", [])
  542. final_results[department]["count"] += 1
  543. # 这里是统计科室出现的分布
  544. for department in final_results.keys():
  545. final_results[department]["score"] = final_results[department]["count"] / total
  546. print(f"STEP 3 finished")
  547. # 这里输出日志
  548. log_data = ["|department|disease|check|drug|count|score"]
  549. log_data.append("|--|--|--|--|--|--|")
  550. for department in final_results.keys():
  551. diesease_data = final_results[department].get("diseases", [])
  552. check_data = final_results[department].get("checks", [])
  553. drug_data = final_results[department].get("drugs", [])
  554. count_data = final_results[department].get("count", 0)
  555. score_data = final_results[department].get("score", 0)
  556. log_data.append(
  557. f"|{department}|{','.join(diesease_data)}|{','.join(check_data)}|{','.join(drug_data)}|{count_data}|{score_data}|")
  558. print("\n" + "\n".join(log_data))
  559. return final_results
  560. def step4(self, final_results):
  561. """
  562. 对final_results中的疾病、检查和药品进行统计和排序
  563. 参数:
  564. final_results: 包含科室、疾病、检查和药品的字典
  565. 返回值:
  566. 排序后的final_results
  567. """
  568. print(f"STEP 4 start")
  569. start_time = time.time()
  570. def sort_data(data, count=10):
  571. tmp = {}
  572. for item in data:
  573. if item in tmp.keys():
  574. tmp[item]["count"] += 1
  575. else:
  576. tmp[item] = {"count": 1}
  577. sorted_data = sorted(tmp.items(), key=lambda x: x[1]["count"], reverse=True)
  578. return sorted_data[:count]
  579. for department in final_results.keys():
  580. final_results[department]['name'] = department
  581. final_results[department]["diseases"] = sort_data(final_results[department]["diseases"])
  582. #final_results[department]["checks"] = sort_data(final_results[department]["checks"])
  583. #final_results[department]["drugs"] = sort_data(final_results[department]["drugs"])
  584. # 这里把科室做一个排序,按照出现的次数降序排序
  585. sorted_final_results = sorted(final_results.items(), key=lambda x: x[1]["count"], reverse=True)
  586. print(f"STEP 4 finished")
  587. end_time = time.time()
  588. print(f"STEP 4 执行完成,耗时:{end_time - start_time:.2f}秒")
  589. # 这里输出markdown日志
  590. log_data = ["|department|disease|check|drug|count|score"]
  591. log_data.append("|--|--|--|--|--|--|")
  592. for department in final_results.keys():
  593. diesease_data = final_results[department].get("diseases")
  594. check_data = final_results[department].get("checks")
  595. drug_data = final_results[department].get("drugs")
  596. count_data = final_results[department].get("count", 0)
  597. score_data = final_results[department].get("score", 0)
  598. log_data.append(f"|{department}|{diesease_data}|{check_data}|{drug_data}|{count_data}|{score_data}|")
  599. print("\n" + "\n".join(log_data))
  600. return sorted_final_results
  601. def step5(self, final_results, input,start_nodes):
  602. """
  603. 按科室汇总结果并排序
  604. 参数:
  605. final_results: 各科室的初步结果
  606. input: 患者输入信息
  607. 返回值:
  608. 返回排序后的诊断结果
  609. """
  610. print(f"STEP 5 start")
  611. start_time = time.time()
  612. diags = {}
  613. total_diags = 0
  614. for department in final_results.keys():
  615. department_factor = 0.1 if department == 'DEFAULT' else final_results[department]["score"]
  616. count = 0
  617. #当前科室权重增加0.1
  618. if input.department.value == department:
  619. count = 1
  620. for disease, data in final_results[department]["diseases"]:
  621. total_diags += 1
  622. #key = 'disease_name_parent_' + disease
  623. # cached_data = self.cache.get(key)
  624. # if cached_data:
  625. # disease = cached_data
  626. disease = disease.split(":")[1]
  627. if disease in diags.keys():
  628. diags[disease]["count"] += data["count"]+count
  629. diags[disease]["score"] += (data["count"]+count)*0.1 * department_factor
  630. else:
  631. diags[disease] = {"count": data["count"]+count, "score": (data["count"]+count)*0.1 * department_factor}
  632. #sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)[:10]
  633. sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["count"], reverse=True)[:10]
  634. diags = {}
  635. for item in sorted_score_diags:
  636. new_item = {"old_score": item[1]["score"],"count": item[1]["count"], "score": float(item[1]["count"])*0.1}
  637. diags[item[0]] = new_item
  638. sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
  639. #循环sorted_score_diags,把disease现在是id:name替换name
  640. # diags = {}
  641. # for item in sorted_score_diags:
  642. # diseaseinfos = item[0].split(":")
  643. # diseaseId = diseaseinfos[0]
  644. # diseaseName = diseaseinfos[1]
  645. # diseaseScore = self.graph.nodes[int(diseaseId)].get('score', None)
  646. # if diseaseScore:
  647. # try:
  648. # # 创建新字典替换原元组
  649. # new_item = {"count": item[1]["count"], "score": float(diseaseScore)*0.1}
  650. # diags[diseaseName] = new_item
  651. # except (ValueError, TypeError):
  652. # # 如果转换失败,直接使用原元组
  653. # diags[diseaseName] = item[1]
  654. # else:
  655. # # 如果没有找到对应的score,直接使用原元组
  656. # diags[diseaseName] = item[1]
  657. # sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
  658. print(f"STEP 5 finished")
  659. end_time = time.time()
  660. print(f"STEP 5 执行完成,耗时:{end_time - start_time:.2f}秒")
  661. log_data = ["|department|disease|count|score"]
  662. log_data.append("|--|--|--|--|")
  663. for department in final_results.keys():
  664. diesease_data = final_results[department].get("diseases")
  665. count_data = final_results[department].get("count", 0)
  666. score_data = final_results[department].get("score", 0)
  667. log_data.append(f"|{department}|{diesease_data}|{count_data}|{score_data}|")
  668. print("这里是经过排序的数据\n" + "\n".join(log_data))
  669. return sorted_score_diags, total_diags