Feat: 모델 업로드 API 구현
This commit is contained in:
parent
0503da4e41
commit
a98b3a0c9c
@ -1,8 +1,8 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, File, UploadFile
|
||||
from schemas.model_create_request import ModelCreateRequest
|
||||
from services.init_model import create_pretrained_model, create_default_model
|
||||
from services.init_model import create_pretrained_model, create_default_model, upload_tmp_model
|
||||
from services.load_model import load_model
|
||||
from utils.file_utils import get_model_paths, delete_file
|
||||
from utils.file_utils import get_model_paths, delete_file, join_path, save_file
|
||||
import re
|
||||
|
||||
router = APIRouter()
|
||||
@ -55,8 +55,29 @@ def model_delete(model_path:str):
|
||||
detail= "모델을 찾을 수 없습니다.")
|
||||
|
||||
@router.post("/upload")
|
||||
def model_upload():
|
||||
pass
|
||||
def upload_model(project_id:int, file: UploadFile = File(...)):
|
||||
# 확장자 확인
|
||||
if not file.filename.endswith(".pt"):
|
||||
raise HTTPException(status_code=400, detail="Only .pt files are allowed.")
|
||||
|
||||
tmp_path = join_path("resources", "models", "tmp-"+file.filename)
|
||||
|
||||
# 임시로 파일 저장
|
||||
try:
|
||||
save_file(tmp_path, file)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="file save exception: "+str(e))
|
||||
|
||||
# YOLO 모델 변환 및 저장
|
||||
try:
|
||||
model_path = upload_tmp_model(project_id, tmp_path)
|
||||
return {"model_path": model_path}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="file save exception: "+str(e))
|
||||
finally:
|
||||
# 임시파일 삭제
|
||||
delete_file(tmp_path)
|
||||
|
||||
|
||||
@router.get("/download")
|
||||
def model_download():
|
||||
|
@ -1,8 +1,9 @@
|
||||
from ultralytics import YOLO # Ultralytics YOLO 모델을 가져오기
|
||||
import os
|
||||
import uuid
|
||||
from services.load_model import load_model
|
||||
|
||||
def create_pretrained_model(project_id: id, type:str):
|
||||
def create_pretrained_model(project_id: int, type:str):
|
||||
suffix = ""
|
||||
if type in ["seg", "cls"]:
|
||||
suffix = "-"+type
|
||||
@ -24,5 +25,24 @@ def create_pretrained_model(project_id: id, type:str):
|
||||
|
||||
return model_path
|
||||
|
||||
def create_default_model(project_id: id, type:str, labels:list[str]):
|
||||
def create_default_model(project_id: int, type:str, labels:list[str]):
|
||||
pass
|
||||
|
||||
def upload_tmp_model(project_id: int, tmp_path:str):
|
||||
# 모델 불러오기
|
||||
model = load_model(tmp_path)
|
||||
|
||||
# 모델을 저장할 폴더 경로
|
||||
base_path = os.path.join("resources","projects",str(project_id),"models")
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
|
||||
# 고유값 id 생성
|
||||
unique_id = uuid.uuid4()
|
||||
while os.path.exists(os.path.join(base_path, f"{unique_id}.pt")):
|
||||
unique_id = uuid.uuid4()
|
||||
model_path = os.path.join(base_path, f"{unique_id}.pt")
|
||||
|
||||
# 기본 모델 저장
|
||||
model.save(filename=model_path)
|
||||
|
||||
return model_path
|
@ -158,4 +158,14 @@ def get_model_paths(project_id:int):
|
||||
return [os.path.join(path, file) for file in files if file.endswith(".pt")]
|
||||
|
||||
def delete_file(path):
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError()
|
||||
os.remove(path)
|
||||
|
||||
def save_file(path, file):
|
||||
# 경로에서 디렉토리 부분만 추출 (파일명을 제외한 경로)
|
||||
dir_path = os.path.dirname(path)
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
|
||||
with open(path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
Loading…
Reference in New Issue
Block a user