build_graph_index.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import sys,os
  2. current_path = os.getcwd()
  3. sys.path.append(current_path)
  4. import json
  5. from libs.embed_helper import EmbedHelper
  6. def embed_test():
  7. embed_helper = EmbedHelper()
  8. result = embed_helper.embed_text("你好")
  9. print(f"result length: {len(result)}")
  10. print(result)
  11. def search_test():
  12. from utils.es import ElasticsearchOperations
  13. es = ElasticsearchOperations()
  14. result = es.search("graph_entity_index", "上呼吸道感染", 10)
  15. for item in result:
  16. print(item)
  17. def load_entities():
  18. print("load entity data")
  19. with open(f"{current_path}\\web\\cached_data\\entities_med.json", "r", encoding="utf-8") as f:
  20. entities = json.load(f)
  21. return entities
  22. def load_relationships():
  23. print("load relationship data")
  24. with open(f"{current_path}\\web\\cached_data\\relationship_med.json", "r", encoding="utf-8") as f:
  25. relationships = json.load(f)
  26. return relationships
  27. def write_data_file(file_name, data):
  28. if len(data) == 0:
  29. return
  30. print("write data file", file_name)
  31. with open(file_name, "w", encoding="utf-8") as f:
  32. f.write(json.dumps(data, ensure_ascii=False,indent=4))
  33. def import_index():
  34. from utils.es import ElasticsearchOperations
  35. es = ElasticsearchOperations()
  36. es.delete_index("graph_entity_index")
  37. for i in range(999):
  38. if os.path.exists(f"{current_path}\\web\\cached_data\\embed\\word_index_{i}.json"):
  39. print("load embed data", f"{current_path}\\web\\cached_data\\embed\\word_index_{i}.json")
  40. with open(f"{current_path}\\web\\cached_data\\embed\\word_index_{i}.json", "r", encoding="utf-8") as f:
  41. records = json.load(f)
  42. for item in records:
  43. node_id = item[0]
  44. embed = item[1]
  45. doc = { "title": node_id,
  46. "text": node_id,
  47. "embedding": embed}
  48. es.add_document("graph_entity_index", es.get_doc_id(node_id), doc)
  49. print("index added of ", node_id, "embed length: ", len(embed))
  50. #attr_embed_list = item[2]
  51. def import_community_report_index():
  52. from utils.es import ElasticsearchOperations
  53. embed_helper = EmbedHelper()
  54. es = ElasticsearchOperations()
  55. es.delete_index("graph_community_report_index")
  56. for filename in os.listdir(f"{current_path}\\web\\cached_data\\report"):
  57. if filename.endswith(".md"):
  58. file_path = os.path.join(f"{current_path}\\web\\cached_data\\report", filename)
  59. with open(file_path, "r", encoding="utf-8") as f:
  60. content = f.read()
  61. jsonstr = []
  62. found_json = False
  63. for line in content.splitlines():
  64. if line.startswith("```json"):
  65. jsonstr = []
  66. found_json = True
  67. continue
  68. if line.startswith("```"):
  69. found_json = False
  70. continue
  71. if found_json:
  72. jsonstr.append(line)
  73. doc = { "title": "",
  74. "text": content,
  75. "embedding": []}
  76. jsondata = json.loads("\n".join(jsonstr))
  77. title_list = []
  78. for item in jsondata:
  79. title_list.append(item["name"])
  80. doc["title"] = " ".join(title_list)
  81. doc["embedding"] = embed_helper.embed_text(doc["title"])
  82. es.add_document("graph_community_report_index", es.get_doc_id(doc["title"]), doc)
  83. print("index added of ", doc["title"], "embed length: ", len(doc["embedding"]))
  84. def build_index():
  85. print("build index")
  86. embed_helper = EmbedHelper()
  87. entities = load_entities()
  88. count = 0
  89. records = []
  90. index = 0
  91. for item in entities:
  92. node_id = item[0]
  93. print("process node: ",count, node_id)
  94. attrs = item[1]
  95. embed = embed_helper.embed_text(node_id)
  96. attr_embed_list = []
  97. for attr in attrs:
  98. if len(attrs[attr])>3 and attr not in ["type", "description"]:
  99. attr_embed = embed_helper.embed_text(attrs[attr])
  100. attr_embed_list.append([attrs[attr],attr_embed])
  101. else:
  102. print("skip", attr)
  103. records.append([node_id, embed, attr_embed_list])
  104. count += 1
  105. if count % 100 == 0:
  106. write_data_file(f"{current_path}\\web\\cached_data\\embed\\word_index_{index}.json", records)
  107. index = index + 1
  108. records = []
  109. write_data_file(f"{current_path}\\web\\cached_data\\embed\\word_index_{index}.json", records)
  110. # 使用示例
  111. if __name__ == "__main__":
  112. param_count = len(sys.argv)
  113. if param_count == 2:
  114. action = sys.argv[1]
  115. if action== "test":
  116. embed_test()
  117. search_test()
  118. if action == "build":
  119. build_index()
  120. if action == "import":
  121. import_index()
  122. if action == "import_com":
  123. import_community_report_index()
  124. if action == "chunc":
  125. pass
  126. #build_index()