extract_disease_doc.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import sys,os
  2. current_path = os.getcwd()
  3. sys.path.append(current_path)
  4. import json
  5. from libs.embed_helper import EmbedHelper
  6. def embed_test():
  7. embed_helper = EmbedHelper()
  8. result = embed_helper.embed_text("你好")
  9. print(f"result length: {len(result)}")
  10. print(result)
  11. def search_test():
  12. from utils.es import ElasticsearchOperations
  13. es = ElasticsearchOperations()
  14. result = es.search("graph_entity_index", "上呼吸道感染", 10)
  15. for item in result:
  16. print(item)
  17. def load_entities():
  18. print("load entity data")
  19. with open(f"{current_path}\\web\\cached_data\\entities_med.json", "r", encoding="utf-8") as f:
  20. entities = json.load(f)
  21. return entities
  22. def load_relationships():
  23. print("load relationship data")
  24. with open(f"{current_path}\\web\\cached_data\\relationship_med.json", "r", encoding="utf-8") as f:
  25. relationships = json.load(f)
  26. return relationships
  27. def write_data_file(file_name, data):
  28. if len(data) == 0:
  29. return
  30. print("write data file", file_name)
  31. with open(file_name, "w", encoding="utf-8") as f:
  32. f.write(json.dumps(data, ensure_ascii=False,indent=4))
  33. def build_index():
  34. print("build index")
  35. embed_helper = EmbedHelper()
  36. entities = load_entities()
  37. count = 0
  38. index = 0
  39. for item in entities:
  40. node_id = item[0]
  41. print("process node: ",count, node_id)
  42. texts = []
  43. attrs = item[1]
  44. attr_embed_list = []
  45. if attrs["type"] == "Disease":
  46. for attr in attrs:
  47. if len(attrs[attr])>3 and attr not in ["type", "description"]:
  48. texts.append(attrs[attr])
  49. attr_embed = embed_helper.embed_text(node_id+"-"+attr+"-"+attrs[attr])
  50. attr_embed_list.append(
  51. {
  52. "title": node_id+"-"+attr,
  53. "text": attrs[attr],
  54. "embedding": attr_embed}
  55. )
  56. else:
  57. print("skip", attr)
  58. doc = { "title": node_id,
  59. "text": "\n".join(texts),
  60. "embedding": attr_embed_list} # 初始化doc对象,确保它在循环外部定义
  61. count += 1
  62. if count % 1 == 0:
  63. write_data_file(f"{current_path}\\web\\cached_data\\diseases\\{index}.json", doc)
  64. index = index + 1
  65. #write_data_file(f"{current_path}\\web\\cached_data\\diseases\\{index}.json", records)
  66. # 使用示例
  67. if __name__ == "__main__":
  68. param_count = len(sys.argv)
  69. if param_count == 2:
  70. action = sys.argv[1]
  71. if action== "test":
  72. embed_test()
  73. search_test()
  74. if action == "build":
  75. build_index()
  76. #build_index()