Refactor: 기본 모델 생성 구현 및 모델 생성 리팩토링
This commit is contained in:
parent
941df49f13
commit
bcae5d5255
@ -1,6 +1,6 @@
|
||||
from fastapi import APIRouter, HTTPException, File, UploadFile
|
||||
from schemas.model_create_request import ModelCreateRequest
|
||||
from services.init_model import create_pretrained_model, create_default_model, upload_tmp_model
|
||||
from services.create_model import create_new_model, upload_tmp_model
|
||||
from services.load_model import load_model
|
||||
from utils.file_utils import get_model_paths, delete_file, join_path, save_file, get_file_name
|
||||
import re
|
||||
@ -32,19 +32,15 @@ def get_model_list(project_id:int):
|
||||
detail= "프로젝트가 찾을 수 없거나 생성된 모델이 없습니다.")
|
||||
|
||||
@router.post("/create", status_code=201)
|
||||
def model_create(request: ModelCreateRequest):
|
||||
def create_model(request: ModelCreateRequest):
|
||||
if request.type not in ["seg", "det", "cls"]:
|
||||
raise HTTPException(status_code=400,
|
||||
detail= f"Invalid type '{request.type}'. Must be one of \"seg\", \"det\", \"cls\".")
|
||||
if request.pretrained:
|
||||
model_path = create_pretrained_model(request.project_id, request.type)
|
||||
else:
|
||||
labels = list(map(lambda x:x.label, request.labelCategories))
|
||||
model_path = create_default_model(request.project_id, request.type, labels)
|
||||
model_path = create_new_model(request.project_id, request.type, request.pretrained)
|
||||
return {"model_path": model_path}
|
||||
|
||||
@router.delete("/delete", status_code=204)
|
||||
def model_delete(model_path:str):
|
||||
def delete_model(model_path:str):
|
||||
pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$'
|
||||
if not re.match(pattern, model_path):
|
||||
raise HTTPException(status_code=400,
|
||||
@ -81,7 +77,7 @@ def upload_model(project_id:int, file: UploadFile = File(...)):
|
||||
|
||||
|
||||
@router.get("/download")
|
||||
async def download_model(model_path: str):
|
||||
def download_model(model_path: str):
|
||||
pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$'
|
||||
if not re.match(pattern, model_path):
|
||||
raise HTTPException(status_code=400,
|
||||
|
@ -1,13 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
class LabelCategory(BaseModel):
|
||||
id:int
|
||||
label:str
|
||||
|
||||
|
||||
class ModelCreateRequest(BaseModel):
|
||||
project_id: int
|
||||
type: str
|
||||
pretrained:bool = True
|
||||
labelCategories:Optional[list[LabelCategory]] = None
|
||||
pretrained:bool = True
|
@ -3,12 +3,16 @@ import os
|
||||
import uuid
|
||||
from services.load_model import load_model
|
||||
|
||||
def create_pretrained_model(project_id: int, type:str):
|
||||
def create_new_model(project_id: int, type:str, pretrained:bool):
|
||||
suffix = ""
|
||||
if type in ["seg", "cls"]:
|
||||
suffix = "-"+type
|
||||
# 학습된 기본 모델 로드
|
||||
model = YOLO(os.path.join("resources", "models" ,f"yolov8n{suffix}.pt"))
|
||||
if pretrained:
|
||||
suffix += ".pt"
|
||||
else:
|
||||
suffix += ".yaml"
|
||||
model = YOLO(os.path.join("resources", "models" ,f"yolov8n{suffix}"))
|
||||
|
||||
# 모델을 저장할 폴더 경로
|
||||
base_path = os.path.join("resources","projects",str(project_id),"models")
|
||||
@ -25,9 +29,6 @@ def create_pretrained_model(project_id: int, type:str):
|
||||
|
||||
return model_path
|
||||
|
||||
def create_default_model(project_id: int, type:str, labels:list[str]):
|
||||
pass
|
||||
|
||||
def upload_tmp_model(project_id: int, tmp_path:str):
|
||||
# 모델 불러오기
|
||||
model = load_model(tmp_path)
|
Loading…
Reference in New Issue
Block a user