2024-09-19 16:49:11 +09:00
|
|
|
from pydantic import BaseModel, Field
|
2024-09-09 17:46:15 +09:00
|
|
|
from typing import List, Optional, Union
|
2024-09-08 16:46:52 +09:00
|
|
|
from schemas.predict_response import LabelData
|
|
|
|
|
|
|
|
class TrainDataInfo(BaseModel):
|
|
|
|
image_url: str
|
|
|
|
label: LabelData
|
|
|
|
|
|
|
|
class TrainRequest(BaseModel):
|
|
|
|
project_id: int
|
|
|
|
data: List[TrainDataInfo]
|
|
|
|
seed: Optional[int] = None # 랜덤 변수 시드
|
2024-09-09 17:46:15 +09:00
|
|
|
ratio: float = 0.8 # 훈련/검증 분할 비율
|
|
|
|
epochs: int = 50 # 훈련 반복 횟수
|
|
|
|
batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지
|
2024-09-19 16:49:11 +09:00
|
|
|
path: Optional[str] = Field(None, alias="model_path")
|