Fix: 모델 생성 api에 오타 수정, project_id str로 변환 안한 버그 수정
This commit is contained in:
parent
d6b132c6ce
commit
4d49b925dc
@ -10,7 +10,7 @@ router = APIRouter()
|
|||||||
|
|
||||||
@router.get("/info/projects/{project_id}/models/{model_key}", summary= "모델 관련 정보 반환")
|
@router.get("/info/projects/{project_id}/models/{model_key}", summary= "모델 관련 정보 반환")
|
||||||
def get_model_info(project_id:int, model_key:str):
|
def get_model_info(project_id:int, model_key:str):
|
||||||
model_path = join_path("resources","projects",project_id, "models", model_key)
|
model_path = join_path("resources","projects", str(project_id), "models", model_key)
|
||||||
try:
|
try:
|
||||||
model = load_model(model_path=model_path)
|
model = load_model(model_path=model_path)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@ -32,9 +32,9 @@ def get_model_list(project_id:int):
|
|||||||
|
|
||||||
@router.post("/projects/{project_id}", status_code=201)
|
@router.post("/projects/{project_id}", status_code=201)
|
||||||
def create_model(project_id: int, request: ModelCreateRequest):
|
def create_model(project_id: int, request: ModelCreateRequest):
|
||||||
if request.project_type not in ["segmentation", "detection", "classfication"]:
|
if request.project_type not in ["segmentation", "detection", "classification"]:
|
||||||
raise HTTPException(status_code=400,
|
raise HTTPException(status_code=400,
|
||||||
detail= f"Invalid type '{request.type}'. Must be one of \"segmentation\", \"detection\", \"classfication\".")
|
detail= f"Invalid type '{request.type}'. Must be one of \"segmentation\", \"detection\", \"classification\".")
|
||||||
model_key = create_new_model(project_id, request.project_type, request.pretrained)
|
model_key = create_new_model(project_id, request.project_type, request.pretrained)
|
||||||
return {"model_key": model_key}
|
return {"model_key": model_key}
|
||||||
|
|
||||||
|
@ -5,8 +5,9 @@ from services.load_model import load_model
|
|||||||
|
|
||||||
def create_new_model(project_id: int, type:str, pretrained:bool):
|
def create_new_model(project_id: int, type:str, pretrained:bool):
|
||||||
suffix = ""
|
suffix = ""
|
||||||
if type in ["seg", "cls"]:
|
type_list = {"segmentation": "seg", "classification": "cls"}
|
||||||
suffix = "-"+type
|
if type in type_list:
|
||||||
|
suffix = "-"+type_list[type]
|
||||||
# 학습된 기본 모델 로드
|
# 학습된 기본 모델 로드
|
||||||
if pretrained:
|
if pretrained:
|
||||||
suffix += ".pt"
|
suffix += ".pt"
|
||||||
|
Loading…
Reference in New Issue
Block a user