Fix: 모델 생성 api에 오타 수정, project_id str로 변환 안한 버그 수정

This commit is contained in:
김진현 2024-09-26 20:00:00 +09:00
parent d6b132c6ce
commit 4d49b925dc
2 changed files with 6 additions and 5 deletions

View File

@ -10,7 +10,7 @@ router = APIRouter()
@router.get("/info/projects/{project_id}/models/{model_key}", summary= "모델 관련 정보 반환") @router.get("/info/projects/{project_id}/models/{model_key}", summary= "모델 관련 정보 반환")
def get_model_info(project_id:int, model_key:str): def get_model_info(project_id:int, model_key:str):
model_path = join_path("resources","projects",project_id, "models", model_key) model_path = join_path("resources","projects", str(project_id), "models", model_key)
try: try:
model = load_model(model_path=model_path) model = load_model(model_path=model_path)
except FileNotFoundError: except FileNotFoundError:
@ -32,9 +32,9 @@ def get_model_list(project_id:int):
@router.post("/projects/{project_id}", status_code=201) @router.post("/projects/{project_id}", status_code=201)
def create_model(project_id: int, request: ModelCreateRequest): def create_model(project_id: int, request: ModelCreateRequest):
if request.project_type not in ["segmentation", "detection", "classfication"]: if request.project_type not in ["segmentation", "detection", "classification"]:
raise HTTPException(status_code=400, raise HTTPException(status_code=400,
detail= f"Invalid type '{request.type}'. Must be one of \"segmentation\", \"detection\", \"classfication\".") detail= f"Invalid type '{request.type}'. Must be one of \"segmentation\", \"detection\", \"classification\".")
model_key = create_new_model(project_id, request.project_type, request.pretrained) model_key = create_new_model(project_id, request.project_type, request.pretrained)
return {"model_key": model_key} return {"model_key": model_key}

View File

@ -5,8 +5,9 @@ from services.load_model import load_model
def create_new_model(project_id: int, type:str, pretrained:bool): def create_new_model(project_id: int, type:str, pretrained:bool):
suffix = "" suffix = ""
if type in ["seg", "cls"]: type_list = {"segmentation": "seg", "classification": "cls"}
suffix = "-"+type if type in type_list:
suffix = "-"+type_list[type]
# 학습된 기본 모델 로드 # 학습된 기본 모델 로드
if pretrained: if pretrained:
suffix += ".pt" suffix += ".pt"