123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131 |
- import sys,os
- current_path = os.getcwd()
- sys.path.append(current_path)
- import json
- from libs.embed_helper import EmbedHelper
- def embed_test():
- embed_helper = EmbedHelper()
- result = embed_helper.embed_text("你好")
- print(f"result length: {len(result)}")
- print(result)
- def search_test():
- from utils.es import ElasticsearchOperations
- es = ElasticsearchOperations()
- result = es.search("graph_entity_index", "上呼吸道感染", 10)
- for item in result:
- print(item)
- def load_entities():
- print("load entity data")
- with open(f"{current_path}\\web\\cached_data\\entities_med.json", "r", encoding="utf-8") as f:
- entities = json.load(f)
- return entities
-
- def load_relationships():
- print("load relationship data")
- with open(f"{current_path}\\web\\cached_data\\relationship_med.json", "r", encoding="utf-8") as f:
- relationships = json.load(f)
- return relationships
- def write_data_file(file_name, data):
- if len(data) == 0:
- return
- print("write data file", file_name)
- with open(file_name, "w", encoding="utf-8") as f:
- f.write(json.dumps(data, ensure_ascii=False,indent=4))
- def import_index():
- from utils.es import ElasticsearchOperations
- es = ElasticsearchOperations()
- es.delete_index("graph_entity_index")
- for i in range(999):
- if os.path.exists(f"{current_path}\\web\\cached_data\\embed\\word_index_{i}.json"):
- print("load embed data", f"{current_path}\\web\\cached_data\\embed\\word_index_{i}.json")
- with open(f"{current_path}\\web\\cached_data\\embed\\word_index_{i}.json", "r", encoding="utf-8") as f:
- records = json.load(f)
- for item in records:
- node_id = item[0]
- embed = item[1]
- doc = { "title": node_id,
- "text": node_id,
- "embedding": embed}
- es.add_document("graph_entity_index", es.get_doc_id(node_id), doc)
- print("index added of ", node_id, "embed length: ", len(embed))
- #attr_embed_list = item[2]
- def import_community_report_index():
- from utils.es import ElasticsearchOperations
- embed_helper = EmbedHelper()
- es = ElasticsearchOperations()
- es.delete_index("graph_community_report_index")
- for filename in os.listdir(f"{current_path}\\web\\cached_data\\report"):
- if filename.endswith(".md"):
- file_path = os.path.join(f"{current_path}\\web\\cached_data\\report", filename)
- with open(file_path, "r", encoding="utf-8") as f:
- content = f.read()
- jsonstr = []
- found_json = False
- for line in content.splitlines():
- if line.startswith("```json"):
- jsonstr = []
- found_json = True
- continue
- if line.startswith("```"):
- found_json = False
- continue
- if found_json:
- jsonstr.append(line)
- doc = { "title": "",
- "text": content,
- "embedding": []}
- jsondata = json.loads("\n".join(jsonstr))
- title_list = []
- for item in jsondata:
- title_list.append(item["name"])
- doc["title"] = " ".join(title_list)
- doc["embedding"] = embed_helper.embed_text(doc["title"])
- es.add_document("graph_community_report_index", es.get_doc_id(doc["title"]), doc)
- print("index added of ", doc["title"], "embed length: ", len(doc["embedding"]))
- def build_index():
- print("build index")
- embed_helper = EmbedHelper()
- entities = load_entities()
- count = 0
- records = []
- index = 0
- for item in entities:
- node_id = item[0]
- print("process node: ",count, node_id)
- attrs = item[1]
- embed = embed_helper.embed_text(node_id)
- attr_embed_list = []
- for attr in attrs:
- if len(attrs[attr])>3 and attr not in ["type", "description"]:
- attr_embed = embed_helper.embed_text(attrs[attr])
- attr_embed_list.append([attrs[attr],attr_embed])
- else:
- print("skip", attr)
- records.append([node_id, embed, attr_embed_list])
- count += 1
- if count % 100 == 0:
- write_data_file(f"{current_path}\\web\\cached_data\\embed\\word_index_{index}.json", records)
- index = index + 1
- records = []
- write_data_file(f"{current_path}\\web\\cached_data\\embed\\word_index_{index}.json", records)
- # 使用示例
- if __name__ == "__main__":
- param_count = len(sys.argv)
- if param_count == 2:
- action = sys.argv[1]
- if action== "test":
- embed_test()
- search_test()
- if action == "build":
- build_index()
- if action == "import":
- import_index()
- if action == "import_com":
- import_community_report_index()
- if action == "chunc":
- pass
- #build_index()
|