123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- import os,sys
- import logging
- import json
- current_path = os.getcwd()
- sys.path.append(current_path)
- from agent.db.database import SessionLocal
- from agent.libs.graph import GraphBusiness
- graphBiz = GraphBusiness(db=SessionLocal())
- hi_index = 1
- low_index = 1
- def get_hi_lo_id():
- global hi_index, low_index
- if low_index < 10000:
- low_index += 1
- return hi_index * 10000 + low_index
- else:
- hi_index += 1
- low_index = 1
- return hi_index * 10000 + low_index
- def load_json_from_file(filename: str):
- """检查JSON文件格式是否正确"""
- try:
- with open(filename, 'r', encoding='utf-8') as f:
- content = f.read()
- buffer = []
- json_started = False
- for line in content.split("\n"):
- if line.strip()=="":
- continue
- if line.startswith("```json"):
- buffer = []
- json_started = True
- continue
- if line.startswith("```"):
- if json_started:
- return json.loads("\n".join(buffer))
- json_started = False
- buffer.append(line)
- return None
- except json.JSONDecodeError as e:
- logger.info(f"JSON格式错误: {e}")
- return None
- def parse_json(data):
- if 'entities' in data:
- entities = data['entities']
- for entity in entities:
- if len(entity) == 2:
- entity.append("")
-
- def import_entities(graph_id, entities_list, relations_list):
- from agent.libs.user_data_relation import UserDataRelationBusiness
- from agent.libs.user import UserBusiness
- from agent.libs.agent import AgentBusiness
- agent_biz = AgentBusiness(db=SessionLocal())
- # 获取job信息
- job = agent_biz.get_job(graph_id)
- if not job:
- logger.error(f"Job not found with id: {graph_id}")
- return entities
- # 从job_creator中提取user_id
- user_id = int(job.job_creator.split('/')[1])
- # 获取用户角色
- user_biz = UserBusiness(db=SessionLocal())
- user = user_biz.get_user(user_id)
- if not user or not user.roles:
- logger.error(f"User {user_id} has no roles assigned")
- return entities
- role_id = user.roles[0].id
- user_name = user.name
- role_name = user.roles[0].name
-
- # 创建用户数据关系业务对象
- relation_biz = UserDataRelationBusiness(db=SessionLocal())
- # 创建子图谱数据关联
- relation_biz.create_relation(user_id, 'sub_graph', graph_id, role_id)
-
- for text, ent in entities_list.items():
- id = ent['id']
- name = ent['name']
- type = ent['type']
- full_name = name
- if len(name) > 64:
- name = name[:64]
- logger.info(f"create node: {ent}")
- node = graphBiz.create_node(graph_id=graph_id, name=name, category=type[0], props={'types':",".join(type),'full_name':full_name})
- if node:
- ent["db_id"] = node.id
- # 创建节点数据关联
- relation_biz.create_relation(user_id, 'DbKgNode', node.id, role_id, user_name, role_name)
-
- for text, relations in relations_list.items():
- source_name = relations['source_name']
- source_type = relations['source_type']
- target_name = relations['target_name']
- target_type = relations['target_type']
- relation_type = relations['type']
- source_db_id = entities_list[source_name]['db_id']
- target_db_id = entities_list[target_name]['db_id']
- edge = graphBiz.create_edge(graph_id=graph_id,
- src_id=source_db_id,
- dest_id=target_db_id,
- name=relation_type,
- category=relation_type,
- props={
- "src_type":source_type,
- "dest_type":target_type,
- })
- logger.info(f"create edge: {source_db_id}->{target_db_id}")
- # 创建边数据关联
- if edge:
- relation_biz.create_relation(user_id, 'DbKgEdge', edge.id, role_id, user_name, role_name)
- # 构建树状结构并存储到数据库
- disease_nodes = [ent for ent in entities_list.values() if '疾病' in ent.get('type', [])]
- if disease_nodes:
- # 按照名称字母排序
- disease_nodes.sort(key=lambda x: x['name'])
-
- # 构建树状结构
- tree_structure = {"name": "疾病", "sNode": []}
- current_letter = None
- letter_group = None
-
- for node in disease_nodes:
- first_letter = node['name'][0].upper()
- if first_letter != current_letter:
- current_letter = first_letter
- letter_group = {"name": current_letter, "sNode": []}
- tree_structure["sNode"].append(letter_group)
- letter_group["sNode"].append({"name": node['name'], "sNode": []})
-
- # 存储到数据库
- from agent.models.db.tree_structure import TreeStructure
- db = SessionLocal()
- try:
- tree_record = TreeStructure(
- user_id=user_id,
- graph_id=graph_id,
- content=json.dumps(tree_structure, ensure_ascii=False)
- )
- db.add(tree_record)
- db.commit()
- except Exception as e:
- db.rollback()
- logger.error(f"Failed to save tree structure: {e}")
- finally:
- db.close()
-
- # 构建树状结构并存储到数据库
- disease_nodes = [ent for ent in entities_list.values() if '疾病' in ent.get('type', [])]
- if disease_nodes:
- # 按照名称字母排序
- disease_nodes.sort(key=lambda x: x['name'])
-
- # 构建树状结构
- tree_structure = {"name": "症状", "sNode": []}
- current_letter = None
- letter_group = None
-
- for node in disease_nodes:
- first_letter = node['name'][0].upper()
- if first_letter != current_letter:
- current_letter = first_letter
- letter_group = {"name": current_letter, "sNode": []}
- tree_structure["sNode"].append(letter_group)
- letter_group["sNode"].append({"name": node['name'], "sNode": []})
-
- # 存储到数据库
- from agent.models.db.tree_structure import TreeStructure
- db = SessionLocal()
- try:
- tree_record = TreeStructure(
- user_id=user_id,
- graph_id=graph_id,
- content=json.dumps(tree_structure, ensure_ascii=False)
- )
- db.add(tree_record)
- db.commit()
- except Exception as e:
- db.rollback()
- logger.error(f"Failed to save tree structure: {e}")
- finally:
- db.close()
-
- return entities
- if __name__ == "__main__":
- if len(sys.argv) != 2:
- print("Usage: python standard_kb_cbuild.py <path_of_job> <graph_id>")
- sys.exit(-1)
- job_path = sys.argv[1]
- if not os.path.exists(job_path):
- print(f"job path not exists: {job_path}")
- sys.exit(-1)
- kb_path = os.path.join(job_path,"kb_extract")
- if not os.path.exists(kb_path):
- print(f"kb path not exists: {kb_path}")
- sys.exit(-1)
- kb_build_path = os.path.join(job_path,"kb_build")
- job_id = int(job_path.split("/")[-1])
- os.makedirs(kb_build_path,exist_ok=True)
-
- log_path = os.path.join(job_path,"logs")
- print(f"log path: {log_path}")
- handler = logging.FileHandler(f"{log_path}/graph_build.log", mode='a',encoding="utf-8")
- handler.setLevel(logging.INFO)
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- handler.setFormatter(formatter)
- logging.getLogger().addHandler(handler)
- logger = logging.getLogger(__name__)
- entities_list = {}
- relations_list = {}
- for root,dirs,files in os.walk(kb_path):
- for file in files:
- if file.endswith(".txt"):
- logger.info(f"Processing {file}")
- data = load_json_from_file(filename=os.path.join(root,file))
- if data is None:
- continue
- if 'entities' in data:
- entities = data['entities']
- for entity in entities:
- text = entity['text']
- type = entity['type']
- position = entity['position']
- if text in entities_list:
- ent = entities_list[text]
- if type in ent['type']:
- continue
- ent['type'].append(type)
- else:
- ent = {"id": get_hi_lo_id(), "name":text,"type":[type]}
- entities_list[text] = ent
- else:
- logger.info(f"entities not found in {file}")
- if "relations" in data:
- relations = data['relations']
- for relation in relations:
- source_idx = relation['source']
- target_idx = relation['target']
- type = relation['type']
- if source_idx >= len(data['entities']) or target_idx >= len(data['entities']):
- logger.info(f"source/target of relation {relation} not found")
- continue
- source_ent = data['entities'][source_idx]
- target_ent = data['entities'][target_idx]
- source_text = source_ent['text']
- source_type = source_ent['type']
- target_text = target_ent['text']
- target_type = target_ent['type']
-
- if source_text in entities_list:
- source_ent = entities_list[source_text]
- else:
- source_ent = None
- if target_text in entities_list:
- target_ent = entities_list[target_text]
- else:
- target_ent = None
-
- if source_ent and target_ent:
- source_id = source_ent['id']
- target_id = target_ent['id']
- relation_key = f"{source_id}/{source_type}-{type}->{target_id}/{target_type}"
- if relation_key in relations_list:
- continue
- relations_list[relation_key] = {"source_id":source_id,
- "source_name":source_text,
- "source_type":source_type,
- "target_id":target_id,
- "target_name":target_text,
- "target_type":target_type,
- "type":type}
- else:
- logger.info(f"relation {relation_key} not found")
- else:
- logger.info(f"relations not found in {file}")
- print(f"Done {file}")
- with open(os.path.join(kb_build_path,"entities.json"), "w", encoding="utf-8") as f:
- f.write(json.dumps(list(entities_list.values()), ensure_ascii=False,indent=4))
- with open(os.path.join(kb_build_path,"relations.json"), "w", encoding="utf-8") as f:
- f.write(json.dumps(list(relations_list.values()), ensure_ascii=False,indent=4))
- import_entities(job_id, entities_list, relations_list)
- print("Done")
|