data_import.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from fastapi import APIRouter, Depends, File, UploadFile, HTTPException
  2. from fastapi.responses import JSONResponse
  3. from sqlalchemy.orm import Session
  4. import csv
  5. import io
  6. from db.database import get_db
  7. from db.models import DbKgNode, DbKgEdge
  8. from utils.response import resp_200
  9. router = APIRouter(
  10. prefix="/api/data-import",
  11. )
  12. def save_nodes_to_db(rows, db):
  13. try:
  14. for row in rows:
  15. # Check if node already exists
  16. exists = db.query(DbKgNode).filter(
  17. DbKgNode.name == row['name'],
  18. DbKgNode.category == row['category'],
  19. DbKgNode.graph_id == int(row['graph_id'])
  20. ).first()
  21. if not exists:
  22. node = DbKgNode(
  23. category=row['category'],
  24. name=row['name'],
  25. graph_id=int(row['graph_id'])
  26. )
  27. db.add(node)
  28. db.commit()
  29. except Exception as e:
  30. db.rollback()
  31. raise e
  32. finally:
  33. db.close()
  34. def save_edges_to_db(rows, db):
  35. try:
  36. for row in rows:
  37. # Find source node
  38. source_node = db.query(DbKgNode).filter(
  39. DbKgNode.category == row['src_category'],
  40. DbKgNode.name == row['src_name'],
  41. DbKgNode.graph_id == int(row['graph_id'])
  42. ).first()
  43. # Find target node
  44. target_node = db.query(DbKgNode).filter(
  45. DbKgNode.category == row['target_category'],
  46. DbKgNode.name == row['target_name'],
  47. DbKgNode.graph_id == int(row['graph_id'])
  48. ).first()
  49. if not source_node or not target_node:
  50. continue # Skip if either node not found
  51. # Check if edge already exists
  52. exists = db.query(DbKgEdge).filter(
  53. DbKgEdge.graph_id == int(row['graph_id']),
  54. DbKgEdge.src_id == source_node.id,
  55. DbKgEdge.dest_id == target_node.id
  56. ).first()
  57. if not exists:
  58. edge = DbKgEdge(
  59. source_id=source_node.id,
  60. target_id=target_node.id,
  61. category=row['category'],
  62. name=row['name'],
  63. graph_id=int(row['graph_id'])
  64. )
  65. db.add(edge)
  66. db.commit()
  67. except Exception as e:
  68. db.rollback()
  69. raise e
  70. finally:
  71. db.close()
  72. @router.post("/nodes/{graph_id}")
  73. async def upload_nodes_csv(graph_id:int, file: UploadFile = File(...), db:Session = Depends(get_db)):
  74. if not file.filename.endswith('.csv'):
  75. return resp_200(data={"total": 0, 'error_code': 500, 'error_message':"Only CSV files are allowed"})
  76. try:
  77. contents = await file.read()
  78. csv_file = io.StringIO(contents.decode('utf-8'))
  79. reader = csv.DictReader(csv_file)
  80. for row in reader:
  81. row['graph_id'] = graph_id
  82. rows = [row for row in reader]
  83. save_nodes_to_db(rows)
  84. return resp_200(data={"total": len(rows)})
  85. except Exception as e:
  86. return resp_200(data={"total": 0, 'error_code': 500, 'error_message':str(e)})
  87. @router.post("/edges/{graph_id}")
  88. async def upload_edges_csv(graph_id:int,file: UploadFile = File(...)):
  89. if not file.filename.endswith('.csv'):
  90. return resp_200(data={"total": 0, 'error_code': 500, 'error_message':"Only CSV files are allowed"})
  91. try:
  92. contents = await file.read()
  93. csv_file = io.StringIO(contents.decode('utf-8'))
  94. reader = csv.DictReader(csv_file)
  95. for row in reader:
  96. row['graph_id'] = graph_id
  97. rows = [row for row in reader]
  98. save_edges_to_db(rows)
  99. return JSONResponse(content={"message": f"Successfully processed {len(rows)} edges"})
  100. except Exception as e:
  101. raise HTTPException(status_code=500, detail=str(e))
  102. data_import_router = router