Merge branch 'ai/refactor/model-path' into 'ai/develop'
Refactor: model_path 파라미터를 model_key로 변경 See merge request s11-s-project/S11P21S002!109
This commit is contained in:
commit
43f3d4104e
@ -6,7 +6,7 @@ from schemas.train_request import TrainRequest
|
||||
from schemas.predict_response import PredictResponse, LabelData
|
||||
from services.load_model import load_detection_model
|
||||
from utils.dataset_utils import split_data
|
||||
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path
|
||||
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path, get_model_path
|
||||
from utils.websocket_utils import WebSocketClient, WebSocketConnectionException
|
||||
import asyncio
|
||||
|
||||
@ -26,7 +26,8 @@ async def detection_predict(request: PredictRequest):
|
||||
|
||||
# 모델 로드
|
||||
try:
|
||||
model = load_detection_model(request.path)
|
||||
model_path = request.m_key and get_model_path(request.project_id, request.m_key)
|
||||
model = load_detection_model(model_path=model_path)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="load model exception: " + str(e))
|
||||
|
||||
|
@ -2,15 +2,15 @@ from fastapi import APIRouter, HTTPException, File, UploadFile
|
||||
from schemas.model_create_request import ModelCreateRequest
|
||||
from services.create_model import create_new_model, upload_tmp_model
|
||||
from services.load_model import load_model
|
||||
from utils.file_utils import get_model_paths, delete_file, join_path, save_file, get_file_name
|
||||
from utils.file_utils import get_model_keys, delete_file, join_path, save_file, get_file_name
|
||||
import re
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# modelType(detection/segmentation/classification), (default, pretrained), labelCategories
|
||||
@router.get("/info")
|
||||
def get_model_info(model_path:str):
|
||||
@router.get("/info/projects/{project_id}/models/{model_key}", summary= "모델 관련 정보 반환")
|
||||
def get_model_info(project_id:int, model_key:str):
|
||||
model_path = join_path("resources","projects",project_id, "models", model_key)
|
||||
try:
|
||||
model = load_model(model_path=model_path)
|
||||
except FileNotFoundError:
|
||||
@ -18,42 +18,38 @@ def get_model_info(model_path:str):
|
||||
detail= "모델을 찾을 수 없습니다.")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="model load exception: " + str(e))
|
||||
pretrained = model.names == {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'}
|
||||
|
||||
return {"type": model.task, "pretrained":pretrained, "labelCategories":model.names}
|
||||
# TODO: 학습치 등등 추가 예정
|
||||
return {"type": model.task, "labelCategories":model.names}
|
||||
|
||||
# project_id => model path 리스트 를 가져오는 함수
|
||||
@router.get("/list")
|
||||
@router.get("/projects/{project_id}", summary="project id 에 해당하는 모델 id 리스트")
|
||||
def get_model_list(project_id:int):
|
||||
try:
|
||||
return get_model_paths(project_id)
|
||||
return get_model_keys(project_id)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404,
|
||||
detail= "프로젝트가 찾을 수 없거나 생성된 모델이 없습니다.")
|
||||
|
||||
@router.post("/create", status_code=201)
|
||||
def create_model(request: ModelCreateRequest):
|
||||
if request.type not in ["seg", "det", "cls"]:
|
||||
@router.post("/projects/{project_id}", status_code=201)
|
||||
def create_model(project_id: int, request: ModelCreateRequest):
|
||||
if request.project_type not in ["segmentation", "detection", "classfication"]:
|
||||
raise HTTPException(status_code=400,
|
||||
detail= f"Invalid type '{request.type}'. Must be one of \"seg\", \"det\", \"cls\".")
|
||||
model_path = create_new_model(request.project_id, request.type, request.pretrained)
|
||||
return {"model_path": model_path}
|
||||
detail= f"Invalid type '{request.type}'. Must be one of \"segmentation\", \"detection\", \"classfication\".")
|
||||
model_key = create_new_model(project_id, request.project_type, request.pretrained)
|
||||
return {"model_key": model_key}
|
||||
|
||||
@router.delete("/delete", status_code=204)
|
||||
def delete_model(model_path:str):
|
||||
pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$'
|
||||
if not re.match(pattern, model_path):
|
||||
raise HTTPException(status_code=400,
|
||||
detail= "Invalid path format")
|
||||
@router.delete("/projects/{project_id}/models/{model_key}", status_code=204)
|
||||
def delete_model(project_id:int, model_key:str):
|
||||
model_path = join_path("resources", "projects", project_id, "models", model_key)
|
||||
try:
|
||||
delete_file(model_path)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404,
|
||||
detail= "모델을 찾을 수 없습니다.")
|
||||
|
||||
@router.post("/upload")
|
||||
@router.post("/upload/projects/{project_id}")
|
||||
def upload_model(project_id:int, file: UploadFile = File(...)):
|
||||
# 확장자 확인
|
||||
# 확장자 확인 확장자 새로 추가한다면 여기에 추가
|
||||
if not file.filename.endswith(".pt"):
|
||||
raise HTTPException(status_code=400, detail="Only .pt files are allowed.")
|
||||
|
||||
@ -76,12 +72,9 @@ def upload_model(project_id:int, file: UploadFile = File(...)):
|
||||
delete_file(tmp_path)
|
||||
|
||||
|
||||
@router.get("/download")
|
||||
def download_model(model_path: str):
|
||||
pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$'
|
||||
if not re.match(pattern, model_path):
|
||||
raise HTTPException(status_code=400,
|
||||
detail= "Invalid path format")
|
||||
@router.get("/download/projects/{project_id}/models/{model_key}")
|
||||
def download_model(project_id:int, model_key:str):
|
||||
model_path = join_path("resources", "projects", project_id, "models", model_key)
|
||||
try:
|
||||
filename = get_file_name(model_path)
|
||||
# 파일 응답 반환
|
||||
|
@ -2,6 +2,7 @@ from fastapi import APIRouter, HTTPException
|
||||
from schemas.predict_request import PredictRequest
|
||||
from schemas.predict_response import PredictResponse, LabelData
|
||||
from services.load_model import load_segmentation_model
|
||||
from utils.file_utils import get_model_path
|
||||
from typing import List
|
||||
|
||||
router = APIRouter()
|
||||
@ -9,10 +10,11 @@ router = APIRouter()
|
||||
@router.post("/predict", response_model=List[PredictResponse])
|
||||
def predict(request: PredictRequest):
|
||||
version = "0.1.0"
|
||||
|
||||
|
||||
# 모델 로드
|
||||
try:
|
||||
model = load_segmentation_model()
|
||||
model_path = request.m_key and get_model_path(request.project_id, request.m_key)
|
||||
model = load_segmentation_model(model_path)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="load model exception: "+str(e))
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ModelCreateRequest(BaseModel):
|
||||
project_id: int
|
||||
type: str
|
||||
project_type: str
|
||||
pretrained:bool = True
|
@ -11,10 +11,10 @@ class LabelCategory(BaseModel):
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
project_id: int
|
||||
m_key: Optional[str] = Field(None, alias="model_key")
|
||||
image_list: List[ImageInfo]
|
||||
version: str = "latest"
|
||||
conf_threshold: float = 0.25
|
||||
iou_threshold: float = 0.45
|
||||
classes: Optional[List[int]] = None
|
||||
path: Optional[str] = Field(None, alias="model_path")
|
||||
label_categories: Optional[List[LabelCategory]] = None
|
||||
|
@ -27,7 +27,7 @@ def create_new_model(project_id: int, type:str, pretrained:bool):
|
||||
# 기본 모델 저장
|
||||
model.save(filename=model_path)
|
||||
|
||||
return model_path
|
||||
return f"{unique_id}.pt"
|
||||
|
||||
def upload_tmp_model(project_id: int, tmp_path:str):
|
||||
# 모델 불러오기
|
||||
@ -46,4 +46,4 @@ def upload_tmp_model(project_id: int, tmp_path:str):
|
||||
# 기본 모델 저장
|
||||
model.save(filename=model_path)
|
||||
|
||||
return model_path
|
||||
return f"{unique_id}.pt"
|
@ -7,7 +7,7 @@ import os
|
||||
import torch
|
||||
import re
|
||||
|
||||
def load_detection_model(model_path:str):
|
||||
def load_detection_model(model_path:str|None):
|
||||
"""
|
||||
지정된 경로에서 YOLO 모델을 로드합니다.
|
||||
|
||||
@ -29,7 +29,7 @@ def load_detection_model(model_path:str):
|
||||
raise TypeError(f"Invalid model type: {model.task}. Expected a DetectionModel.")
|
||||
return model
|
||||
|
||||
def load_segmentation_model(model_path: str):
|
||||
def load_segmentation_model(model_path: str|None):
|
||||
if model_path:
|
||||
model = YOLO(model_path)
|
||||
else:
|
||||
@ -41,11 +41,6 @@ def load_segmentation_model(model_path: str):
|
||||
return model
|
||||
|
||||
def load_model(model_path: str):
|
||||
# model_path 검증
|
||||
pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$'
|
||||
if not re.match(pattern, model_path):
|
||||
raise Exception("Invalid path format")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model file not found at path: {model_path}")
|
||||
|
||||
|
@ -150,12 +150,12 @@ def join_path(path, *paths):
|
||||
"""os.path.join()과 같은 기능, os import 하기 싫어서 만듦"""
|
||||
return os.path.join(path, *paths)
|
||||
|
||||
def get_model_paths(project_id:int):
|
||||
def get_model_keys(project_id:int):
|
||||
path = os.path.join("resources","projects",str(project_id), "models")
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError()
|
||||
files = os.listdir(path)
|
||||
return [os.path.join(path, file) for file in files if file.endswith(".pt")]
|
||||
return files
|
||||
|
||||
def delete_file(path):
|
||||
if not os.path.exists(path):
|
||||
@ -173,4 +173,7 @@ def save_file(path, file):
|
||||
def get_file_name(path):
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError()
|
||||
return os.path.basename(path)
|
||||
return os.path.basename(path)
|
||||
|
||||
def get_model_path(project_id:int, model_key:str):
|
||||
return os.path.join("resources", "projects", str(project_id), "models", model_key)
|
Loading…
Reference in New Issue
Block a user