From 4d49b925dc83529d83d29e749b4b3e16fb93160a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A7=84=ED=98=84?= Date: Thu, 26 Sep 2024 20:00:00 +0900 Subject: [PATCH] =?UTF-8?q?Fix:=20=EB=AA=A8=EB=8D=B8=20=EC=83=9D=EC=84=B1?= =?UTF-8?q?=20api=EC=97=90=20=EC=98=A4=ED=83=80=20=EC=88=98=EC=A0=95,=20pr?= =?UTF-8?q?oject=5Fid=20str=EB=A1=9C=20=EB=B3=80=ED=99=98=20=EC=95=88?= =?UTF-8?q?=ED=95=9C=20=EB=B2=84=EA=B7=B8=20=EC=88=98=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/app/api/yolo/model.py | 6 +++--- ai/app/services/create_model.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/ai/app/api/yolo/model.py b/ai/app/api/yolo/model.py index b816f21..489912f 100644 --- a/ai/app/api/yolo/model.py +++ b/ai/app/api/yolo/model.py @@ -10,7 +10,7 @@ router = APIRouter() @router.get("/info/projects/{project_id}/models/{model_key}", summary= "모델 관련 정보 반환") def get_model_info(project_id:int, model_key:str): - model_path = join_path("resources","projects",project_id, "models", model_key) + model_path = join_path("resources","projects", str(project_id), "models", model_key) try: model = load_model(model_path=model_path) except FileNotFoundError: @@ -32,9 +32,9 @@ def get_model_list(project_id:int): @router.post("/projects/{project_id}", status_code=201) def create_model(project_id: int, request: ModelCreateRequest): - if request.project_type not in ["segmentation", "detection", "classfication"]: + if request.project_type not in ["segmentation", "detection", "classification"]: raise HTTPException(status_code=400, - detail= f"Invalid type '{request.type}'. Must be one of \"segmentation\", \"detection\", \"classfication\".") + detail= f"Invalid type '{request.type}'. Must be one of \"segmentation\", \"detection\", \"classification\".") model_key = create_new_model(project_id, request.project_type, request.pretrained) return {"model_key": model_key} diff --git a/ai/app/services/create_model.py b/ai/app/services/create_model.py index 3dea563..b5c6b8e 100644 --- a/ai/app/services/create_model.py +++ b/ai/app/services/create_model.py @@ -5,8 +5,9 @@ from services.load_model import load_model def create_new_model(project_id: int, type:str, pretrained:bool): suffix = "" - if type in ["seg", "cls"]: - suffix = "-"+type + type_list = {"segmentation": "seg", "classification": "cls"} + if type in type_list: + suffix = "-"+type_list[type] # 학습된 기본 모델 로드 if pretrained: suffix += ".pt"