123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116 |
- from fastapi import APIRouter, Depends, File, UploadFile, HTTPException
- from fastapi.responses import JSONResponse
- from sqlalchemy.orm import Session
- import csv
- import io
- from db.database import get_db
- from db.models import DbKgNode, DbKgEdge
- from utils.response import resp_200
- router = APIRouter(
- prefix="/api/data-import",
- )
- def save_nodes_to_db(rows, db):
- try:
- for row in rows:
- # Check if node already exists
- exists = db.query(DbKgNode).filter(
- DbKgNode.name == row['name'],
- DbKgNode.category == row['category'],
- DbKgNode.graph_id == int(row['graph_id'])
- ).first()
-
- if not exists:
- node = DbKgNode(
- category=row['category'],
- name=row['name'],
- graph_id=int(row['graph_id'])
- )
- db.add(node)
- db.commit()
- except Exception as e:
- db.rollback()
- raise e
- finally:
- db.close()
- def save_edges_to_db(rows, db):
- try:
- for row in rows:
- # Find source node
- source_node = db.query(DbKgNode).filter(
- DbKgNode.category == row['src_category'],
- DbKgNode.name == row['src_name'],
- DbKgNode.graph_id == int(row['graph_id'])
- ).first()
-
- # Find target node
- target_node = db.query(DbKgNode).filter(
- DbKgNode.category == row['target_category'],
- DbKgNode.name == row['target_name'],
- DbKgNode.graph_id == int(row['graph_id'])
- ).first()
-
- if not source_node or not target_node:
- continue # Skip if either node not found
-
- # Check if edge already exists
- exists = db.query(DbKgEdge).filter(
- DbKgEdge.graph_id == int(row['graph_id']),
- DbKgEdge.src_id == source_node.id,
- DbKgEdge.dest_id == target_node.id
- ).first()
-
- if not exists:
- edge = DbKgEdge(
- source_id=source_node.id,
- target_id=target_node.id,
- category=row['category'],
- name=row['name'],
- graph_id=int(row['graph_id'])
- )
- db.add(edge)
- db.commit()
- except Exception as e:
- db.rollback()
- raise e
- finally:
- db.close()
- @router.post("/nodes/{graph_id}")
- async def upload_nodes_csv(graph_id:int, file: UploadFile = File(...), db:Session = Depends(get_db)):
- if not file.filename.endswith('.csv'):
- return resp_200(data={"total": 0, 'error_code': 500, 'error_message':"Only CSV files are allowed"})
- try:
- contents = await file.read()
- csv_file = io.StringIO(contents.decode('utf-8'))
- reader = csv.DictReader(csv_file)
-
- for row in reader:
- row['graph_id'] = graph_id
- rows = [row for row in reader]
- save_nodes_to_db(rows)
- return resp_200(data={"total": len(rows)})
-
- except Exception as e:
- return resp_200(data={"total": 0, 'error_code': 500, 'error_message':str(e)})
- @router.post("/edges/{graph_id}")
- async def upload_edges_csv(graph_id:int,file: UploadFile = File(...)):
- if not file.filename.endswith('.csv'):
- return resp_200(data={"total": 0, 'error_code': 500, 'error_message':"Only CSV files are allowed"})
- try:
- contents = await file.read()
- csv_file = io.StringIO(contents.decode('utf-8'))
- reader = csv.DictReader(csv_file)
- for row in reader:
- row['graph_id'] = graph_id
- rows = [row for row in reader]
-
- save_edges_to_db(rows)
- return JSONResponse(content={"message": f"Successfully processed {len(rows)} edges"})
-
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
- data_import_router = router
|