From 9b8d0927d78f64790330f9d1a842772be56a23b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A7=84=ED=98=84?= Date: Fri, 20 Sep 2024 13:54:46 +0900 Subject: [PATCH] =?UTF-8?q?Refactor:=20model=5Fpath=20=ED=8C=8C=EB=9D=BC?= =?UTF-8?q?=EB=AF=B8=ED=84=B0=EB=A5=BC=20model=5Fkey=EB=A1=9C=20=EB=B3=80?= =?UTF-8?q?=EA=B2=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/app/api/yolo/detection.py | 5 ++- ai/app/api/yolo/model.py | 51 +++++++++++--------------- ai/app/api/yolo/segmentation.py | 6 ++- ai/app/schemas/model_create_request.py | 3 +- ai/app/schemas/predict_request.py | 2 +- ai/app/services/create_model.py | 4 +- ai/app/services/load_model.py | 9 +---- ai/app/utils/file_utils.py | 9 +++-- 8 files changed, 41 insertions(+), 48 deletions(-) diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index db9a28a..e556d3e 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -6,7 +6,7 @@ from schemas.train_request import TrainRequest from schemas.predict_response import PredictResponse, LabelData from services.load_model import load_detection_model from utils.dataset_utils import split_data -from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path +from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path, get_model_path from utils.websocket_utils import WebSocketClient, WebSocketConnectionException import asyncio @@ -26,7 +26,8 @@ async def detection_predict(request: PredictRequest): # 모델 로드 try: - model = load_detection_model(request.path) + model_path = request.m_key and get_model_path(request.project_id, request.m_key) + model = load_detection_model(model_path=model_path) except Exception as e: raise HTTPException(status_code=500, detail="load model exception: " + str(e)) diff --git a/ai/app/api/yolo/model.py b/ai/app/api/yolo/model.py index c595c88..a8a22f0 100644 --- a/ai/app/api/yolo/model.py +++ b/ai/app/api/yolo/model.py @@ -2,15 +2,15 @@ from fastapi import APIRouter, HTTPException, File, UploadFile from schemas.model_create_request import ModelCreateRequest from services.create_model import create_new_model, upload_tmp_model from services.load_model import load_model -from utils.file_utils import get_model_paths, delete_file, join_path, save_file, get_file_name +from utils.file_utils import get_model_keys, delete_file, join_path, save_file, get_file_name import re from fastapi.responses import FileResponse router = APIRouter() -# modelType(detection/segmentation/classification), (default, pretrained), labelCategories -@router.get("/info") -def get_model_info(model_path:str): +@router.get("/info/projects/{project_id}/models/{model_key}", summary= "모델 관련 정보 반환") +def get_model_info(project_id:int, model_key:str): + model_path = join_path("resources","projects",project_id, "models", model_key) try: model = load_model(model_path=model_path) except FileNotFoundError: @@ -18,42 +18,38 @@ def get_model_info(model_path:str): detail= "모델을 찾을 수 없습니다.") except Exception as e: raise HTTPException(status_code=500, detail="model load exception: " + str(e)) - pretrained = model.names == {0: 'person', 1: 'bicycle', 2: 'car', 3: 'motorcycle', 4: 'airplane', 5: 'bus', 6: 'train', 7: 'truck', 8: 'boat', 9: 'traffic light', 10: 'fire hydrant', 11: 'stop sign', 12: 'parking meter', 13: 'bench', 14: 'bird', 15: 'cat', 16: 'dog', 17: 'horse', 18: 'sheep', 19: 'cow', 20: 'elephant', 21: 'bear', 22: 'zebra', 23: 'giraffe', 24: 'backpack', 25: 'umbrella', 26: 'handbag', 27: 'tie', 28: 'suitcase', 29: 'frisbee', 30: 'skis', 31: 'snowboard', 32: 'sports ball', 33: 'kite', 34: 'baseball bat', 35: 'baseball glove', 36: 'skateboard', 37: 'surfboard', 38: 'tennis racket', 39: 'bottle', 40: 'wine glass', 41: 'cup', 42: 'fork', 43: 'knife', 44: 'spoon', 45: 'bowl', 46: 'banana', 47: 'apple', 48: 'sandwich', 49: 'orange', 50: 'broccoli', 51: 'carrot', 52: 'hot dog', 53: 'pizza', 54: 'donut', 55: 'cake', 56: 'chair', 57: 'couch', 58: 'potted plant', 59: 'bed', 60: 'dining table', 61: 'toilet', 62: 'tv', 63: 'laptop', 64: 'mouse', 65: 'remote', 66: 'keyboard', 67: 'cell phone', 68: 'microwave', 69: 'oven', 70: 'toaster', 71: 'sink', 72: 'refrigerator', 73: 'book', 74: 'clock', 75: 'vase', 76: 'scissors', 77: 'teddy bear', 78: 'hair drier', 79: 'toothbrush'} - - return {"type": model.task, "pretrained":pretrained, "labelCategories":model.names} + # TODO: 학습치 등등 추가 예정 + return {"type": model.task, "labelCategories":model.names} # project_id => model path 리스트 를 가져오는 함수 -@router.get("/list") +@router.get("/projects/{project_id}", summary="project id 에 해당하는 모델 id 리스트") def get_model_list(project_id:int): try: - return get_model_paths(project_id) + return get_model_keys(project_id) except FileNotFoundError: raise HTTPException(status_code=404, detail= "프로젝트가 찾을 수 없거나 생성된 모델이 없습니다.") -@router.post("/create", status_code=201) -def create_model(request: ModelCreateRequest): - if request.type not in ["seg", "det", "cls"]: +@router.post("/projects/{project_id}", status_code=201) +def create_model(project_id: int, request: ModelCreateRequest): + if request.project_type not in ["segmentation", "detection", "classfication"]: raise HTTPException(status_code=400, - detail= f"Invalid type '{request.type}'. Must be one of \"seg\", \"det\", \"cls\".") - model_path = create_new_model(request.project_id, request.type, request.pretrained) - return {"model_path": model_path} + detail= f"Invalid type '{request.type}'. Must be one of \"segmentation\", \"detection\", \"classfication\".") + model_key = create_new_model(project_id, request.project_type, request.pretrained) + return {"model_key": model_key} -@router.delete("/delete", status_code=204) -def delete_model(model_path:str): - pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$' - if not re.match(pattern, model_path): - raise HTTPException(status_code=400, - detail= "Invalid path format") +@router.delete("/projects/{project_id}/models/{model_key}", status_code=204) +def delete_model(project_id:int, model_key:str): + model_path = join_path("resources", "projects", project_id, "models", model_key) try: delete_file(model_path) except FileNotFoundError: raise HTTPException(status_code=404, detail= "모델을 찾을 수 없습니다.") -@router.post("/upload") +@router.post("/upload/projects/{project_id}") def upload_model(project_id:int, file: UploadFile = File(...)): - # 확장자 확인 + # 확장자 확인 확장자 새로 추가한다면 여기에 추가 if not file.filename.endswith(".pt"): raise HTTPException(status_code=400, detail="Only .pt files are allowed.") @@ -76,12 +72,9 @@ def upload_model(project_id:int, file: UploadFile = File(...)): delete_file(tmp_path) -@router.get("/download") -def download_model(model_path: str): - pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$' - if not re.match(pattern, model_path): - raise HTTPException(status_code=400, - detail= "Invalid path format") +@router.get("/download/projects/{project_id}/models/{model_key}") +def download_model(project_id:int, model_key:str): + model_path = join_path("resources", "projects", project_id, "models", model_key) try: filename = get_file_name(model_path) # 파일 응답 반환 diff --git a/ai/app/api/yolo/segmentation.py b/ai/app/api/yolo/segmentation.py index 8995b2f..5bc8f8f 100644 --- a/ai/app/api/yolo/segmentation.py +++ b/ai/app/api/yolo/segmentation.py @@ -2,6 +2,7 @@ from fastapi import APIRouter, HTTPException from schemas.predict_request import PredictRequest from schemas.predict_response import PredictResponse, LabelData from services.load_model import load_segmentation_model +from utils.file_utils import get_model_path from typing import List router = APIRouter() @@ -9,10 +10,11 @@ router = APIRouter() @router.post("/predict", response_model=List[PredictResponse]) def predict(request: PredictRequest): version = "0.1.0" - + # 모델 로드 try: - model = load_segmentation_model() + model_path = request.m_key and get_model_path(request.project_id, request.m_key) + model = load_segmentation_model(model_path) except Exception as e: raise HTTPException(status_code=500, detail="load model exception: "+str(e)) diff --git a/ai/app/schemas/model_create_request.py b/ai/app/schemas/model_create_request.py index 27d6c82..70aa483 100644 --- a/ai/app/schemas/model_create_request.py +++ b/ai/app/schemas/model_create_request.py @@ -1,6 +1,5 @@ from pydantic import BaseModel class ModelCreateRequest(BaseModel): - project_id: int - type: str + project_type: str pretrained:bool = True \ No newline at end of file diff --git a/ai/app/schemas/predict_request.py b/ai/app/schemas/predict_request.py index af9640b..e65166a 100644 --- a/ai/app/schemas/predict_request.py +++ b/ai/app/schemas/predict_request.py @@ -11,10 +11,10 @@ class LabelCategory(BaseModel): class PredictRequest(BaseModel): project_id: int + m_key: Optional[str] = Field(None, alias="model_key") image_list: List[ImageInfo] version: str = "latest" conf_threshold: float = 0.25 iou_threshold: float = 0.45 classes: Optional[List[int]] = None - path: Optional[str] = Field(None, alias="model_path") label_categories: Optional[List[LabelCategory]] = None diff --git a/ai/app/services/create_model.py b/ai/app/services/create_model.py index 8fe4b52..bbeabdf 100644 --- a/ai/app/services/create_model.py +++ b/ai/app/services/create_model.py @@ -27,7 +27,7 @@ def create_new_model(project_id: int, type:str, pretrained:bool): # 기본 모델 저장 model.save(filename=model_path) - return model_path + return f"{unique_id}.pt" def upload_tmp_model(project_id: int, tmp_path:str): # 모델 불러오기 @@ -46,4 +46,4 @@ def upload_tmp_model(project_id: int, tmp_path:str): # 기본 모델 저장 model.save(filename=model_path) - return model_path \ No newline at end of file + return f"{unique_id}.pt" \ No newline at end of file diff --git a/ai/app/services/load_model.py b/ai/app/services/load_model.py index fed0d27..28ff08c 100644 --- a/ai/app/services/load_model.py +++ b/ai/app/services/load_model.py @@ -7,7 +7,7 @@ import os import torch import re -def load_detection_model(model_path:str): +def load_detection_model(model_path:str|None): """ 지정된 경로에서 YOLO 모델을 로드합니다. @@ -29,7 +29,7 @@ def load_detection_model(model_path:str): raise TypeError(f"Invalid model type: {model.task}. Expected a DetectionModel.") return model -def load_segmentation_model(model_path: str): +def load_segmentation_model(model_path: str|None): if model_path: model = YOLO(model_path) else: @@ -41,11 +41,6 @@ def load_segmentation_model(model_path: str): return model def load_model(model_path: str): - # model_path 검증 - pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$' - if not re.match(pattern, model_path): - raise Exception("Invalid path format") - if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found at path: {model_path}") diff --git a/ai/app/utils/file_utils.py b/ai/app/utils/file_utils.py index 2e5bac6..f1f0f3f 100644 --- a/ai/app/utils/file_utils.py +++ b/ai/app/utils/file_utils.py @@ -150,12 +150,12 @@ def join_path(path, *paths): """os.path.join()과 같은 기능, os import 하기 싫어서 만듦""" return os.path.join(path, *paths) -def get_model_paths(project_id:int): +def get_model_keys(project_id:int): path = os.path.join("resources","projects",str(project_id), "models") if not os.path.exists(path): raise FileNotFoundError() files = os.listdir(path) - return [os.path.join(path, file) for file in files if file.endswith(".pt")] + return files def delete_file(path): if not os.path.exists(path): @@ -173,4 +173,7 @@ def save_file(path, file): def get_file_name(path): if not os.path.exists(path): raise FileNotFoundError() - return os.path.basename(path) \ No newline at end of file + return os.path.basename(path) + +def get_model_path(project_id:int, model_key:str): + return os.path.join("resources", "projects", str(project_id), "models", model_key) \ No newline at end of file