Feat: 모델 업로드 API 구현

This commit is contained in:
김진현 2024-09-18 16:09:18 +09:00
parent 0503da4e41
commit a98b3a0c9c
3 changed files with 60 additions and 9 deletions

View File

@ -1,8 +1,8 @@
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, File, UploadFile
from schemas.model_create_request import ModelCreateRequest
from services.init_model import create_pretrained_model, create_default_model
from services.init_model import create_pretrained_model, create_default_model, upload_tmp_model
from services.load_model import load_model
from utils.file_utils import get_model_paths, delete_file
from utils.file_utils import get_model_paths, delete_file, join_path, save_file
import re
router = APIRouter()
@ -55,8 +55,29 @@ def model_delete(model_path:str):
detail= "모델을 찾을 수 없습니다.")
@router.post("/upload")
def model_upload():
pass
def upload_model(project_id:int, file: UploadFile = File(...)):
# 확장자 확인
if not file.filename.endswith(".pt"):
raise HTTPException(status_code=400, detail="Only .pt files are allowed.")
tmp_path = join_path("resources", "models", "tmp-"+file.filename)
# 임시로 파일 저장
try:
save_file(tmp_path, file)
except Exception as e:
raise HTTPException(status_code=500, detail="file save exception: "+str(e))
# YOLO 모델 변환 및 저장
try:
model_path = upload_tmp_model(project_id, tmp_path)
return {"model_path": model_path}
except Exception as e:
raise HTTPException(status_code=500, detail="file save exception: "+str(e))
finally:
# 임시파일 삭제
delete_file(tmp_path)
@router.get("/download")
def model_download():

View File

@ -1,8 +1,9 @@
from ultralytics import YOLO # Ultralytics YOLO 모델을 가져오기
import os
import uuid
from services.load_model import load_model
def create_pretrained_model(project_id: id, type:str):
def create_pretrained_model(project_id: int, type:str):
suffix = ""
if type in ["seg", "cls"]:
suffix = "-"+type
@ -24,5 +25,24 @@ def create_pretrained_model(project_id: id, type:str):
return model_path
def create_default_model(project_id: id, type:str, labels:list[str]):
def create_default_model(project_id: int, type:str, labels:list[str]):
pass
def upload_tmp_model(project_id: int, tmp_path:str):
# 모델 불러오기
model = load_model(tmp_path)
# 모델을 저장할 폴더 경로
base_path = os.path.join("resources","projects",str(project_id),"models")
os.makedirs(base_path, exist_ok=True)
# 고유값 id 생성
unique_id = uuid.uuid4()
while os.path.exists(os.path.join(base_path, f"{unique_id}.pt")):
unique_id = uuid.uuid4()
model_path = os.path.join(base_path, f"{unique_id}.pt")
# 기본 모델 저장
model.save(filename=model_path)
return model_path

View File

@ -158,4 +158,14 @@ def get_model_paths(project_id:int):
return [os.path.join(path, file) for file in files if file.endswith(".pt")]
def delete_file(path):
if not os.path.exists(path):
raise FileNotFoundError()
os.remove(path)
def save_file(path, file):
# 경로에서 디렉토리 부분만 추출 (파일명을 제외한 경로)
dir_path = os.path.dirname(path)
os.makedirs(dir_path, exist_ok=True)
with open(path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)