standard_kb_build.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import os,sys
  2. import logging
  3. import json
  4. current_path = os.getcwd()
  5. sys.path.append(current_path)
  6. from agent.db.database import SessionLocal
  7. from agent.libs.graph import GraphBusiness
  8. graphBiz = GraphBusiness(db=SessionLocal())
  9. hi_index = 1
  10. low_index = 1
  11. def get_hi_lo_id():
  12. global hi_index, low_index
  13. if low_index < 10000:
  14. low_index += 1
  15. return hi_index * 10000 + low_index
  16. else:
  17. hi_index += 1
  18. low_index = 1
  19. return hi_index * 10000 + low_index
  20. def load_json_from_file(filename: str):
  21. """检查JSON文件格式是否正确"""
  22. try:
  23. with open(filename, 'r', encoding='utf-8') as f:
  24. content = f.read()
  25. buffer = []
  26. json_started = False
  27. for line in content.split("\n"):
  28. if line.strip()=="":
  29. continue
  30. if line.startswith("```json"):
  31. buffer = []
  32. json_started = True
  33. continue
  34. if line.startswith("```"):
  35. if json_started:
  36. return json.loads("\n".join(buffer))
  37. json_started = False
  38. buffer.append(line)
  39. return None
  40. except json.JSONDecodeError as e:
  41. logger.info(f"JSON格式错误: {e}")
  42. return None
  43. def parse_json(data):
  44. if 'entities' in data:
  45. entities = data['entities']
  46. for entity in entities:
  47. if len(entity) == 2:
  48. entity.append("")
  49. def import_entities(graph_id, entities_list, relations_list):
  50. for text, ent in entities_list.items():
  51. id = ent['id']
  52. name = ent['name']
  53. type = ent['type']
  54. full_name = name
  55. if len(name) > 64:
  56. name = name[:64]
  57. logger.info(f"create node: {ent}")
  58. node = graphBiz.create_node(graph_id=graph_id, name=name, category=type[0], props={'types':",".join(type),'full_name':full_name})
  59. if node:
  60. ent["db_id"] = node.id
  61. for text, relations in relations_list.items():
  62. source_name = relations['source_name']
  63. source_type = relations['source_type']
  64. target_name = relations['target_name']
  65. target_type = relations['target_type']
  66. relation_type = relations['type']
  67. source_db_id = entities_list[source_name]['db_id']
  68. target_db_id = entities_list[target_name]['db_id']
  69. graphBiz.create_edge(graph_id=graph_id,
  70. src_id=source_db_id,
  71. dest_id=target_db_id,
  72. name=relation_type,
  73. category=relation_type,
  74. props={
  75. "src_type":source_type,
  76. "dest_type":target_type,
  77. })
  78. logger.info(f"create edge: {source_db_id}->{target_db_id}")
  79. return entities
  80. if __name__ == "__main__":
  81. if len(sys.argv) != 2:
  82. print("Usage: python standard_kb_cbuild.py <path_of_job> <graph_id>")
  83. sys.exit(-1)
  84. job_path = sys.argv[1]
  85. if not os.path.exists(job_path):
  86. print(f"job path not exists: {job_path}")
  87. sys.exit(-1)
  88. kb_path = os.path.join(job_path,"kb_extract")
  89. if not os.path.exists(kb_path):
  90. print(f"kb path not exists: {kb_path}")
  91. sys.exit(-1)
  92. kb_build_path = os.path.join(job_path,"kb_build")
  93. job_id = int(job_path.split("/")[-1])
  94. os.makedirs(kb_build_path,exist_ok=True)
  95. log_path = os.path.join(job_path,"logs")
  96. print(f"log path: {log_path}")
  97. handler = logging.FileHandler(f"{log_path}/graph_build.log", mode='a',encoding="utf-8")
  98. handler.setLevel(logging.INFO)
  99. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  100. handler.setFormatter(formatter)
  101. logging.getLogger().addHandler(handler)
  102. logger = logging.getLogger(__name__)
  103. entities_list = {}
  104. relations_list = {}
  105. for root,dirs,files in os.walk(kb_path):
  106. for file in files:
  107. if file.endswith(".txt"):
  108. logger.info(f"Processing {file}")
  109. data = load_json_from_file(filename=os.path.join(root,file))
  110. if data is None:
  111. continue
  112. if 'entities' in data:
  113. entities = data['entities']
  114. for entity in entities:
  115. text = entity['text']
  116. type = entity['type']
  117. position = entity['position']
  118. if text in entities_list:
  119. ent = entities_list[text]
  120. if type in ent['type']:
  121. continue
  122. ent['type'].append(type)
  123. else:
  124. ent = {"id": get_hi_lo_id(), "name":text,"type":[type]}
  125. entities_list[text] = ent
  126. else:
  127. logger.info(f"entities not found in {file}")
  128. if "relations" in data:
  129. relations = data['relations']
  130. for relation in relations:
  131. source_idx = relation['source']
  132. target_idx = relation['target']
  133. type = relation['type']
  134. if source_idx >= len(data['entities']) or target_idx >= len(data['entities']):
  135. logger.info(f"source/target of relation {relation} not found")
  136. continue
  137. source_ent = data['entities'][source_idx]
  138. target_ent = data['entities'][target_idx]
  139. source_text = source_ent['text']
  140. source_type = source_ent['type']
  141. target_text = target_ent['text']
  142. target_type = target_ent['type']
  143. if source_text in entities_list:
  144. source_ent = entities_list[source_text]
  145. else:
  146. source_ent = None
  147. if target_text in entities_list:
  148. target_ent = entities_list[target_text]
  149. else:
  150. target_ent = None
  151. if source_ent and target_ent:
  152. source_id = source_ent['id']
  153. target_id = target_ent['id']
  154. relation_key = f"{source_id}/{source_type}-{type}->{target_id}/{target_type}"
  155. if relation_key in relations_list:
  156. continue
  157. relations_list[relation_key] = {"source_id":source_id,
  158. "source_name":source_text,
  159. "source_type":source_type,
  160. "target_id":target_id,
  161. "target_name":target_text,
  162. "target_type":target_type,
  163. "type":type}
  164. else:
  165. logger.info(f"relation {relation_key} not found")
  166. else:
  167. logger.info(f"relations not found in {file}")
  168. print(f"Done {file}")
  169. with open(os.path.join(kb_build_path,"entities.json"), "w", encoding="utf-8") as f:
  170. f.write(json.dumps(list(entities_list.values()), ensure_ascii=False,indent=4))
  171. with open(os.path.join(kb_build_path,"relations.json"), "w", encoding="utf-8") as f:
  172. f.write(json.dumps(list(relations_list.values()), ensure_ascii=False,indent=4))
  173. import_entities(job_id, entities_list, relations_list)
  174. print("Done")