cdss_helper.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714
  1. from hmac import new
  2. import os
  3. import sys
  4. import logging
  5. import json
  6. import time
  7. from service.kg_edge_service import KGEdgeService
  8. from db.session import get_db
  9. from service.kg_node_service import KGNodeService
  10. from service.kg_prop_service import KGPropService
  11. from utils.cache import Cache
  12. current_path = os.getcwd()
  13. sys.path.append(current_path)
  14. from community.graph_helper import GraphHelper
  15. from typing import List
  16. from agent.cdss.models.schemas import CDSSInput
  17. from config.site import SiteConfig
  18. import networkx as nx
  19. import pandas as pd
  20. logger = logging.getLogger(__name__)
  21. current_path = os.getcwd()
  22. sys.path.append(current_path)
  23. # 图谱数据缓存路径(由dump_graph_data.py生成)
  24. CACHED_DATA_PATH = os.path.join(current_path, 'community', 'web', 'cached_data')
  25. class CDSSHelper(GraphHelper):
  26. def node_search(self, node_id=None, node_type=None, filters=None, limit=1, min_degree=None):
  27. kg_node_service = KGNodeService(next(get_db()))
  28. es_result = kg_node_service.search_title_index("graph_entity_index", node_id, limit)
  29. results = []
  30. for item in es_result:
  31. n = self.graph.nodes.get(item["id"])
  32. score = item["score"]
  33. if n:
  34. results.append({
  35. 'id': item["title"],
  36. 'score': score,
  37. "name": item["title"],
  38. })
  39. return results
  40. def _load_entity_data(self):
  41. config = SiteConfig()
  42. # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
  43. print("load entity data")
  44. # 这里设置了读取的属性
  45. data = {"id": [], "name": [], "type": [],"is_symptom": [], "sex": [], "age": []}
  46. with open(f"{CACHED_DATA_PATH}\\entities_med.json", "r", encoding="utf-8") as f:
  47. entities = json.load(f)
  48. for item in entities:
  49. #如果id已经存在,则跳过
  50. # if item[0] in data["id"]:
  51. # print(f"skip {item[0]}")
  52. # continue
  53. data["id"].append(int(item[0]))
  54. data["name"].append(item[1]["name"])
  55. data["type"].append(item[1]["type"])
  56. self._append_entity_attribute(data, item, "sex")
  57. self._append_entity_attribute(data, item, "age")
  58. self._append_entity_attribute(data, item, "is_symptom")
  59. # item[1]["id"] = item[0]
  60. # item[1]["name"] = item[0]
  61. # attrs = item[1]
  62. # self.graph.add_node(item[0], **attrs)
  63. self.entity_data = pd.DataFrame(data)
  64. self.entity_data.set_index("id", inplace=True)
  65. print("load entity data finished")
  66. def _append_entity_attribute(self, data, item, attr_name):
  67. if attr_name in item[1]:
  68. value = item[1][attr_name].split(":")
  69. if len(value) < 2:
  70. data[attr_name].append(value[0])
  71. else:
  72. data[attr_name].append(value[1])
  73. else:
  74. data[attr_name].append("")
  75. def _load_relation_data(self):
  76. config = SiteConfig()
  77. # CACHED_DATA_PATH = config.get_config("CACHED_DATA_PATH")
  78. print("load relationship data")
  79. for i in range(46):
  80. if os.path.exists(f"{CACHED_DATA_PATH}\\relationship_med_{i}.json"):
  81. print(f"load entity data {CACHED_DATA_PATH}\\relationship_med_{i}.json")
  82. with open(f"{CACHED_DATA_PATH}\\relationship_med_{i}.json", "r", encoding="utf-8") as f:
  83. data = {"src": [], "dest": [], "type": [], "weight": []}
  84. relations = json.load(f)
  85. for item in relations:
  86. data["src"].append(int(item[0]))
  87. data["dest"].append(int(item[2]))
  88. data["type"].append(item[4]["type"])
  89. if "order" in item[4]:
  90. order = item[4]["order"].split(":")
  91. if len(order) < 2:
  92. data["weight"].append(order[0])
  93. else:
  94. data["weight"].append(order[1])
  95. else:
  96. data["weight"].append(1)
  97. self.relation_data = pd.concat([self.relation_data, pd.DataFrame(data)], ignore_index=True)
  98. def build_graph(self):
  99. self.entity_data = pd.DataFrame(
  100. {"id": [], "name": [], "type": [], "sex": [], "allowed_age_range": []})
  101. self.relation_data = pd.DataFrame({"src": [], "dest": [], "type": [], "weight": []})
  102. self._load_entity_data()
  103. self._load_relation_data()
  104. self._load_local_data()
  105. self.graph = nx.from_pandas_edgelist(self.relation_data, "src", "dest", edge_attr=True,
  106. create_using=nx.DiGraph())
  107. nx.set_node_attributes(self.graph, self.entity_data.to_dict(orient="index"))
  108. # print(self.graph.in_edges('1257357',data=True))
  109. def _load_local_data(self):
  110. # 这里加载update数据和权重数据
  111. config = SiteConfig()
  112. self.update_data_path = config.get_config('UPDATE_DATA_PATH')
  113. self.factor_data_path = config.get_config('FACTOR_DATA_PATH')
  114. print(f"load update data from {self.update_data_path}")
  115. for root, dirs, files in os.walk(self.update_data_path):
  116. for file in files:
  117. file_path = os.path.join(root, file)
  118. if file_path.endswith(".json") and file.startswith("ent"):
  119. self._load_update_entity_json(file_path)
  120. if file_path.endswith(".json") and file.startswith("rel"):
  121. self._load_update_relationship_json(file_path)
  122. def _load_update_entity_json(self, file):
  123. '''load json data from file'''
  124. print(f"load entity update data from {file}")
  125. # 这里加载update数据,update数据是一个json文件,格式同cached data如下:
  126. with open(file, "r", encoding="utf-8") as f:
  127. entities = json.load(f)
  128. for item in entities:
  129. original_data = self.entity_data[self.entity_data.index == item[0]]
  130. if original_data.empty:
  131. continue
  132. original_data = original_data.iloc[0]
  133. id = int(item[0])
  134. name = item[1]["name"] if "name" in item[1] else original_data['name']
  135. type = item[1]["type"] if "type" in item[1] else original_data['type']
  136. allowed_sex_list = item[1]["allowed_sex_list"] if "allowed_sex_list" in item[1] else original_data[
  137. 'allowed_sex_list']
  138. allowed_age_range = item[1]["allowed_age_range"] if "allowed_age_range" in item[1] else original_data[
  139. 'allowed_age_range']
  140. self.entity_data.loc[id, ["name", "type", "allowed_sex_list", "allowed_age_range"]] = [name, type,
  141. allowed_sex_list,
  142. allowed_age_range]
  143. def _load_update_relationship_json(self, file):
  144. '''load json data from file'''
  145. print(f"load relationship update data from {file}")
  146. with open(file, "r", encoding="utf-8") as f:
  147. relations = json.load(f)
  148. for item in relations:
  149. data = {}
  150. original_data = self.relation_data[(self.relation_data['src'] == data['src']) &
  151. (self.relation_data['dest'] == data['dest']) &
  152. (self.relation_data['type'] == data['type'])]
  153. if original_data.empty:
  154. continue
  155. original_data = original_data.iloc[0]
  156. data["src"] = int(item[0])
  157. data["dest"] = int(item[2])
  158. data["type"] = item[4]["type"]
  159. data["weight"] = item[4]["weight"] if "weight" in item[4] else original_data['weight']
  160. self.relation_data.loc[(self.relation_data['src'] == data['src']) &
  161. (self.relation_data['dest'] == data['dest']) &
  162. (self.relation_data['type'] == data['type']), 'weight'] = data["weight"]
  163. def check_sex_allowed(self, node, sex):
  164. # 性别过滤,假设疾病节点有一个属性叫做allowed_sex_type,值为“0,1,2”,分别代表未知,男,女
  165. sex_allowed = self.graph.nodes[node].get('sex', None)
  166. #sexProps = self.propService.get_props_by_ref_id(node, 'sex')
  167. #if len(sexProps) > 0 and sexProps[0]['prop_value'] is not None and sexProps[0][
  168. #'prop_value'] != input.pat_sex.value:
  169. #continue
  170. if sex_allowed:
  171. if len(sex_allowed) == 0:
  172. # 如果性别列表为空,那么默认允许所有性别
  173. return True
  174. sex_allowed_list = sex_allowed.split(',')
  175. if sex not in sex_allowed_list:
  176. # 如果性别不匹配,跳过
  177. return False
  178. return True
  179. def check_age_allowed(self, node, age):
  180. # 年龄过滤,假设疾病节点有一个属性叫做allowed_age_range,值为“6-88”,代表年龄在0-88月之间是允许的
  181. # 如果说年龄小于6岁,那么我们就认为是儿童,所以儿童的年龄范围是0-6月
  182. age_allowed = self.graph.nodes[node].get('age', None)
  183. if age_allowed:
  184. if len(age_allowed) == 0:
  185. # 如果年龄范围为空,那么默认允许所有年龄
  186. return True
  187. age_allowed_list = age_allowed.split('-')
  188. age_min = int(age_allowed_list[0])
  189. age_max = int(age_allowed_list[-1])
  190. if age_max ==0:
  191. return True
  192. if age >= age_min and age < age_max:
  193. # 如果年龄范围正常,那么返回True
  194. return True
  195. else:
  196. # 如果没有设置年龄范围,那么默认返回True
  197. return True
  198. return False
  199. def check_diease_allowed(self, node):
  200. is_symptom = self.graph.nodes[node].get('is_symptom', None)
  201. if is_symptom == "是":
  202. return False
  203. return True
  204. propService = KGPropService(next(get_db()))
  205. cache = Cache()
  206. def cdss_travel(self, input: CDSSInput, start_nodes: List, max_hops=3):
  207. """
  208. 基于输入的症状节点,在知识图谱中进行遍历,查找相关疾病、科室、检查和药品
  209. 参数:
  210. input: CDSSInput对象,包含患者的基本信息(年龄、性别等)
  211. start_nodes: 症状节点名称列表,作为遍历的起点
  212. max_hops: 最大遍历深度,默认为3
  213. 返回值:
  214. 返回一个包含以下信息的字典:
  215. - details: 按科室汇总的结果
  216. - diags: 按相关性排序的疾病列表
  217. - checks: 按出现频率排序的检查列表
  218. - drugs: 按出现频率排序的药品列表
  219. - total_diags: 疾病总数
  220. - total_checks: 检查总数
  221. - total_drugs: 药品总数
  222. 主要步骤:
  223. 1. 初始化允许的节点类型和关系类型
  224. 2. 将症状名称转换为节点ID
  225. 3. 遍历图谱查找相关疾病(STEP 1)
  226. 4. 查找疾病对应的科室、检查和药品(STEP 2)
  227. 5. 按科室汇总结果(STEP 3)
  228. 6. 对结果进行排序和统计(STEP 4-6)
  229. """
  230. start_time = time.time()
  231. # 定义允许的节点类型,包括科室、疾病、药品、检查和症状
  232. # 这些类型用于后续的节点过滤和路径查找
  233. DEPARTMENT = ['科室', 'Department']
  234. DIESEASE = ['疾病', 'Disease']
  235. DRUG = ['药品', 'Drug']
  236. CHECK = ['检查', 'Check']
  237. SYMPTOM = ['症状', 'Symptom']
  238. #allowed_types = DEPARTMENT + DIESEASE + DRUG + CHECK + SYMPTOM
  239. allowed_types = DEPARTMENT + DIESEASE + SYMPTOM
  240. # 定义允许的关系类型,包括has_symptom、need_check、recommend_drug、belongs_to
  241. # 这些关系类型用于后续的路径查找和过滤
  242. allowed_links = ['has_symptom', 'need_check', 'recommend_drug', 'belongs_to']
  243. # 将输入的症状名称转换为节点ID
  244. # 由于可能存在同名节点,转换后的节点ID数量可能大于输入的症状数量
  245. node_ids = []
  246. node_id_names = {}
  247. # start_nodes里面重复的症状,去重同样的症状
  248. start_nodes = list(set(start_nodes))
  249. for node in start_nodes:
  250. #print(f"searching for node {node}")
  251. result = self.entity_data[self.entity_data['name'] == node]
  252. # print(f"searching for node {result}")
  253. for index, data in result.iterrows():
  254. node_id_names[index] = data["name"]
  255. node_ids = node_ids + [index]
  256. #print(f"start travel from {node_id_names}")
  257. # 这里是一个队列,用于存储待遍历的症状:
  258. node_ids_filtered = []
  259. for node in node_ids:
  260. if self.graph.has_node(node):
  261. node_ids_filtered.append(node)
  262. else:
  263. logger.debug(f"node {node} not found")
  264. node_ids = node_ids_filtered
  265. end_time = time.time()
  266. print(f"init执行完成,耗时:{end_time - start_time:.2f}秒")
  267. start_time = time.time()
  268. results = {}
  269. for node in node_ids:
  270. visited = set()
  271. temp_results = {}
  272. cache_key = f"symptom_{node}"
  273. cache_data = self.cache.get(cache_key)
  274. if cache_data:
  275. logger.debug(f"cache hit for {cache_key}")
  276. temp_results = cache_data
  277. if results=={}:
  278. results = temp_results
  279. else:
  280. for disease_id in temp_results:
  281. path = temp_results[disease_id]["path"][0]
  282. if disease_id in results.keys():
  283. results[disease_id]["count"] = temp_results[disease_id]["count"] + 1
  284. results[disease_id]["path"].append(path)
  285. else:
  286. results[disease_id] = temp_results[disease_id]
  287. continue
  288. queue = [(node, 0, node_id_names[node], {'allowed_types': allowed_types, 'allowed_links': allowed_links})]
  289. # 整理input的数据,这里主要是要检查输入数据是否正确,也需要做转换
  290. if input.pat_age.value > 0 and input.pat_age.type == 'year':
  291. # 这里将年龄从年转换为月,因为我们的图里面的年龄都是以月为单位的
  292. input.pat_age.value = input.pat_age.value * 12
  293. input.pat_age.type = 'month'
  294. # STEP 1: 假设start_nodes里面都是症状,第一步我们先找到这些症状对应的疾病
  295. # TODO 由于这部分是按照症状逐一去寻找疾病,所以实际应用中可以缓存这些结果
  296. while queue:
  297. node, depth, path, data = queue.pop(0)
  298. node = int(node)
  299. # 这里是通过id去获取节点的name和type
  300. entity_data = self.entity_data[self.entity_data.index == node]
  301. # 如果节点不存在,那么跳过
  302. if entity_data.empty:
  303. continue
  304. if self.graph.nodes.get(node) is None:
  305. continue
  306. node_type = self.entity_data[self.entity_data.index == node]['type'].tolist()[0]
  307. node_name = self.entity_data[self.entity_data.index == node]['name'].tolist()[0]
  308. # print(f"node {node} type {node_type}")
  309. if node_type in DIESEASE:
  310. # print(f"node {node} type {node_type} is a disease")
  311. if self.check_diease_allowed(node) == False:
  312. continue
  313. if node in temp_results.keys():
  314. temp_results[node]["count"] = temp_results[node]["count"] + 1
  315. temp_results[node]["path"].append(path)
  316. else:
  317. temp_results[node] = {"type": node_type, "count": 1, "name": node_name, 'path': [path]}
  318. continue
  319. if node in visited or depth > max_hops:
  320. # print(f"{node} already visited or reach max hops")
  321. continue
  322. visited.add(node)
  323. # print(f"check edges from {node}")
  324. if node not in self.graph:
  325. # print(f"node {node} not found in graph")
  326. continue
  327. # todo 目前是取入边,出边是不是也有用?
  328. for edge in self.graph.in_edges(node, data=True):
  329. src, dest, edge_data = edge
  330. if src not in visited and depth + 1 < max_hops:
  331. # print(f"put into queue travel from {src} to {dest}")
  332. queue.append((src, depth + 1, path, data))
  333. # else:
  334. # print(f"skip travel from {src} to {dest}")
  335. self.cache.set(cache_key, temp_results)
  336. if results == {}:
  337. results = temp_results
  338. else:
  339. for disease_id in temp_results:
  340. path = temp_results[disease_id]["path"][0]
  341. if disease_id in results.keys():
  342. results[disease_id]["count"] = temp_results[disease_id]["count"] + 1
  343. results[disease_id]["path"].append(path)
  344. else:
  345. results[disease_id] = temp_results[disease_id]
  346. end_time = time.time()
  347. # 这里我们需要对结果进行过滤,过滤掉不满足条件的疾病
  348. new_results = {}
  349. print(len(results))
  350. for item in results:
  351. if self.check_sex_allowed(node, input.pat_sex.value) == False:
  352. continue
  353. if self.check_age_allowed(node, input.pat_age.value) == False:
  354. continue
  355. new_results[item] = results[item]
  356. results = new_results
  357. print(len(results))
  358. print(f"STEP 1 执行完成,耗时:{end_time - start_time:.2f}秒")
  359. print(f"STEP 1 遍历图谱查找相关疾病 finished")
  360. # 这里输出markdonw格式日志
  361. log_data = ["|疾病|症状|出现次数|是否相关"]
  362. log_data.append("|--|--|--|--|")
  363. for item in results:
  364. data = results[item]
  365. data['relevant'] = False
  366. if data["count"] / len(start_nodes) > 0.5:
  367. # 疾病有50%以上的症状出现,才认为是相关的
  368. data['relevant'] = True
  369. disease_name = data["name"]
  370. key = 'disease_name_parent_' +disease_name
  371. cached_value = self.cache.get(key)
  372. if cached_value is None:
  373. out_edges = self.graph.out_edges(item, data=True)
  374. for edge in out_edges:
  375. src, dest, edge_data = edge
  376. if edge_data["type"] != '疾病相关父类':
  377. continue
  378. dest_data = self.entity_data[self.entity_data.index == dest]
  379. if dest_data.empty:
  380. continue
  381. dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
  382. self.cache.set(key, dest_name)
  383. break
  384. # 如果data['relevant']为False,那么我们就不输出
  385. if data['relevant'] == False:
  386. continue
  387. log_data.append(f"|{data['name']}|{','.join(data['path'])}|{data['count']}|{data['relevant']}|")
  388. content = "疾病和症状相关性统计表格\n" + "\n".join(log_data)
  389. #print(f"\n{content}")
  390. # STEP 2: 找到这些疾病对应的科室,检查和药品
  391. # 由于这部分是按照疾病逐一去寻找,所以实际应用中可以缓存这些结果
  392. start_time = time.time()
  393. print("STEP 2 查找疾病对应的科室、检查和药品 start")
  394. for disease in results.keys():
  395. # TODO 这里需要对疾病对应的科室检查药品进行加载缓存,性能可以得到很大的提升
  396. if results[disease]["relevant"] == False:
  397. continue
  398. print(f"search data for {disease}:{results[disease]['name']}")
  399. queue = []
  400. queue.append((disease, 0, disease, {'allowed_types': DEPARTMENT, 'allowed_links': ['belongs_to']}))
  401. # 这里尝试过将visited放倒for disease循环外面,但是会造成一些问题,性能提升也不明显,所以这里还是放在for disease循环里面
  402. visited = set()
  403. while queue:
  404. node, depth, disease, data = queue.pop(0)
  405. if node in visited or depth > max_hops:
  406. continue
  407. visited.add(node)
  408. entity_data = self.entity_data[self.entity_data.index == node]
  409. # 如果节点不存在,那么跳过
  410. if entity_data.empty:
  411. continue
  412. node_type = self.entity_data[self.entity_data.index == node]['type'].tolist()[0]
  413. node_name = self.entity_data[self.entity_data.index == node]['name'].tolist()[0]
  414. # print(f"node {results[disease].get("name", disease)} {node_name} type {node_type}")
  415. # node_type = self.graph.nodes[node].get('type')
  416. if node_type in DEPARTMENT:
  417. # 展开科室,重复次数为疾病出现的次数,为了方便后续统计
  418. department_data = [node_name] * results[disease]["count"]
  419. if 'department' in results[disease].keys():
  420. results[disease]["department"] = results[disease]["department"] + department_data
  421. else:
  422. results[disease]["department"] = department_data
  423. continue
  424. # if node_type in CHECK:
  425. # if 'check' in results[disease].keys():
  426. # results[disease]["check"] = list(set(results[disease]["check"]+[node_name]))
  427. # else:
  428. # results[disease]["check"] = [node_name]
  429. # continue
  430. # if node_type in DRUG:
  431. # if 'drug' in results[disease].keys():
  432. # results[disease]["drug"] = list(set(results[disease]["drug"]+[node_name]))
  433. # else:
  434. # results[disease]["drug"] = [node_name]
  435. # continue
  436. out_edges = self.graph.out_edges(node, data=True)
  437. for edge in out_edges:
  438. src, dest, edge_data = edge
  439. src_data = self.entity_data[self.entity_data.index == src]
  440. if src_data.empty:
  441. continue
  442. dest_data = self.entity_data[self.entity_data.index == dest]
  443. if dest_data.empty:
  444. continue
  445. src_name = self.entity_data[self.entity_data.index == src]['name'].tolist()[0]
  446. dest_name = self.entity_data[self.entity_data.index == dest]['name'].tolist()[0]
  447. dest_type = self.entity_data[self.entity_data.index == dest]['type'].tolist()[0]
  448. if dest_type in allowed_types:
  449. if dest not in visited and depth + 1 < max_hops:
  450. # print(f"put travel request in queue from {src}:{src_name} to {dest}:{dest_name}")
  451. queue.append((edge[1], depth + 1, disease, data))
  452. # TODO 可以在这里将results里面的每个疾病对应的科室,检查和药品进行缓存,方便后续使用
  453. # for item in results.keys():
  454. # department_data = results[item].get("department", [])
  455. # count_data = results[item].get("count")
  456. # check_data = results[item].get("check", [])
  457. # drug_data = results[item].get("drug", [])
  458. # #缓存代码放在这里
  459. print(f"STEP 2 finished")
  460. end_time = time.time()
  461. print(f"STEP 2 执行完成,耗时:{end_time - start_time:.2f}秒")
  462. # 这里输出日志
  463. log_data = ["|disease|count|department|check|drug|"]
  464. log_data.append("|--|--|--|--|--|")
  465. for item in results.keys():
  466. department_data = results[item].get("department", [])
  467. count_data = results[item].get("count")
  468. check_data = results[item].get("check", [])
  469. drug_data = results[item].get("drug", [])
  470. log_data.append(
  471. f"|{results[item].get("name", item)}|{count_data}|{','.join(department_data)}|{','.join(check_data)}|{','.join(drug_data)}|")
  472. #print("疾病科室检查药品相关统计\n" + "\n".join(log_data))
  473. # 日志输出完毕
  474. # STEP 3: 对于结果按照科室维度进行汇总
  475. start_time = time.time()
  476. print(f"STEP 3 对于结果按照科室维度进行汇总 start")
  477. final_results = {}
  478. total = 0
  479. for disease in results.keys():
  480. disease = int(disease)
  481. # 由于存在有些疾病没有科室的情况,所以这里需要做一下处理
  482. departments = ['DEFAULT']
  483. if 'department' in results[disease].keys():
  484. departments = results[disease]["department"]
  485. # else:
  486. # edges = KGEdgeService(next(get_db())).get_edges_by_nodes(src_id=disease, category='belongs_to')
  487. # #edges有可能为空,这里需要做一下处理
  488. # if len(edges) == 0:
  489. # continue
  490. # departments = [edge['dest_node']['name'] for edge in edges]
  491. # 处理查询结果
  492. for department in departments:
  493. total += 1
  494. if not department in final_results.keys():
  495. final_results[department] = {
  496. "diseases": [results[disease].get("name", disease)],
  497. "checks": results[disease].get("check", []),
  498. "drugs": results[disease].get("drug", []),
  499. "count": 1
  500. }
  501. else:
  502. final_results[department]["diseases"] = final_results[department]["diseases"] + [
  503. results[disease].get("name", disease)]
  504. final_results[department]["checks"] = final_results[department]["checks"] + results[disease].get(
  505. "check", [])
  506. final_results[department]["drugs"] = final_results[department]["drugs"] + results[disease].get(
  507. "drug", [])
  508. final_results[department]["count"] += 1
  509. # 这里是统计科室出现的分布
  510. for department in final_results.keys():
  511. final_results[department]["score"] = final_results[department]["count"] / total
  512. print(f"STEP 3 finished")
  513. end_time = time.time()
  514. print(f"STEP 3 执行完成,耗时:{end_time - start_time:.2f}秒")
  515. # 这里输出日志
  516. log_data = ["|department|disease|check|drug|count|score"]
  517. log_data.append("|--|--|--|--|--|--|")
  518. for department in final_results.keys():
  519. diesease_data = final_results[department].get("diseases", [])
  520. check_data = final_results[department].get("checks", [])
  521. drug_data = final_results[department].get("drugs", [])
  522. count_data = final_results[department].get("count", 0)
  523. score_data = final_results[department].get("score", 0)
  524. log_data.append(
  525. f"|{department}|{','.join(diesease_data)}|{','.join(check_data)}|{','.join(drug_data)}|{count_data}|{score_data}|")
  526. #print("\n" + "\n".join(log_data))
  527. # STEP 4: 对于final_results里面的disease,checks和durgs统计出现的次数并且按照次数降序排序
  528. print(f"STEP 4 start")
  529. start_time = time.time()
  530. def sort_data(data, count=5):
  531. tmp = {}
  532. for item in data:
  533. if item in tmp.keys():
  534. tmp[item]["count"] += 1
  535. else:
  536. tmp[item] = {"count": 1}
  537. sorted_data = sorted(tmp.items(), key=lambda x: x[1]["count"], reverse=True)
  538. return sorted_data[:count]
  539. for department in final_results.keys():
  540. final_results[department]['name'] = department
  541. final_results[department]["diseases"] = sort_data(final_results[department]["diseases"])
  542. final_results[department]["checks"] = sort_data(final_results[department]["checks"])
  543. final_results[department]["drugs"] = sort_data(final_results[department]["drugs"])
  544. # 这里把科室做一个排序,按照出现的次数降序排序
  545. sorted_final_results = sorted(final_results.items(), key=lambda x: x[1]["count"], reverse=True)
  546. print(f"STEP 4 finished")
  547. end_time = time.time()
  548. print(f"STEP 4 执行完成,耗时:{end_time - start_time:.2f}秒")
  549. # 这里输出markdown日志
  550. log_data = ["|department|disease|check|drug|count|score"]
  551. log_data.append("|--|--|--|--|--|--|")
  552. for department in final_results.keys():
  553. diesease_data = final_results[department].get("diseases")
  554. check_data = final_results[department].get("checks")
  555. drug_data = final_results[department].get("drugs")
  556. count_data = final_results[department].get("count", 0)
  557. score_data = final_results[department].get("score", 0)
  558. log_data.append(f"|{department}|{diesease_data}|{check_data}|{drug_data}|{count_data}|{score_data}|")
  559. #print("\n" + "\n".join(log_data))
  560. # STEP 5: 对于final_results里面的diseases, checks和durgs统计全局出现的次数并且按照次数降序排序
  561. print(f"STEP 5 start")
  562. start_time = time.time()
  563. checks = {}
  564. drugs = {}
  565. diags = {}
  566. total_check = 0
  567. total_drug = 0
  568. total_diags = 0
  569. for department in final_results.keys():
  570. # 这里是提取了科室出现的概率,对于缺省的科室设置了0.1
  571. # 对于疾病来说用疾病在科室中出现的次数乘以科室出现的概率作为分数
  572. if department == 'DEFAULT':
  573. department_factor = 0.1
  574. else:
  575. department_factor = final_results[department]["score"]
  576. #input.department=department时,增加权重
  577. #todo 科室同义词都需要增加权重
  578. if input.department.value == department:
  579. department_factor = department_factor*1.1
  580. for disease, data in final_results[department]["diseases"]:
  581. total_diags += 1
  582. key = 'disease_name_parent_' + disease
  583. cached_value = self.cache.get(key)
  584. if cached_value is not None:
  585. disease = cached_value
  586. if disease in diags.keys():
  587. diags[disease]["count"] += data["count"]
  588. diags[disease]["score"] += data["count"] * department_factor
  589. else:
  590. diags[disease] = {"count": data["count"], "score": data["count"] * department_factor}
  591. # 对于检查和药品直接累加出现的次数
  592. # for check, data in final_results[department]["checks"]:
  593. # total_check += 1
  594. # if check in checks.keys():
  595. # checks[check]["count"] += data["count"]
  596. # else:
  597. # checks[check] = {"count":data["count"]}
  598. # for drug, data in final_results[department]["drugs"]:
  599. # total_drug += 1
  600. # if drug in drugs.keys():
  601. # drugs[drug]["count"] += data["count"]
  602. # else:
  603. # drugs[drug] = {"count":data["count"]}
  604. sorted_score_diags = sorted(diags.items(), key=lambda x: x[1]["score"], reverse=True)
  605. # sorted_checks = sorted(checks.items(), key=lambda x:x[1]["count"],reverse=True)
  606. # sorted_drugs = sorted(drugs.items(), key=lambda x:x[1]["count"],reverse=True)
  607. print(f"STEP 5 finished")
  608. end_time = time.time()
  609. print(f"STEP 5 执行完成,耗时:{end_time - start_time:.2f}秒")
  610. # 这里输出markdown日志
  611. log_data = ["|department|disease|check|drug|count|score"]
  612. log_data.append("|--|--|--|--|--|--|")
  613. for department in final_results.keys():
  614. diesease_data = final_results[department].get("diseases")
  615. # check_data = final_results[department].get("checks")
  616. # drug_data = final_results[department].get("drugs")
  617. count_data = final_results[department].get("count", 0)
  618. score_data = final_results[department].get("score", 0)
  619. log_data.append(f"|{department}|{diesease_data}|{check_data}|{drug_data}|{count_data}|{score_data}|")
  620. #print("这里是经过排序的数据\n" + "\n".join(log_data))
  621. # STEP 6: 整合数据并返回
  622. # if "department" in item.keys():
  623. # final_results["department"] = list(set(final_results["department"]+item["department"]))
  624. # if "diseases" in item.keys():
  625. # final_results["diseases"] = list(set(final_results["diseases"]+item["diseases"]))
  626. # if "checks" in item.keys():
  627. # final_results["checks"] = list(set(final_results["checks"]+item["checks"]))
  628. # if "drugs" in item.keys():
  629. # final_results["drugs"] = list(set(final_results["drugs"]+item["drugs"]))
  630. # if "symptoms" in item.keys():
  631. # final_results["symptoms"] = list(set(final_results["symptoms"]+item["symptoms"]))
  632. return {"details": sorted_final_results,
  633. "score_diags": sorted_score_diags,"total_diags": total_diags,
  634. # "checks":sorted_checks, "drugs":sorted_drugs,
  635. # "total_checks":total_check, "total_drugs":total_drug
  636. }