dump_graph_data.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. '''
  2. 这个脚本是用来从postgre数据库中导出图谱数据到json文件的。
  3. '''
  4. import sys,os
  5. current_path = os.getcwd()
  6. sys.path.append(current_path)
  7. from sqlalchemy import text
  8. from sqlalchemy.orm import Session
  9. import json
  10. #这个是数据库的连接
  11. from db.session import SessionLocal
  12. #两个会话,分别是读取节点和属性的
  13. db = SessionLocal()
  14. prop = SessionLocal()
  15. #图谱id
  16. version = 'ER'
  17. def get_props(ref_id):
  18. props = {}
  19. sql = """select prop_name, prop_value,prop_title from kg_props where ref_id=:ref_id"""
  20. result = prop.execute(text(sql), {'ref_id':ref_id})
  21. for record in result:
  22. prop_name, prop_value,prop_title = record
  23. props[prop_name] = prop_title + ":" +prop_value
  24. return props
  25. def get_entities():
  26. #COUNT_SQL = "select count(*) from kg_nodes where version=:version"
  27. COUNT_SQL = "select count(*) from kg_nodes"
  28. result = db.execute(text(COUNT_SQL), {'version': version})
  29. count = result.scalar()
  30. print("total nodes: ", count)
  31. entities = []
  32. batch = 100
  33. start = 1
  34. while start < count:
  35. #sql = """select id,name,category from kg_nodes where version=:version order by id limit :batch OFFSET :start"""
  36. sql = """select id,name,category from kg_nodes order by id limit :batch OFFSET :start"""
  37. result = db.execute(text(sql), {'start':start, 'batch':batch, 'version': version})
  38. #["发热",{"type":"症状","description":"发热是诊断的主要目的,用于明确发热病因。"}]
  39. row_count = 0
  40. for row in result:
  41. id,name,category = row
  42. props = get_props(id)
  43. #description = props.get('description', props.get('standard_description', ''))
  44. entities.append([name,{'type':category, **props}])
  45. row_count += 1
  46. if row_count == 0:
  47. break
  48. start = start + row_count
  49. print("start: ", start, "row_count: ", row_count)
  50. with open(current_path+"\\web\\cached_data\\entities_med.json", "w", encoding="utf-8") as f:
  51. f.write(json.dumps(entities, ensure_ascii=False,indent=4))
  52. def get_names(src_id, dest_id):
  53. sql = """select name from kg_nodes where id = :src_id"""
  54. result = db.execute(text(sql), {'src_id':src_id})
  55. src_name = result.scalar()
  56. result = db.execute(text(sql), {'src_id':dest_id})
  57. dest_name = result.scalar()
  58. return (src_name, dest_name)
  59. def get_relationships():
  60. #COUNT_SQL = "select count(*) from kg_edges where version=:version"
  61. COUNT_SQL = "select count(*) from kg_edges"
  62. result = db.execute(text(COUNT_SQL), {'version': version})
  63. count = result.scalar()
  64. print("total edges: ", count)
  65. edges = []
  66. batch = 1000
  67. start = 1
  68. file_index = 1
  69. while start < count:
  70. #sql = """select id,name,category,src_id,dest_id from kg_edges where version=:version order by id limit :batch OFFSET :start"""
  71. sql = """select id,name,category,src_id,dest_id from kg_edges order by id limit :batch OFFSET :start"""
  72. result = db.execute(text(sql), {'start':start, 'batch':batch, 'version': version})
  73. #["发热",{"type":"症状","description":"发热是诊断的主要目的,用于明确发热病因。"}]
  74. row_count = 0
  75. for row in result:
  76. id,name,category,src_id,dest_id = row
  77. props = get_props(id)
  78. src_name, dest_name = get_names(src_id, dest_id)
  79. edges.append([src_name, dest_name,{'type':category, **props}])
  80. row_count += 1
  81. if row_count == 0:
  82. break
  83. start = start + row_count
  84. print("start: ", start, "row_count: ", row_count)
  85. if len(edges) > 10000:
  86. with open(current_path+f"\\web\\cached_data\\relationship_med_{file_index}.json", "w", encoding="utf-8") as f:
  87. f.write(json.dumps(edges, ensure_ascii=False,indent=4))
  88. edges = []
  89. file_index += 1
  90. with open(current_path+"\\web\\cached_data\\relationship_med.json", "w", encoding="utf-8") as f:
  91. f.write(json.dumps(edges, ensure_ascii=False,indent=4))
  92. #导出节点数据
  93. get_entities()
  94. #导出关系数据
  95. get_relationships()