background_job.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import time
  2. import json
  3. import csv
  4. import os
  5. from datetime import datetime
  6. from sqlalchemy.orm import Session
  7. from db.database import SessionLocal
  8. from db.models import DbKgNode, DbKgEdge, DbKgTask
  9. from utils.files import zip_files
  10. def process_export_task(db: Session, task: DbKgTask):
  11. try:
  12. # 解析任务参数
  13. params = json.loads(task.task_content)
  14. graph_id = params["graph_id"]
  15. # 更新任务状态为执行中
  16. task.status = 1
  17. task.updated = datetime.now()
  18. db.commit()
  19. # 根据任务类型执行导出
  20. if task.task_category == "data_export":
  21. # 确保导出目录存在
  22. export_dir = "/home/tmp"
  23. os.makedirs(export_dir, exist_ok=True)
  24. # export nodes data
  25. filename = f"nodes_{task.id}.csv"
  26. fieldnames = ["category", "name"]
  27. filepath1 = os.path.join(export_dir, filename)
  28. with open(filepath1, "w", newline="") as f:
  29. writer = csv.DictWriter(f, fieldnames=fieldnames)
  30. writer.writeheader()
  31. start = 1
  32. page_size = 100
  33. count = db.query(DbKgNode).filter(DbKgNode.graph_id == graph_id, DbKgNode.status ==0).count()
  34. results = db.query(DbKgNode).filter(DbKgNode.graph_id == graph_id, DbKgNode.status ==0).limit(page_size).offset(start).all()
  35. rows = []
  36. while (len(results) > 0):
  37. print (f"process {start}/{count}")
  38. for node in results:
  39. row = {
  40. "category": node.category,
  41. "name": node.name
  42. }
  43. rows.append(row)
  44. writer.writerows(rows)
  45. start = start + len(results)
  46. results = db.query(DbKgNode).filter(DbKgNode.graph_id == graph_id, DbKgNode.status ==0).limit(page_size).offset(start).all()
  47. rows = []
  48. # export edges data
  49. filename = f"edges_{task.id}.csv"
  50. fieldnames = [
  51. "src_category", "src_name",
  52. "dest_category", "dest_name",
  53. "category", "name", "graph_id"
  54. ]
  55. filepath2 = os.path.join(export_dir, filename)
  56. with open(filepath2, "w", newline="") as f:
  57. writer = csv.DictWriter(f, fieldnames=fieldnames)
  58. writer.writeheader()
  59. start = 1
  60. page_size = 100
  61. count = db.query(DbKgEdge).filter(DbKgEdge.graph_id == graph_id, DbKgEdge.status ==0).count()
  62. results = db.query(DbKgEdge).filter(DbKgEdge.graph_id == graph_id, DbKgEdge.status ==0).limit(page_size).offset(start).all()
  63. rows = []
  64. while (len(results) > 0):
  65. print (f"process {start}/{count}")
  66. for edge in results:
  67. src_node = edge.src_node
  68. dest_node = edge.dest_node
  69. rows.append({
  70. "src_category": src_node.category,
  71. "src_name": src_node.name,
  72. "dest_category": dest_node.category,
  73. "dest_name": dest_node.name,
  74. "category": edge.category,
  75. "name": edge.name,
  76. })
  77. writer.writerows(rows)
  78. start = start + len(results)
  79. results = db.query(DbKgEdge).filter(DbKgEdge.graph_id == graph_id, DbKgEdge.status ==0).limit(page_size).offset(start).all()
  80. rows = []
  81. results = db.query(DbKgNode).limit(page_size).offset(start).all()
  82. # 更新任务状态为完成
  83. task.status = 2
  84. task.updated = datetime.now()
  85. db.commit()
  86. filename = f"nodes_{task.id}.zip"
  87. filepath3 = os.path.join(export_dir, filename)
  88. if (zip_files(file_paths=[filepath1, filepath2], output_zip_path=filepath3)):
  89. task.status = 999
  90. params['output_file'] = filename
  91. task.task_content = json.dumps(params)
  92. task.updated = datetime.now()
  93. db.commit()
  94. except Exception as e:
  95. # 任务失败处理
  96. task.status = -1
  97. task.task_content = json.dumps({
  98. "error": str(e),
  99. **json.loads(task.task_content)
  100. })
  101. task.update_time = datetime.now()
  102. db.commit()
  103. raise
  104. def task_worker():
  105. print("connect to database")
  106. db = SessionLocal()
  107. try:
  108. while True:
  109. # 查询待处理任务
  110. tasks = db.query(DbKgTask).filter(
  111. DbKgTask.proj_id == 1,
  112. DbKgTask.status == 0
  113. ).all()
  114. for task in tasks:
  115. print(f"process task {task.id}:{task.task_category}")
  116. try:
  117. process_export_task(db, task)
  118. except Exception as e:
  119. print(f"任务处理失败: {e}")
  120. continue
  121. print("sleep")
  122. time.sleep(10)
  123. finally:
  124. db.close()
  125. if __name__ == "__main__":
  126. task_worker()