standard_kb_build.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. import os,sys
  2. import logging
  3. import json
  4. current_path = os.getcwd()
  5. sys.path.append(current_path)
  6. from agent.db.database import SessionLocal
  7. from agent.libs.graph import GraphBusiness
  8. graphBiz = GraphBusiness(db=SessionLocal())
  9. hi_index = 1
  10. low_index = 1
  11. def get_hi_lo_id():
  12. global hi_index, low_index
  13. if low_index < 10000:
  14. low_index += 1
  15. return hi_index * 10000 + low_index
  16. else:
  17. hi_index += 1
  18. low_index = 1
  19. return hi_index * 10000 + low_index
  20. def load_json_from_file(filename: str):
  21. """检查JSON文件格式是否正确"""
  22. try:
  23. with open(filename, 'r', encoding='utf-8') as f:
  24. content = f.read()
  25. buffer = []
  26. json_started = False
  27. for line in content.split("\n"):
  28. if line.strip()=="":
  29. continue
  30. if line.startswith("```json"):
  31. buffer = []
  32. json_started = True
  33. continue
  34. if line.startswith("```"):
  35. if json_started:
  36. return json.loads("\n".join(buffer))
  37. json_started = False
  38. buffer.append(line)
  39. return None
  40. except json.JSONDecodeError as e:
  41. logger.info(f"JSON格式错误: {e}")
  42. return None
  43. def parse_json(data):
  44. if 'entities' in data:
  45. entities = data['entities']
  46. for entity in entities:
  47. if len(entity) == 2:
  48. entity.append("")
  49. def import_entities(graph_id, entities_list, relations_list):
  50. from agent.libs.user_data_relation import UserDataRelationBusiness
  51. from agent.models.db.user import User, Role
  52. from agent.libs.user import UserBusiness
  53. # 获取job信息
  54. job = graphBiz.db.query(graphBiz.DbJob).filter(graphBiz.DbJob.id == graph_id).first()
  55. if not job:
  56. logger.error(f"Job not found with id: {graph_id}")
  57. return entities
  58. # 从job_creator中提取user_id
  59. user_id = int(job.job_creator.split('/')[1])
  60. # 获取用户角色
  61. user_biz = UserBusiness(graphBiz.db)
  62. user = user_biz.get_user(user_id)
  63. if not user or not user.roles:
  64. logger.error(f"User {user_id} has no roles assigned")
  65. return entities
  66. role_id = user.roles[0].id
  67. # 创建用户数据关系业务对象
  68. relation_biz = UserDataRelationBusiness(graphBiz.db)
  69. for text, ent in entities_list.items():
  70. id = ent['id']
  71. name = ent['name']
  72. type = ent['type']
  73. full_name = name
  74. if len(name) > 64:
  75. name = name[:64]
  76. logger.info(f"create node: {ent}")
  77. node = graphBiz.create_node(graph_id=graph_id, name=name, category=type[0], props={'types':",".join(type),'full_name':full_name})
  78. if node:
  79. ent["db_id"] = node.id
  80. # 创建节点数据关联
  81. relation_biz.create_relation(user_id, 'DbKgNode', node.id, role_id)
  82. for text, relations in relations_list.items():
  83. source_name = relations['source_name']
  84. source_type = relations['source_type']
  85. target_name = relations['target_name']
  86. target_type = relations['target_type']
  87. relation_type = relations['type']
  88. source_db_id = entities_list[source_name]['db_id']
  89. target_db_id = entities_list[target_name]['db_id']
  90. edge = graphBiz.create_edge(graph_id=graph_id,
  91. src_id=source_db_id,
  92. dest_id=target_db_id,
  93. name=relation_type,
  94. category=relation_type,
  95. props={
  96. "src_type":source_type,
  97. "dest_type":target_type,
  98. })
  99. logger.info(f"create edge: {source_db_id}->{target_db_id}")
  100. # 创建边数据关联
  101. if edge:
  102. relation_biz.create_relation(user_id, 'DbKgEdge', edge.id, role_id)
  103. return entities
  104. if __name__ == "__main__":
  105. if len(sys.argv) != 2:
  106. print("Usage: python standard_kb_cbuild.py <path_of_job> <graph_id>")
  107. sys.exit(-1)
  108. job_path = sys.argv[1]
  109. if not os.path.exists(job_path):
  110. print(f"job path not exists: {job_path}")
  111. sys.exit(-1)
  112. kb_path = os.path.join(job_path,"kb_extract")
  113. if not os.path.exists(kb_path):
  114. print(f"kb path not exists: {kb_path}")
  115. sys.exit(-1)
  116. kb_build_path = os.path.join(job_path,"kb_build")
  117. job_id = int(job_path.split("/")[-1])
  118. os.makedirs(kb_build_path,exist_ok=True)
  119. log_path = os.path.join(job_path,"logs")
  120. print(f"log path: {log_path}")
  121. handler = logging.FileHandler(f"{log_path}/graph_build.log", mode='a',encoding="utf-8")
  122. handler.setLevel(logging.INFO)
  123. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  124. handler.setFormatter(formatter)
  125. logging.getLogger().addHandler(handler)
  126. logger = logging.getLogger(__name__)
  127. entities_list = {}
  128. relations_list = {}
  129. for root,dirs,files in os.walk(kb_path):
  130. for file in files:
  131. if file.endswith(".txt"):
  132. logger.info(f"Processing {file}")
  133. data = load_json_from_file(filename=os.path.join(root,file))
  134. if data is None:
  135. continue
  136. if 'entities' in data:
  137. entities = data['entities']
  138. for entity in entities:
  139. text = entity['text']
  140. type = entity['type']
  141. position = entity['position']
  142. if text in entities_list:
  143. ent = entities_list[text]
  144. if type in ent['type']:
  145. continue
  146. ent['type'].append(type)
  147. else:
  148. ent = {"id": get_hi_lo_id(), "name":text,"type":[type]}
  149. entities_list[text] = ent
  150. else:
  151. logger.info(f"entities not found in {file}")
  152. if "relations" in data:
  153. relations = data['relations']
  154. for relation in relations:
  155. source_idx = relation['source']
  156. target_idx = relation['target']
  157. type = relation['type']
  158. if source_idx >= len(data['entities']) or target_idx >= len(data['entities']):
  159. logger.info(f"source/target of relation {relation} not found")
  160. continue
  161. source_ent = data['entities'][source_idx]
  162. target_ent = data['entities'][target_idx]
  163. source_text = source_ent['text']
  164. source_type = source_ent['type']
  165. target_text = target_ent['text']
  166. target_type = target_ent['type']
  167. if source_text in entities_list:
  168. source_ent = entities_list[source_text]
  169. else:
  170. source_ent = None
  171. if target_text in entities_list:
  172. target_ent = entities_list[target_text]
  173. else:
  174. target_ent = None
  175. if source_ent and target_ent:
  176. source_id = source_ent['id']
  177. target_id = target_ent['id']
  178. relation_key = f"{source_id}/{source_type}-{type}->{target_id}/{target_type}"
  179. if relation_key in relations_list:
  180. continue
  181. relations_list[relation_key] = {"source_id":source_id,
  182. "source_name":source_text,
  183. "source_type":source_type,
  184. "target_id":target_id,
  185. "target_name":target_text,
  186. "target_type":target_type,
  187. "type":type}
  188. else:
  189. logger.info(f"relation {relation_key} not found")
  190. else:
  191. logger.info(f"relations not found in {file}")
  192. print(f"Done {file}")
  193. with open(os.path.join(kb_build_path,"entities.json"), "w", encoding="utf-8") as f:
  194. f.write(json.dumps(list(entities_list.values()), ensure_ascii=False,indent=4))
  195. with open(os.path.join(kb_build_path,"relations.json"), "w", encoding="utf-8") as f:
  196. f.write(json.dumps(list(relations_list.values()), ensure_ascii=False,indent=4))
  197. import_entities(job_id, entities_list, relations_list)
  198. print("Done")