2024-09-19 16:49:11 +09:00
|
|
|
from pydantic import BaseModel, Field
|
2024-09-23 16:21:01 +09:00
|
|
|
from typing import List, Optional, Union, Literal
|
2024-09-08 16:46:52 +09:00
|
|
|
from schemas.predict_response import LabelData
|
|
|
|
|
|
|
|
class TrainDataInfo(BaseModel):
|
|
|
|
image_url: str
|
2024-09-25 16:06:45 +09:00
|
|
|
data_url: str
|
2024-09-08 16:46:52 +09:00
|
|
|
|
|
|
|
class TrainRequest(BaseModel):
|
|
|
|
project_id: int
|
2024-09-25 15:24:14 +09:00
|
|
|
m_key: str = Field("yolo8", alias="model_key")
|
|
|
|
m_id: int = Field(..., alias="model_id") # 학습 중 에포크 결과를 보낼때 model_id를 보냄
|
2024-09-27 11:18:52 +09:00
|
|
|
label_map: dict[str, int] = Field(..., description="프로젝트 레이블 이름: 프로젝트 레이블 pk , None일 경우 모델 레이블 카테고리 idx로 레이블링")
|
2024-09-08 16:46:52 +09:00
|
|
|
data: List[TrainDataInfo]
|
2024-09-09 17:46:15 +09:00
|
|
|
ratio: float = 0.8 # 훈련/검증 분할 비율
|
2024-09-23 16:21:01 +09:00
|
|
|
|
|
|
|
# 학습 파라미터
|
2024-09-09 17:46:15 +09:00
|
|
|
epochs: int = 50 # 훈련 반복 횟수
|
|
|
|
batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지
|
2024-09-23 16:21:01 +09:00
|
|
|
lr0: float = 0.01 # 초기 학습 가중치
|
|
|
|
lrf: float = 0.01 # lr0 기준으로 학습 가중치의 최종 수렴치 (ex lr0의 0.01배)
|
|
|
|
optimizer: Literal['auto', 'SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp'] = 'auto'
|