123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- import os,sys
- import logging
- import json
- current_path = os.getcwd()
- sys.path.append(current_path)
- from agent.db.database import SessionLocal
- from agent.libs.graph import GraphBusiness
- graphBiz = GraphBusiness(db=SessionLocal())
- hi_index = 1
- low_index = 1
- def get_hi_lo_id():
- global hi_index, low_index
- if low_index < 10000:
- low_index += 1
- return hi_index * 10000 + low_index
- else:
- hi_index += 1
- low_index = 1
- return hi_index * 10000 + low_index
- def load_json_from_file(filename: str):
- """检查JSON文件格式是否正确"""
- try:
- with open(filename, 'r', encoding='utf-8') as f:
- content = f.read()
- buffer = []
- json_started = False
- for line in content.split("\n"):
- if line.strip()=="":
- continue
- if line.startswith("```json"):
- buffer = []
- json_started = True
- continue
- if line.startswith("```"):
- if json_started:
- return json.loads("\n".join(buffer))
- json_started = False
- buffer.append(line)
- return None
- except json.JSONDecodeError as e:
- logger.info(f"JSON格式错误: {e}")
- return None
- def parse_json(data):
- if 'entities' in data:
- entities = data['entities']
- for entity in entities:
- if len(entity) == 2:
- entity.append("")
-
- def import_entities(graph_id, entities_list, relations_list):
- for text, ent in entities_list.items():
- id = ent['id']
- name = ent['name']
- type = ent['type']
- full_name = name
- if len(name) > 64:
- name = name[:64]
- logger.info(f"create node: {ent}")
- node = graphBiz.create_node(graph_id=graph_id, name=name, category=type[0], props={'types':",".join(type),'full_name':full_name})
- if node:
- ent["db_id"] = node.id
- for text, relations in relations_list.items():
- source_name = relations['source_name']
- source_type = relations['source_type']
- target_name = relations['target_name']
- target_type = relations['target_type']
- relation_type = relations['type']
- source_db_id = entities_list[source_name]['db_id']
- target_db_id = entities_list[target_name]['db_id']
- graphBiz.create_edge(graph_id=graph_id,
- src_id=source_db_id,
- dest_id=target_db_id,
- name=relation_type,
- category=relation_type,
- props={
- "src_type":source_type,
- "dest_type":target_type,
- })
- logger.info(f"create edge: {source_db_id}->{target_db_id}")
-
- return entities
- if __name__ == "__main__":
- if len(sys.argv) != 2:
- print("Usage: python standard_kb_cbuild.py <path_of_job> <graph_id>")
- sys.exit(-1)
- job_path = sys.argv[1]
- if not os.path.exists(job_path):
- print(f"job path not exists: {job_path}")
- sys.exit(-1)
- kb_path = os.path.join(job_path,"kb_extract")
- if not os.path.exists(kb_path):
- print(f"kb path not exists: {kb_path}")
- sys.exit(-1)
- kb_build_path = os.path.join(job_path,"kb_build")
- job_id = int(job_path.split("/")[-1])
- os.makedirs(kb_build_path,exist_ok=True)
-
- log_path = os.path.join(job_path,"logs")
- print(f"log path: {log_path}")
- handler = logging.FileHandler(f"{log_path}/graph_build.log", mode='a',encoding="utf-8")
- handler.setLevel(logging.INFO)
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- logging.getLogger().addHandler(handler)
- logger = logging.getLogger(__name__)
- entities_list = {}
- relations_list = {}
- for root,dirs,files in os.walk(kb_path):
- for file in files:
- if file.endswith(".txt"):
- logger.info(f"Processing {file}")
- data = load_json_from_file(filename=os.path.join(root,file))
- if data is None:
- continue
- if 'entities' in data:
- entities = data['entities']
- for entity in entities:
- text = entity['text']
- type = entity['type']
- position = entity['position']
- if text in entities_list:
- ent = entities_list[text]
- if type in ent['type']:
- continue
- ent['type'].append(type)
- else:
- ent = {"id": get_hi_lo_id(), "name":text,"type":[type]}
- entities_list[text] = ent
- else:
- logger.info(f"entities not found in {file}")
- if "relations" in data:
- relations = data['relations']
- for relation in relations:
- source_idx = relation['source']
- target_idx = relation['target']
- type = relation['type']
- if source_idx >= len(data['entities']) or target_idx >= len(data['entities']):
- logger.info(f"source/target of relation {relation} not found")
- continue
- source_ent = data['entities'][source_idx]
- target_ent = data['entities'][target_idx]
- source_text = source_ent['text']
- source_type = source_ent['type']
- target_text = target_ent['text']
- target_type = target_ent['type']
-
- if source_text in entities_list:
- source_ent = entities_list[source_text]
- else:
- source_ent = None
- if target_text in entities_list:
- target_ent = entities_list[target_text]
- else:
- target_ent = None
-
- if source_ent and target_ent:
- source_id = source_ent['id']
- target_id = target_ent['id']
- relation_key = f"{source_id}/{source_type}-{type}->{target_id}/{target_type}"
- if relation_key in relations_list:
- continue
- relations_list[relation_key] = {"source_id":source_id,
- "source_name":source_text,
- "source_type":source_type,
- "target_id":target_id,
- "target_name":target_text,
- "target_type":target_type,
- "type":type}
- else:
- logger.info(f"relation {relation_key} not found")
- else:
- logger.info(f"relations not found in {file}")
- print(f"Done {file}")
- with open(os.path.join(kb_build_path,"entities.json"), "w", encoding="utf-8") as f:
- f.write(json.dumps(list(entities_list.values()), ensure_ascii=False,indent=4))
- with open(os.path.join(kb_build_path,"relations.json"), "w", encoding="utf-8") as f:
- f.write(json.dumps(list(relations_list.values()), ensure_ascii=False,indent=4))
- import_entities(job_id, entities_list, relations_list)
- print("Done")
|