Refactor: 기본 모델 생성 구현 및 모델 생성 리팩토링

This commit is contained in:
김진현 2024-09-19 01:15:41 +09:00
parent 941df49f13
commit bcae5d5255
3 changed files with 12 additions and 22 deletions

View File

@ -1,6 +1,6 @@
from fastapi import APIRouter, HTTPException, File, UploadFile
from schemas.model_create_request import ModelCreateRequest
from services.init_model import create_pretrained_model, create_default_model, upload_tmp_model
from services.create_model import create_new_model, upload_tmp_model
from services.load_model import load_model
from utils.file_utils import get_model_paths, delete_file, join_path, save_file, get_file_name
import re
@ -32,19 +32,15 @@ def get_model_list(project_id:int):
detail= "프로젝트가 찾을 수 없거나 생성된 모델이 없습니다.")
@router.post("/create", status_code=201)
def model_create(request: ModelCreateRequest):
def create_model(request: ModelCreateRequest):
if request.type not in ["seg", "det", "cls"]:
raise HTTPException(status_code=400,
detail= f"Invalid type '{request.type}'. Must be one of \"seg\", \"det\", \"cls\".")
if request.pretrained:
model_path = create_pretrained_model(request.project_id, request.type)
else:
labels = list(map(lambda x:x.label, request.labelCategories))
model_path = create_default_model(request.project_id, request.type, labels)
model_path = create_new_model(request.project_id, request.type, request.pretrained)
return {"model_path": model_path}
@router.delete("/delete", status_code=204)
def model_delete(model_path:str):
def delete_model(model_path:str):
pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$'
if not re.match(pattern, model_path):
raise HTTPException(status_code=400,
@ -81,7 +77,7 @@ def upload_model(project_id:int, file: UploadFile = File(...)):
@router.get("/download")
async def download_model(model_path: str):
def download_model(model_path: str):
pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$'
if not re.match(pattern, model_path):
raise HTTPException(status_code=400,

View File

@ -1,13 +1,6 @@
from pydantic import BaseModel
from typing import Optional
class LabelCategory(BaseModel):
id:int
label:str
class ModelCreateRequest(BaseModel):
project_id: int
type: str
pretrained:bool = True
labelCategories:Optional[list[LabelCategory]] = None
pretrained:bool = True

View File

@ -3,12 +3,16 @@ import os
import uuid
from services.load_model import load_model
def create_pretrained_model(project_id: int, type:str):
def create_new_model(project_id: int, type:str, pretrained:bool):
suffix = ""
if type in ["seg", "cls"]:
suffix = "-"+type
# 학습된 기본 모델 로드
model = YOLO(os.path.join("resources", "models" ,f"yolov8n{suffix}.pt"))
if pretrained:
suffix += ".pt"
else:
suffix += ".yaml"
model = YOLO(os.path.join("resources", "models" ,f"yolov8n{suffix}"))
# 모델을 저장할 폴더 경로
base_path = os.path.join("resources","projects",str(project_id),"models")
@ -25,9 +29,6 @@ def create_pretrained_model(project_id: int, type:str):
return model_path
def create_default_model(project_id: int, type:str, labels:list[str]):
pass
def upload_tmp_model(project_id: int, tmp_path:str):
# 모델 불러오기
model = load_model(tmp_path)