worlabel/ai/app/schemas/train_request.py

18 lines
832 B
Python
Raw Normal View History

from pydantic import BaseModel, Field
from typing import List, Optional, Union
from schemas.predict_response import LabelData
class TrainDataInfo(BaseModel):
image_url: str
label: LabelData
class TrainRequest(BaseModel):
project_id: int
m_key: Optional[str] = Field(None, alias="model_key")
label_map: dict[int, int] = Field(None, description="모델 레이블 카테고리 idx: 프로젝트 레이블 카테고리 idx , None 일경우 레이블 데이터(프로젝트 레이블)의 idx로 학습")
data: List[TrainDataInfo]
seed: Optional[int] = None # 랜덤 변수 시드
ratio: float = 0.8 # 훈련/검증 분할 비율
epochs: int = 50 # 훈련 반복 횟수
batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지