import sys,os current_path = os.getcwd() sys.path.append(current_path) import json from libs.embed_helper import EmbedHelper def embed_test(): embed_helper = EmbedHelper() result = embed_helper.embed_text("你好") print(f"result length: {len(result)}") print(result) def search_test(): from utils.es import ElasticsearchOperations es = ElasticsearchOperations() result = es.search("graph_entity_index", "上呼吸道感染", 10) for item in result: print(item) def load_entities(): print("load entity data") with open(f"{current_path}\\web\\cached_data\\entities_med.json", "r", encoding="utf-8") as f: entities = json.load(f) return entities def load_relationships(): print("load relationship data") with open(f"{current_path}\\web\\cached_data\\relationship_med.json", "r", encoding="utf-8") as f: relationships = json.load(f) return relationships def write_data_file(file_name, data): if len(data) == 0: return print("write data file", file_name) with open(file_name, "w", encoding="utf-8") as f: f.write(json.dumps(data, ensure_ascii=False,indent=4)) def import_index(): from utils.es import ElasticsearchOperations es = ElasticsearchOperations() es.delete_index("graph_entity_index") for i in range(999): if os.path.exists(f"{current_path}\\web\\cached_data\\embed\\word_index_{i}.json"): print("load embed data", f"{current_path}\\web\\cached_data\\embed\\word_index_{i}.json") with open(f"{current_path}\\web\\cached_data\\embed\\word_index_{i}.json", "r", encoding="utf-8") as f: records = json.load(f) for item in records: node_id = item[0] embed = item[1] doc = { "title": node_id, "text": node_id, "embedding": embed} es.add_document("graph_entity_index", es.get_doc_id(node_id), doc) print("index added of ", node_id, "embed length: ", len(embed)) #attr_embed_list = item[2] def import_community_report_index(): from utils.es import ElasticsearchOperations embed_helper = EmbedHelper() es = ElasticsearchOperations() es.delete_index("graph_community_report_index") for filename in os.listdir(f"{current_path}\\web\\cached_data\\report"): if filename.endswith(".md"): file_path = os.path.join(f"{current_path}\\web\\cached_data\\report", filename) with open(file_path, "r", encoding="utf-8") as f: content = f.read() jsonstr = [] found_json = False for line in content.splitlines(): if line.startswith("```json"): jsonstr = [] found_json = True continue if line.startswith("```"): found_json = False continue if found_json: jsonstr.append(line) doc = { "title": "", "text": content, "embedding": []} jsondata = json.loads("\n".join(jsonstr)) title_list = [] for item in jsondata: title_list.append(item["name"]) doc["title"] = " ".join(title_list) doc["embedding"] = embed_helper.embed_text(doc["title"]) es.add_document("graph_community_report_index", es.get_doc_id(doc["title"]), doc) print("index added of ", doc["title"], "embed length: ", len(doc["embedding"])) def build_index(): print("build index") embed_helper = EmbedHelper() entities = load_entities() count = 0 records = [] index = 0 for item in entities: node_id = item[0] print("process node: ",count, node_id) attrs = item[1] embed = embed_helper.embed_text(node_id) attr_embed_list = [] for attr in attrs: if len(attrs[attr])>3 and attr not in ["type", "description"]: attr_embed = embed_helper.embed_text(attrs[attr]) attr_embed_list.append([attrs[attr],attr_embed]) else: print("skip", attr) records.append([node_id, embed, attr_embed_list]) count += 1 if count % 100 == 0: write_data_file(f"{current_path}\\web\\cached_data\\embed\\word_index_{index}.json", records) index = index + 1 records = [] write_data_file(f"{current_path}\\web\\cached_data\\embed\\word_index_{index}.json", records) # 使用示例 if __name__ == "__main__": param_count = len(sys.argv) if param_count == 2: action = sys.argv[1] if action== "test": embed_test() search_test() if action == "build": build_index() if action == "import": import_index() if action == "import_com": import_community_report_index() if action == "chunc": pass #build_index()