dump_graph_data.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. '''
  2. 这个脚本是用来从postgre数据库中导出图谱数据到json文件的。
  3. '''
  4. import sys,os
  5. current_path = os.getcwd()
  6. sys.path.append(current_path)
  7. from sqlalchemy import text
  8. from sqlalchemy.orm import Session
  9. import json
  10. #这个是数据库的连接
  11. from db.session import SessionLocal
  12. #两个会话,分别是读取节点和属性的
  13. db = SessionLocal()
  14. prop = SessionLocal()
  15. def get_props(ref_id):
  16. props = {}
  17. sql = """select prop_name, prop_value,prop_title from kg_props where ref_id=:ref_id"""
  18. result = prop.execute(text(sql), {'ref_id':ref_id})
  19. for record in result:
  20. prop_name, prop_value,prop_title = record
  21. #如果prop_title或者prop_value为空则为空字符串
  22. if prop_title is None:
  23. prop_title = ""
  24. if prop_value is None:
  25. prop_value = ""
  26. props[prop_name] = prop_title + ":" +prop_value
  27. return props
  28. def get_entities():
  29. #COUNT_SQL = "select count(*) from kg_nodes where version=:version"
  30. COUNT_SQL = "select count(*) from kg_nodes where status=0"
  31. result = db.execute(text(COUNT_SQL))
  32. count = result.scalar()
  33. print("total nodes: ", count)
  34. entities = []
  35. batch = 100
  36. start = 1
  37. while start < count:
  38. #sql = """select id,name,category from kg_nodes where version=:version order by id limit :batch OFFSET :start"""
  39. sql = """select id,name,category from kg_nodes where status=0 order by id limit :batch OFFSET :start"""
  40. result = db.execute(text(sql), {'start':start, 'batch':batch})
  41. #["发热",{"type":"症状","description":"发热是诊断的主要目的,用于明确发热病因。"}]
  42. row_count = 0
  43. for row in result:
  44. id,name,category = row
  45. props = get_props(id)
  46. entities.append([id,{"name":name, 'type':category,'description':'', **props}])
  47. row_count += 1
  48. if row_count == 0:
  49. break
  50. start = start + row_count
  51. print("start: ", start, "row_count: ", row_count)
  52. with open(current_path+"\\entities_med.json", "w", encoding="utf-8") as f:
  53. f.write(json.dumps(entities, ensure_ascii=False,indent=4))
  54. def get_names(src_id, dest_id):
  55. sql = """select id,name,category from kg_nodes where id = :src_id"""
  56. result = db.execute(text(sql), {'src_id':src_id}).first()
  57. print(result)
  58. if result is None:
  59. #返回空
  60. return (src_id, "", "", dest_id, "", "")
  61. id,src_name,src_category = result
  62. result = db.execute(text(sql), {'src_id':dest_id}).first()
  63. id,dest_name,dest_category = result
  64. return (src_id, src_name, src_category, dest_id, dest_name, dest_category)
  65. def get_relationships():
  66. #COUNT_SQL = "select count(*) from kg_edges where version=:version"
  67. COUNT_SQL = "select count(*) from kg_edges where status=0"
  68. result = db.execute(text(COUNT_SQL))
  69. count = result.scalar()
  70. print("total edges: ", count)
  71. edges = []
  72. batch = 1000
  73. start = 1
  74. file_index = 1
  75. while start < count:
  76. #sql = """select id,name,category,src_id,dest_id from kg_edges where version=:version order by id limit :batch OFFSET :start"""
  77. sql = """select id,name,category,src_id,dest_id from kg_edges where status=0 order by id limit :batch OFFSET :start"""
  78. result = db.execute(text(sql), {'start':start, 'batch':batch})
  79. #["发热",{"type":"症状","description":"发热是诊断的主要目的,用于明确发热病因。"}]
  80. row_count = 0
  81. for row in result:
  82. id,name,category,src_id,dest_id = row
  83. props = get_props(id)
  84. #如果get_names异常,跳过
  85. try:
  86. src_id, src_name, src_category, dest_id, dest_name, dest_category = get_names(src_id, dest_id)
  87. except Exception as e:
  88. print(e)
  89. print("src_id: ", src_id, "dest_id: ", dest_id)
  90. continue
  91. #src_name或dest_name为空,说明节点不存在,跳过
  92. if src_name == "" or dest_name == "":
  93. continue
  94. edges.append([src_id, {"id":src_id, "name":src_name, "type":src_category}, dest_id,{"id":dest_id,"name":dest_name,"type":dest_category}, {'type':category,'name':name, **props}])
  95. row_count += 1
  96. if row_count == 0:
  97. break
  98. start = start + row_count
  99. print("start: ", start, "row_count: ", row_count)
  100. if len(edges) > 10000:
  101. with open(current_path+f"\\relationship_med_{file_index}.json", "w", encoding="utf-8") as f:
  102. f.write(json.dumps(edges, ensure_ascii=False,indent=4))
  103. edges = []
  104. file_index += 1
  105. with open(current_path+"\\relationship_med_0.json", "w", encoding="utf-8") as f:
  106. f.write(json.dumps(edges, ensure_ascii=False,indent=4))
  107. if __name__ == "__main__":
  108. #导出节点数据
  109. get_entities()
  110. #导出关系数据
  111. get_relationships()