Feat: 모델 다운로드 API 구현
This commit is contained in:
parent
a98b3a0c9c
commit
941df49f13
@ -2,8 +2,9 @@ from fastapi import APIRouter, HTTPException, File, UploadFile
|
|||||||
from schemas.model_create_request import ModelCreateRequest
|
from schemas.model_create_request import ModelCreateRequest
|
||||||
from services.init_model import create_pretrained_model, create_default_model, upload_tmp_model
|
from services.init_model import create_pretrained_model, create_default_model, upload_tmp_model
|
||||||
from services.load_model import load_model
|
from services.load_model import load_model
|
||||||
from utils.file_utils import get_model_paths, delete_file, join_path, save_file
|
from utils.file_utils import get_model_paths, delete_file, join_path, save_file, get_file_name
|
||||||
import re
|
import re
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -80,5 +81,15 @@ def upload_model(project_id:int, file: UploadFile = File(...)):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/download")
|
@router.get("/download")
|
||||||
def model_download():
|
async def download_model(model_path: str):
|
||||||
pass
|
pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$'
|
||||||
|
if not re.match(pattern, model_path):
|
||||||
|
raise HTTPException(status_code=400,
|
||||||
|
detail= "Invalid path format")
|
||||||
|
try:
|
||||||
|
filename = get_file_name(model_path)
|
||||||
|
# 파일 응답 반환
|
||||||
|
return FileResponse(model_path, media_type='application/octet-stream', filename=filename)
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise HTTPException(status_code=404,
|
||||||
|
detail= "모델을 찾을 수 없습니다.")
|
||||||
|
@ -169,3 +169,8 @@ def save_file(path, file):
|
|||||||
|
|
||||||
with open(path, "wb") as buffer:
|
with open(path, "wb") as buffer:
|
||||||
shutil.copyfileobj(file.file, buffer)
|
shutil.copyfileobj(file.file, buffer)
|
||||||
|
|
||||||
|
def get_file_name(path):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
raise FileNotFoundError()
|
||||||
|
return os.path.basename(path)
|
Loading…
Reference in New Issue
Block a user