dump_graph_data.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. '''
  2. 这个脚本是用来从postgre数据库中导出图谱数据到json文件的。
  3. '''
  4. import sys,os
  5. current_path = os.getcwd()
  6. sys.path.append(current_path)
  7. current_path ="D:\\work\\03\\cached_data\\new"
  8. from sqlalchemy import text
  9. from sqlalchemy.orm import Session
  10. import json
  11. #这个是数据库的连接
  12. from web.db.database import SessionLocal
  13. #两个会话,分别是读取节点和属性的
  14. db = SessionLocal()
  15. prop = SessionLocal()
  16. #图谱id
  17. GRAPH_ID = 2
  18. def get_props(ref_id):
  19. props = {}
  20. sql = """select prop_name, prop_value,prop_title from kg_props where ref_id=:ref_id"""
  21. result = prop.execute(text(sql), {'ref_id':ref_id})
  22. for record in result:
  23. prop_name, prop_value,prop_title = record
  24. props[prop_name] = prop_title + ":" +prop_value
  25. return props
  26. def get_entities():
  27. COUNT_SQL = f"select count(*) from kg_nodes where graph_id={GRAPH_ID}"
  28. result = db.execute(text(COUNT_SQL))
  29. count = result.scalar()
  30. print("total nodes: ", count)
  31. entities = []
  32. batch = 100
  33. start = 1
  34. while start < count:
  35. sql = f"""select id,name,category from kg_nodes where graph_id={GRAPH_ID} order by id limit :batch OFFSET :start"""
  36. result = db.execute(text(sql), {'start':start, 'batch':batch})
  37. #["发热",{"type":"症状","description":"发热是诊断的主要目的,用于明确发热病因。"}]
  38. row_count = 0
  39. for row in result:
  40. id,name,category = row
  41. props = get_props(id)
  42. entities.append([id,{"name":name, 'type':category,'description':'', **props}])
  43. row_count += 1
  44. if row_count == 0:
  45. break
  46. start = start + row_count
  47. print("start: ", start, "row_count: ", row_count)
  48. with open(current_path+"\\entities_med.json", "w", encoding="utf-8") as f:
  49. f.write(json.dumps(entities, ensure_ascii=False,indent=4))
  50. def get_names(src_id, dest_id):
  51. sql = """select id,name,category from kg_nodes where id = :src_id"""
  52. result = db.execute(text(sql), {'src_id':src_id}).first()
  53. id,src_name,src_category = result
  54. result = db.execute(text(sql), {'src_id':dest_id}).first()
  55. id,dest_name,dest_category = result
  56. return (src_id, src_name, src_category, dest_id, dest_name, dest_category)
  57. def get_relationships():
  58. COUNT_SQL = f"select count(*) from kg_edges where graph_id={GRAPH_ID}"
  59. result = db.execute(text(COUNT_SQL))
  60. count = result.scalar()
  61. print("total edges: ", count)
  62. edges = []
  63. batch = 1000
  64. start = 1
  65. file_index = 1
  66. while start < count:
  67. sql = f"""select id,name,category,src_id,dest_id from kg_edges where graph_id={GRAPH_ID} order by id limit :batch OFFSET :start"""
  68. result = db.execute(text(sql), {'start':start, 'batch':batch})
  69. #["发热",{"type":"症状","description":"发热是诊断的主要目的,用于明确发热病因。"}]
  70. row_count = 0
  71. for row in result:
  72. id,name,category,src_id,dest_id = row
  73. props = get_props(id)
  74. src_id, src_name, src_category, dest_id, dest_name, dest_category = get_names(src_id, dest_id)
  75. edges.append([src_id, {"id":src_id, "name":src_name, "type":src_category}, dest_id,{"id":dest_id,"name":dest_name,"type":dest_category}, {'type':category,'name':name, **props}])
  76. row_count += 1
  77. if row_count == 0:
  78. break
  79. start = start + row_count
  80. print("start: ", start, "row_count: ", row_count)
  81. if len(edges) > 10000:
  82. with open(current_path+f"\\relationship_med_{file_index}.json", "w", encoding="utf-8") as f:
  83. f.write(json.dumps(edges, ensure_ascii=False,indent=4))
  84. edges = []
  85. file_index += 1
  86. with open(current_path+"\\relationship_med_0.json", "w", encoding="utf-8") as f:
  87. f.write(json.dumps(edges, ensure_ascii=False,indent=4))
  88. #导出节点数据
  89. get_entities()
  90. #导出关系数据
  91. get_relationships()