data_export.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from fastapi import APIRouter, Depends, HTTPException, Path,Request
  2. from fastapi.responses import JSONResponse, FileResponse
  3. from pathlib import Path as PathLib
  4. from config.site import TEMP_STORAGE_PATH
  5. from sqlalchemy.orm import Session
  6. import csv
  7. import io
  8. import json
  9. import os
  10. from datetime import datetime
  11. from db.database import get_db
  12. from db.models import DbKgNode, DbKgEdge, DbKgTask
  13. from utils.response import resp_200
  14. router = APIRouter(
  15. prefix="/api/data-export",
  16. )
  17. def generate_csv(rows, fieldnames):
  18. """生成CSV文件流"""
  19. output = io.StringIO()
  20. writer = csv.DictWriter(output, fieldnames=fieldnames)
  21. writer.writeheader()
  22. for row in rows:
  23. writer.writerow(row)
  24. output.seek(0)
  25. return output
  26. @router.get("/all/{category}/{graph_id}")
  27. def export_nodes_csv(category:str, graph_id: int, db: Session = Depends(get_db)):
  28. try:
  29. taskContent = ""
  30. if category == 'graph':
  31. taskContent=json.dumps({"graph_id": graph_id})
  32. if category == 'labeling':
  33. taskContent=json.dumps({"proj_id": graph_id})
  34. print(taskContent)
  35. # 创建任务记录
  36. task = DbKgTask(
  37. proj_id=1,
  38. task_category=category,
  39. task_name="data_export",
  40. task_content=taskContent,
  41. status=0,
  42. created=datetime.now(),
  43. updated=datetime.now()
  44. )
  45. db.add(task)
  46. db.commit()
  47. db.refresh(task)
  48. return resp_200(data={"task_id": task.id, 'error_code': 0, 'error_message':'任务创建成功'})
  49. except Exception as e:
  50. db.rollback()
  51. return resp_200(data={"task_id":0 , 'error_code': 500, 'error_message':str(e)})
  52. @router.get("/download/{filename}")
  53. def download_file(
  54. filename: str = Path(..., regex=r"^[a-zA-Z0-9_\-\.]+$"),
  55. ):
  56. """下载文件接口"""
  57. # 基础路径
  58. base_path = PathLib(TEMP_STORAGE_PATH)
  59. # 拼接完整路径
  60. file_path = base_path / filename
  61. print(f"Attempting to access file at path: {file_path}")
  62. # 安全检查
  63. try:
  64. file_path.resolve().relative_to(base_path)
  65. except ValueError:
  66. raise HTTPException(status_code=400, detail="Invalid file path")
  67. # 检查文件是否存在
  68. if not file_path.exists():
  69. raise HTTPException(status_code=404, detail="File not found")
  70. # 支持断点续传
  71. return FileResponse(
  72. path=file_path,
  73. filename=filename,
  74. headers={"Accept-Ranges": "bytes"},
  75. )
  76. @router.get("/list-routes")
  77. def list_routes(request: Request):
  78. """列出所有已注册的路由"""
  79. route_list = []
  80. for route in request.app.routes:
  81. route_info = {
  82. "path": route.path,
  83. "methods": list(route.methods) if hasattr(route, 'methods') else ['N/A'],
  84. "name": route.name
  85. }
  86. route_list.append(route_info)
  87. return {"routes": route_list}
  88. data_export_router = router