Refactor: model_path 파라미터를 model_key로 변경

This commit is contained in:
김진현 2024-09-20 13:54:46 +09:00
parent 9ce4c986d2
commit 9b8d0927d7
8 changed files with 41 additions and 48 deletions

View File

@ -6,7 +6,7 @@ from schemas.train_request import TrainRequest
from schemas.predict_response import PredictResponse, LabelData
from services.load_model import load_detection_model
from utils.dataset_utils import split_data
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path, get_model_path
from utils.websocket_utils import WebSocketClient, WebSocketConnectionException
import asyncio
@ -26,7 +26,8 @@ async def detection_predict(request: PredictRequest):
# 모델 로드
try:
model = load_detection_model(request.path)
model_path = request.m_key and get_model_path(request.project_id, request.m_key)
model = load_detection_model(model_path=model_path)
except Exception as e:
raise HTTPException(status_code=500, detail="load model exception: " + str(e))

View File

@ -2,15 +2,15 @@ from fastapi import APIRouter, HTTPException, File, UploadFile
from schemas.model_create_request import ModelCreateRequest
from services.create_model import create_new_model, upload_tmp_model
from services.load_model import load_model
from utils.file_utils import get_model_paths, delete_file, join_path, save_file, get_file_name
from utils.file_utils import get_model_keys, delete_file, join_path, save_file, get_file_name
import re
from fastapi.responses import FileResponse
router = APIRouter()
# modelType(detection/segmentation/classification), (default, pretrained), labelCategories
@router.get("/info")
def get_model_info(model_path:str):
@router.get("/info/projects/{project_id}/models/{model_key}", summary= "모델 관련 정보 반환")
def get_model_info(project_id:int, model_key:str):
model_path = join_path("resources","projects",project_id, "models", model_key)
try:
model = load_model(model_path=model_path)
except FileNotFoundError:
@ -18,42 +18,38 @@ def get_model_info(model_path:str):
detail= "모델을 찾을 수 없습니다.")
except Exception as e:
raise HTTPException(status_code=500, detail="model load exception: " + str(e))
pretrained = model.names == {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}
return {"type": model.task, "pretrained":pretrained, "labelCategories":model.names}
# TODO: 학습치 등등 추가 예정
return {"type": model.task, "labelCategories":model.names}
# project_id => model path 리스트 를 가져오는 함수
@router.get("/list")
@router.get("/projects/{project_id}", summary="project id 에 해당하는 모델 id 리스트")
def get_model_list(project_id:int):
try:
return get_model_paths(project_id)
return get_model_keys(project_id)
except FileNotFoundError:
raise HTTPException(status_code=404,
detail= "프로젝트가 찾을 수 없거나 생성된 모델이 없습니다.")
@router.post("/create", status_code=201)
def create_model(request: ModelCreateRequest):
if request.type not in ["seg", "det", "cls"]:
@router.post("/projects/{project_id}", status_code=201)
def create_model(project_id: int, request: ModelCreateRequest):
if request.project_type not in ["segmentation", "detection", "classfication"]:
raise HTTPException(status_code=400,
detail= f"Invalid type '{request.type}'. Must be one of \"seg\", \"det\", \"cls\".")
model_path = create_new_model(request.project_id, request.type, request.pretrained)
return {"model_path": model_path}
detail= f"Invalid type '{request.type}'. Must be one of \"segmentation\", \"detection\", \"classfication\".")
model_key = create_new_model(project_id, request.project_type, request.pretrained)
return {"model_key": model_key}
@router.delete("/delete", status_code=204)
def delete_model(model_path:str):
pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$'
if not re.match(pattern, model_path):
raise HTTPException(status_code=400,
detail= "Invalid path format")
@router.delete("/projects/{project_id}/models/{model_key}", status_code=204)
def delete_model(project_id:int, model_key:str):
model_path = join_path("resources", "projects", project_id, "models", model_key)
try:
delete_file(model_path)
except FileNotFoundError:
raise HTTPException(status_code=404,
detail= "모델을 찾을 수 없습니다.")
@router.post("/upload")
@router.post("/upload/projects/{project_id}")
def upload_model(project_id:int, file: UploadFile = File(...)):
# 확장자 확인
# 확장자 확인 확장자 새로 추가한다면 여기에 추가
if not file.filename.endswith(".pt"):
raise HTTPException(status_code=400, detail="Only .pt files are allowed.")
@ -76,12 +72,9 @@ def upload_model(project_id:int, file: UploadFile = File(...)):
delete_file(tmp_path)
@router.get("/download")
def download_model(model_path: str):
pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$'
if not re.match(pattern, model_path):
raise HTTPException(status_code=400,
detail= "Invalid path format")
@router.get("/download/projects/{project_id}/models/{model_key}")
def download_model(project_id:int, model_key:str):
model_path = join_path("resources", "projects", project_id, "models", model_key)
try:
filename = get_file_name(model_path)
# 파일 응답 반환

View File

@ -2,6 +2,7 @@ from fastapi import APIRouter, HTTPException
from schemas.predict_request import PredictRequest
from schemas.predict_response import PredictResponse, LabelData
from services.load_model import load_segmentation_model
from utils.file_utils import get_model_path
from typing import List
router = APIRouter()
@ -9,10 +10,11 @@ router = APIRouter()
@router.post("/predict", response_model=List[PredictResponse])
def predict(request: PredictRequest):
version = "0.1.0"
# 모델 로드
try:
model = load_segmentation_model()
model_path = request.m_key and get_model_path(request.project_id, request.m_key)
model = load_segmentation_model(model_path)
except Exception as e:
raise HTTPException(status_code=500, detail="load model exception: "+str(e))

View File

@ -1,6 +1,5 @@
from pydantic import BaseModel
class ModelCreateRequest(BaseModel):
project_id: int
type: str
project_type: str
pretrained:bool = True

View File

@ -11,10 +11,10 @@ class LabelCategory(BaseModel):
class PredictRequest(BaseModel):
project_id: int
m_key: Optional[str] = Field(None, alias="model_key")
image_list: List[ImageInfo]
version: str = "latest"
conf_threshold: float = 0.25
iou_threshold: float = 0.45
classes: Optional[List[int]] = None
path: Optional[str] = Field(None, alias="model_path")
label_categories: Optional[List[LabelCategory]] = None

View File

@ -27,7 +27,7 @@ def create_new_model(project_id: int, type:str, pretrained:bool):
# 기본 모델 저장
model.save(filename=model_path)
return model_path
return f"{unique_id}.pt"
def upload_tmp_model(project_id: int, tmp_path:str):
# 모델 불러오기
@ -46,4 +46,4 @@ def upload_tmp_model(project_id: int, tmp_path:str):
# 기본 모델 저장
model.save(filename=model_path)
return model_path
return f"{unique_id}.pt"

View File

@ -7,7 +7,7 @@ import os
import torch
import re
def load_detection_model(model_path:str):
def load_detection_model(model_path:str|None):
"""
지정된 경로에서 YOLO 모델을 로드합니다.
@ -29,7 +29,7 @@ def load_detection_model(model_path:str):
raise TypeError(f"Invalid model type: {model.task}. Expected a DetectionModel.")
return model
def load_segmentation_model(model_path: str):
def load_segmentation_model(model_path: str|None):
if model_path:
model = YOLO(model_path)
else:
@ -41,11 +41,6 @@ def load_segmentation_model(model_path: str):
return model
def load_model(model_path: str):
# model_path 검증
pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$'
if not re.match(pattern, model_path):
raise Exception("Invalid path format")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at path: {model_path}")

View File

@ -150,12 +150,12 @@ def join_path(path, *paths):
"""os.path.join()과 같은 기능, os import 하기 싫어서 만듦"""
return os.path.join(path, *paths)
def get_model_paths(project_id:int):
def get_model_keys(project_id:int):
path = os.path.join("resources","projects",str(project_id), "models")
if not os.path.exists(path):
raise FileNotFoundError()
files = os.listdir(path)
return [os.path.join(path, file) for file in files if file.endswith(".pt")]
return files
def delete_file(path):
if not os.path.exists(path):
@ -173,4 +173,7 @@ def save_file(path, file):
def get_file_name(path):
if not os.path.exists(path):
raise FileNotFoundError()
return os.path.basename(path)
return os.path.basename(path)
def get_model_path(project_id:int, model_key:str):
return os.path.join("resources", "projects", str(project_id), "models", model_key)