labeling.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from fastapi import APIRouter, Depends, HTTPException, UploadFile
  2. from sqlalchemy.orm import Session
  3. from sqlalchemy.sql import select, or_, and_, func, distinct
  4. from db.schemas import KgTask, KgProj, KgProjCreate, KgTaskCreate, KgTaskUpdate
  5. from db.models import DbKgProj, DbKgTask
  6. from db.database import get_db
  7. from models.response import ResponseModel
  8. from utils.response import resp_200
  9. from typing import List
  10. from datetime import datetime
  11. from math import ceil
  12. from config.site import FILE_STORAGE_PATH
  13. import os
  14. import json
  15. router = APIRouter()
  16. def create_task( proj_id, task_category, content: str,db:Session):
  17. print("create task")
  18. data = DbKgTask()
  19. data.proj_id = proj_id
  20. data.task_category = task_category
  21. data.task_log = ""
  22. data.task_content = content
  23. data.status = 0
  24. data.created = datetime.now()
  25. data.updated = datetime.now()
  26. db.add(data)
  27. db.commit()
  28. db.refresh(data)
  29. return data
  30. # 标注数据文件上传
  31. @router.post("/api/labeling-file-upload/{proj_id}")
  32. async def create_upload_file(proj_id: int, file: UploadFile, db: Session = Depends(get_db)):
  33. path = FILE_STORAGE_PATH + "/tasks_file"
  34. if not os.path.exists(path):
  35. os.makedirs(path)
  36. # 打印文件名称
  37. new_filename = path + "/" + file.filename
  38. # 将上传的文件保存到服务本地
  39. with open(f"{new_filename}", 'wb') as f:
  40. # 一次读取1024字节,循环读取写入
  41. for chunk in iter(lambda: file.file.read(1024), b''):
  42. f.write(chunk)
  43. with open(f"{new_filename}", 'r', encoding="utf-8") as f:
  44. buf_str = ""
  45. for line in f.readlines():
  46. line = line.strip()
  47. if len(buf_str) > 0:
  48. buf_str = buf_str + "\n" + line
  49. else:
  50. buf_str = buf_str + line
  51. while len(buf_str) > 256:
  52. chunk = buf_str[0:256]
  53. buf_str = buf_str[256:]
  54. if chunk[-1] != '\n' and chunk[-1] != "。":
  55. last_end1 = chunk.rfind("。")
  56. last_end2 = chunk.rfind("\n")
  57. if last_end1 == -1 and last_end2 == -1:
  58. last_end1 = len(chunk)
  59. #print("found end char: ", last_end1, last_end2)
  60. #print("*" * 60)
  61. if last_end2 > last_end1:
  62. last_end1 = last_end2
  63. buf_str = chunk[last_end1 + 1:] + buf_str
  64. chunk = chunk[0:last_end1 + 1]
  65. #print(len(chunk),chunk)
  66. create_task(proj_id, "NLP", chunk, db)
  67. if len(buf_str) > 0:
  68. create_task(proj_id, "NLP", buf_str, db)
  69. return resp_200(data={"filename": file.filename})
  70. labeling_router = router