standard_kb_build.py 12 KB


  1. import os,sys
  2. import logging
  3. import json
  4. current_path = os.getcwd()
  5. sys.path.append(current_path)
  6. from agent.db.database import SessionLocal
  7. from agent.libs.graph import GraphBusiness
  8. graphBiz = GraphBusiness(db=SessionLocal())
  9. hi_index = 1
  10. low_index = 1
  11. def get_hi_lo_id():
  12. global hi_index, low_index
  13. if low_index < 10000:
  14. low_index += 1
  15. return hi_index * 10000 + low_index
  16. else:
  17. hi_index += 1
  18. low_index = 1
  19. return hi_index * 10000 + low_index
  20. def load_json_from_file(filename: str):
  21. """检查JSON文件格式是否正确"""
  22. try:
  23. with open(filename, 'r', encoding='utf-8') as f:
  24. content = f.read()
  25. buffer = []
  26. json_started = False
  27. for line in content.split("\n"):
  28. if line.strip()=="":
  29. continue
  30. if line.startswith("```json"):
  31. buffer = []
  32. json_started = True
  33. continue
  34. if line.startswith("```"):
  35. if json_started:
  36. return json.loads("\n".join(buffer))
  37. json_started = False
  38. buffer.append(line)
  39. return None
  40. except json.JSONDecodeError as e:
  41. logger.info(f"JSON格式错误: {e}")
  42. return None
  43. def parse_json(data):
  44. if 'entities' in data:
  45. entities = data['entities']
  46. for entity in entities:
  47. if len(entity) == 2:
  48. entity.append("")
  49. def import_entities(graph_id, entities_list, relations_list):
  50. from agent.libs.user_data_relation import UserDataRelationBusiness
  51. from agent.libs.user import UserBusiness
  52. from agent.libs.agent import AgentBusiness
  53. agent_biz = AgentBusiness(db=SessionLocal())
  54. # 获取job信息
  55. job = agent_biz.get_job(graph_id)
  56. if not job:
  57. logger.error(f"Job not found with id: {graph_id}")
  58. return entities
  59. # 从job_creator中提取user_id
  60. user_id = int(job.job_creator.split('/')[1])
  61. # 获取用户角色
  62. user_biz = UserBusiness(db=SessionLocal())
  63. user = user_biz.get_user(user_id)
  64. if not user or not user.roles:
  65. logger.error(f"User {user_id} has no roles assigned")
  66. return entities
  67. role_id = user.roles[0].id
  68. user_name = user.name
  69. role_name = user.roles[0].name
  70. # 创建用户数据关系业务对象
  71. relation_biz = UserDataRelationBusiness(db=SessionLocal())
  72. # 创建子图谱数据关联
  73. relation_biz.create_relation(user_id, 'sub_graph', graph_id, role_id)
  74. for text, ent in entities_list.items():
  75. id = ent['id']
  76. name = ent['name']
  77. type = ent['type']
  78. full_name = name
  79. if len(name) > 64:
  80. name = name[:64]
  81. logger.info(f"create node: {ent}")
  82. node = graphBiz.create_node(graph_id=graph_id, name=name, category=type[0], props={'types':",".join(type),'full_name':full_name})
  83. if node:
  84. ent["db_id"] = node.id
  85. # 创建节点数据关联
  86. relation_biz.create_relation(user_id, 'DbKgNode', node.id, role_id, user_name, role_name)
  87. for text, relations in relations_list.items():
  88. source_name = relations['source_name']
  89. source_type = relations['source_type']
  90. target_name = relations['target_name']
  91. target_type = relations['target_type']
  92. relation_type = relations['type']
  93. source_db_id = entities_list[source_name]['db_id']
  94. target_db_id = entities_list[target_name]['db_id']
  95. edge = graphBiz.create_edge(graph_id=graph_id,
  96. src_id=source_db_id,
  97. dest_id=target_db_id,
  98. name=relation_type,
  99. category=relation_type,
  100. props={
  101. "src_type":source_type,
  102. "dest_type":target_type,
  103. })
  104. logger.info(f"create edge: {source_db_id}->{target_db_id}")
  105. # 创建边数据关联
  106. if edge:
  107. relation_biz.create_relation(user_id, 'DbKgEdge', edge.id, role_id, user_name, role_name)
  108. # 构建树状结构并存储到数据库
  109. disease_nodes = [ent for ent in entities_list.values() if '疾病' in ent.get('type', [])]
  110. if disease_nodes:
  111. # 按照名称字母排序
  112. disease_nodes.sort(key=lambda x: x['name'])
  113. # 构建树状结构
  114. tree_structure = {"name": "疾病", "sNode": []}
  115. current_letter = None
  116. letter_group = None
  117. for node in disease_nodes:
  118. first_letter = node['name'][0].upper()
  119. if first_letter != current_letter:
  120. current_letter = first_letter
  121. letter_group = {"name": current_letter, "sNode": []}
  122. tree_structure["sNode"].append(letter_group)
  123. letter_group["sNode"].append({"name": node['name'], "sNode": []})
  124. # 存储到数据库
  125. from agent.models.db.tree_structure import TreeStructure
  126. db = SessionLocal()
  127. try:
  128. tree_record = TreeStructure(
  129. user_id=user_id,
  130. graph_id=graph_id,
  131. content=json.dumps(tree_structure, ensure_ascii=False)
  132. )
  133. db.add(tree_record)
  134. db.commit()
  135. except Exception as e:
  136. db.rollback()
  137. logger.error(f"Failed to save tree structure: {e}")
  138. finally:
  139. db.close()
  140. # 构建树状结构并存储到数据库
  141. disease_nodes = [ent for ent in entities_list.values() if '疾病' in ent.get('type', [])]
  142. if disease_nodes:
  143. # 按照名称字母排序
  144. disease_nodes.sort(key=lambda x: x['name'])
  145. # 构建树状结构
  146. tree_structure = {"name": "症状", "sNode": []}
  147. current_letter = None
  148. letter_group = None
  149. for node in disease_nodes:
  150. first_letter = node['name'][0].upper()
  151. if first_letter != current_letter:
  152. current_letter = first_letter
  153. letter_group = {"name": current_letter, "sNode": []}
  154. tree_structure["sNode"].append(letter_group)
  155. letter_group["sNode"].append({"name": node['name'], "sNode": []})
  156. # 存储到数据库
  157. from agent.models.db.tree_structure import TreeStructure
  158. db = SessionLocal()
  159. try:
  160. tree_record = TreeStructure(
  161. user_id=user_id,
  162. graph_id=graph_id,
  163. content=json.dumps(tree_structure, ensure_ascii=False)
  164. )
  165. db.add(tree_record)
  166. db.commit()
  167. except Exception as e:
  168. db.rollback()
  169. logger.error(f"Failed to save tree structure: {e}")
  170. finally:
  171. db.close()
  172. return entities
  173. if __name__ == "__main__":
  174. if len(sys.argv) != 2:
  175. print("Usage: python standard_kb_cbuild.py <path_of_job> <graph_id>")
  176. sys.exit(-1)
  177. job_path = sys.argv[1]
  178. if not os.path.exists(job_path):
  179. print(f"job path not exists: {job_path}")
  180. sys.exit(-1)
  181. kb_path = os.path.join(job_path,"kb_extract")
  182. if not os.path.exists(kb_path):
  183. print(f"kb path not exists: {kb_path}")
  184. sys.exit(-1)
  185. kb_build_path = os.path.join(job_path,"kb_build")
  186. job_id = int(job_path.split("/")[-1])
  187. os.makedirs(kb_build_path,exist_ok=True)
  188. log_path = os.path.join(job_path,"logs")
  189. print(f"log path: {log_path}")
  190. handler = logging.FileHandler(f"{log_path}/graph_build.log", mode='a',encoding="utf-8")
  191. handler.setLevel(logging.INFO)
  192. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  193. handler.setFormatter(formatter)
  194. logging.getLogger().addHandler(handler)
  195. logger = logging.getLogger(__name__)
  196. entities_list = {}
  197. relations_list = {}
  198. for root,dirs,files in os.walk(kb_path):
  199. for file in files:
  200. if file.endswith(".txt"):
  201. logger.info(f"Processing {file}")
  202. data = load_json_from_file(filename=os.path.join(root,file))
  203. if data is None:
  204. continue
  205. if 'entities' in data:
  206. entities = data['entities']
  207. for entity in entities:
  208. text = entity['text']
  209. type = entity['type']
  210. position = entity['position']
  211. if text in entities_list:
  212. ent = entities_list[text]
  213. if type in ent['type']:
  214. continue
  215. ent['type'].append(type)
  216. else:
  217. ent = {"id": get_hi_lo_id(), "name":text,"type":[type]}
  218. entities_list[text] = ent
  219. else:
  220. logger.info(f"entities not found in {file}")
  221. if "relations" in data:
  222. relations = data['relations']
  223. for relation in relations:
  224. source_idx = relation['source']
  225. target_idx = relation['target']
  226. type = relation['type']
  227. if source_idx >= len(data['entities']) or target_idx >= len(data['entities']):
  228. logger.info(f"source/target of relation {relation} not found")
  229. continue
  230. source_ent = data['entities'][source_idx]
  231. target_ent = data['entities'][target_idx]
  232. source_text = source_ent['text']
  233. source_type = source_ent['type']
  234. target_text = target_ent['text']
  235. target_type = target_ent['type']
  236. if source_text in entities_list:
  237. source_ent = entities_list[source_text]
  238. else:
  239. source_ent = None
  240. if target_text in entities_list:
  241. target_ent = entities_list[target_text]
  242. else:
  243. target_ent = None
  244. if source_ent and target_ent:
  245. source_id = source_ent['id']
  246. target_id = target_ent['id']
  247. relation_key = f"{source_id}/{source_type}-{type}->{target_id}/{target_type}"
  248. if relation_key in relations_list:
  249. continue
  250. relations_list[relation_key] = {"source_id":source_id,
  251. "source_name":source_text,
  252. "source_type":source_type,
  253. "target_id":target_id,
  254. "target_name":target_text,
  255. "target_type":target_type,
  256. "type":type}
  257. else:
  258. logger.info(f"relation {relation_key} not found")
  259. else:
  260. logger.info(f"relations not found in {file}")
  261. print(f"Done {file}")
  262. with open(os.path.join(kb_build_path,"entities.json"), "w", encoding="utf-8") as f:
  263. f.write(json.dumps(list(entities_list.values()), ensure_ascii=False,indent=4))
  264. with open(os.path.join(kb_build_path,"relations.json"), "w", encoding="utf-8") as f:
  265. f.write(json.dumps(list(relations_list.values()), ensure_ascii=False,indent=4))
  266. import_entities(job_id, entities_list, relations_list)
  267. print("Done")