standard_kb_build.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import os,sys
  2. import logging
  3. import json
  4. current_path = os.getcwd()
  5. sys.path.append(current_path)
  6. from agent.db.database import SessionLocal
  7. from agent.libs.graph import GraphBusiness
  8. graphBiz = GraphBusiness(db=SessionLocal())
  9. hi_index = 1
  10. low_index = 1
  11. def get_hi_lo_id():
  12. global hi_index, low_index
  13. if low_index < 10000:
  14. low_index += 1
  15. return hi_index * 10000 + low_index
  16. else:
  17. hi_index += 1
  18. low_index = 1
  19. return hi_index * 10000 + low_index
  20. def load_json_from_file(filename: str):
  21. """检查JSON文件格式是否正确"""
  22. try:
  23. with open(filename, 'r', encoding='utf-8') as f:
  24. content = f.read()
  25. buffer = []
  26. json_started = False
  27. for line in content.split("\n"):
  28. if line.strip()=="":
  29. continue
  30. if line.startswith("```json"):
  31. buffer = []
  32. json_started = True
  33. continue
  34. if line.startswith("```"):
  35. if json_started:
  36. return json.loads("\n".join(buffer))
  37. json_started = False
  38. buffer.append(line)
  39. return None
  40. except json.JSONDecodeError as e:
  41. logger.info(f"JSON格式错误: {e}")
  42. return None
  43. def parse_json(data):
  44. if 'entities' in data:
  45. entities = data['entities']
  46. for entity in entities:
  47. if len(entity) == 2:
  48. entity.append("")
  49. def import_entities(graph_id, entities_list, relations_list):
  50. from agent.libs.user_data_relation import UserDataRelationBusiness
  51. from agent.models.db.user import User, Role
  52. from agent.libs.user import UserBusiness
  53. # 获取job信息
  54. job = graphBiz.db.query(graphBiz.DbJob).filter(graphBiz.DbJob.id == graph_id).first()
  55. if not job:
  56. logger.error(f"Job not found with id: {graph_id}")
  57. return entities
  58. # 从job_creator中提取user_id
  59. user_id = int(job.job_creator.split('/')[1])
  60. # 获取用户角色
  61. user_biz = UserBusiness(graphBiz.db)
  62. user = user_biz.get_user(user_id)
  63. if not user or not user.roles:
  64. logger.error(f"User {user_id} has no roles assigned")
  65. return entities
  66. role_id = user.roles[0].id
  67. user_name = user.name
  68. role_name = user.roles[0].name
  69. # 创建用户数据关系业务对象
  70. relation_biz = UserDataRelationBusiness(graphBiz.db)
  71. for text, ent in entities_list.items():
  72. id = ent['id']
  73. name = ent['name']
  74. type = ent['type']
  75. full_name = name
  76. if len(name) > 64:
  77. name = name[:64]
  78. logger.info(f"create node: {ent}")
  79. node = graphBiz.create_node(graph_id=graph_id, name=name, category=type[0], props={'types':",".join(type),'full_name':full_name})
  80. if node:
  81. ent["db_id"] = node.id
  82. # 创建节点数据关联
  83. relation_biz.create_relation(user_id, 'DbKgNode', node.id, role_id, user_name, role_name)
  84. for text, relations in relations_list.items():
  85. source_name = relations['source_name']
  86. source_type = relations['source_type']
  87. target_name = relations['target_name']
  88. target_type = relations['target_type']
  89. relation_type = relations['type']
  90. source_db_id = entities_list[source_name]['db_id']
  91. target_db_id = entities_list[target_name]['db_id']
  92. edge = graphBiz.create_edge(graph_id=graph_id,
  93. src_id=source_db_id,
  94. dest_id=target_db_id,
  95. name=relation_type,
  96. category=relation_type,
  97. props={
  98. "src_type":source_type,
  99. "dest_type":target_type,
  100. })
  101. logger.info(f"create edge: {source_db_id}->{target_db_id}")
  102. # 创建边数据关联
  103. if edge:
  104. relation_biz.create_relation(user_id, 'DbKgEdge', edge.id, role_id, user_name, role_name)
  105. return entities
  106. if __name__ == "__main__":
  107. if len(sys.argv) != 2:
  108. print("Usage: python standard_kb_cbuild.py <path_of_job> <graph_id>")
  109. sys.exit(-1)
  110. job_path = sys.argv[1]
  111. if not os.path.exists(job_path):
  112. print(f"job path not exists: {job_path}")
  113. sys.exit(-1)
  114. kb_path = os.path.join(job_path,"kb_extract")
  115. if not os.path.exists(kb_path):
  116. print(f"kb path not exists: {kb_path}")
  117. sys.exit(-1)
  118. kb_build_path = os.path.join(job_path,"kb_build")
  119. job_id = int(job_path.split("/")[-1])
  120. os.makedirs(kb_build_path,exist_ok=True)
  121. log_path = os.path.join(job_path,"logs")
  122. print(f"log path: {log_path}")
  123. handler = logging.FileHandler(f"{log_path}/graph_build.log", mode='a',encoding="utf-8")
  124. handler.setLevel(logging.INFO)
  125. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  126. handler.setFormatter(formatter)
  127. logging.getLogger().addHandler(handler)
  128. logger = logging.getLogger(__name__)
  129. entities_list = {}
  130. relations_list = {}
  131. for root,dirs,files in os.walk(kb_path):
  132. for file in files:
  133. if file.endswith(".txt"):
  134. logger.info(f"Processing {file}")
  135. data = load_json_from_file(filename=os.path.join(root,file))
  136. if data is None:
  137. continue
  138. if 'entities' in data:
  139. entities = data['entities']
  140. for entity in entities:
  141. text = entity['text']
  142. type = entity['type']
  143. position = entity['position']
  144. if text in entities_list:
  145. ent = entities_list[text]
  146. if type in ent['type']:
  147. continue
  148. ent['type'].append(type)
  149. else:
  150. ent = {"id": get_hi_lo_id(), "name":text,"type":[type]}
  151. entities_list[text] = ent
  152. else:
  153. logger.info(f"entities not found in {file}")
  154. if "relations" in data:
  155. relations = data['relations']
  156. for relation in relations:
  157. source_idx = relation['source']
  158. target_idx = relation['target']
  159. type = relation['type']
  160. if source_idx >= len(data['entities']) or target_idx >= len(data['entities']):
  161. logger.info(f"source/target of relation {relation} not found")
  162. continue
  163. source_ent = data['entities'][source_idx]
  164. target_ent = data['entities'][target_idx]
  165. source_text = source_ent['text']
  166. source_type = source_ent['type']
  167. target_text = target_ent['text']
  168. target_type = target_ent['type']
  169. if source_text in entities_list:
  170. source_ent = entities_list[source_text]
  171. else:
  172. source_ent = None
  173. if target_text in entities_list:
  174. target_ent = entities_list[target_text]
  175. else:
  176. target_ent = None
  177. if source_ent and target_ent:
  178. source_id = source_ent['id']
  179. target_id = target_ent['id']
  180. relation_key = f"{source_id}/{source_type}-{type}->{target_id}/{target_type}"
  181. if relation_key in relations_list:
  182. continue
  183. relations_list[relation_key] = {"source_id":source_id,
  184. "source_name":source_text,
  185. "source_type":source_type,
  186. "target_id":target_id,
  187. "target_name":target_text,
  188. "target_type":target_type,
  189. "type":type}
  190. else:
  191. logger.info(f"relation {relation_key} not found")
  192. else:
  193. logger.info(f"relations not found in {file}")
  194. print(f"Done {file}")
  195. with open(os.path.join(kb_build_path,"entities.json"), "w", encoding="utf-8") as f:
  196. f.write(json.dumps(list(entities_list.values()), ensure_ascii=False,indent=4))
  197. with open(os.path.join(kb_build_path,"relations.json"), "w", encoding="utf-8") as f:
  198. f.write(json.dumps(list(relations_list.values()), ensure_ascii=False,indent=4))
  199. import_entities(job_id, entities_list, relations_list)
  200. print("Done")