diff --git a/ai/app/api/yolo/model.py b/ai/app/api/yolo/model.py index a87b073..49beb9b 100644 --- a/ai/app/api/yolo/model.py +++ b/ai/app/api/yolo/model.py @@ -2,8 +2,9 @@ from fastapi import APIRouter, HTTPException, File, UploadFile from schemas.model_create_request import ModelCreateRequest from services.init_model import create_pretrained_model, create_default_model, upload_tmp_model from services.load_model import load_model -from utils.file_utils import get_model_paths, delete_file, join_path, save_file +from utils.file_utils import get_model_paths, delete_file, join_path, save_file, get_file_name import re +from fastapi.responses import FileResponse router = APIRouter() @@ -80,5 +81,15 @@ def upload_model(project_id:int, file: UploadFile = File(...)): @router.get("/download") -def model_download(): - pass +async def download_model(model_path: str): + pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$' + if not re.match(pattern, model_path): + raise HTTPException(status_code=400, + detail= "Invalid path format") + try: + filename = get_file_name(model_path) + # 파일 응답 반환 + return FileResponse(model_path, media_type='application/octet-stream', filename=filename) + except FileNotFoundError: + raise HTTPException(status_code=404, + detail= "모델을 찾을 수 없습니다.") diff --git a/ai/app/utils/file_utils.py b/ai/app/utils/file_utils.py index ee14181..2e5bac6 100644 --- a/ai/app/utils/file_utils.py +++ b/ai/app/utils/file_utils.py @@ -168,4 +168,9 @@ def save_file(path, file): os.makedirs(dir_path, exist_ok=True) with open(path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) \ No newline at end of file + shutil.copyfileobj(file.file, buffer) + +def get_file_name(path): + if not os.path.exists(path): + raise FileNotFoundError() + return os.path.basename(path) \ No newline at end of file