Merge branch 'ai/feat/229-data-process-for-train-stats' into 'ai/develop'
Feat: 학습 통계 데이터 처리 구현 - S11P21S002-229 See merge request s11-s-project/S11P21S002!140
This commit is contained in:
commit
dcd4799d20
@ -10,6 +10,7 @@ from utils.dataset_utils import split_data
|
||||
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path, get_model_path
|
||||
from utils.websocket_utils import WebSocketClient, WebSocketConnectionException
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
@ -189,7 +190,7 @@ async def detection_train(request: TrainRequest):
|
||||
inverted_label_map = {value: key for key, value in request.label_map.items()}
|
||||
|
||||
# 학습 데이터 분류
|
||||
train_data, val_data = split_data(request.data, request.ratio, request.seed)
|
||||
train_data, val_data = split_data(request.data, request.ratio)
|
||||
|
||||
try:
|
||||
await ws_client.connect()
|
||||
@ -218,15 +219,46 @@ async def detection_train(request: TrainRequest):
|
||||
# 웹소켓으로 메시지 전송 (필요할 경우 추가)
|
||||
await ws_client.send_message("/app/ai/val/progress", f"검증 데이터 처리 중 {request.project_id}: {progress:.2f}% 완료")
|
||||
|
||||
model = load_detection_model("test-data/model/best.pt")
|
||||
from ultralytics.models.yolo.detect import DetectionTrainer
|
||||
|
||||
def send_data(trainer):
|
||||
# 첫번째 epoch는 스킵
|
||||
if trainer.epoch == 0:
|
||||
return
|
||||
|
||||
## 남은 시간 계산(초)
|
||||
left_epochs = trainer.epochs-trainer.epoch
|
||||
left_sec = left_epochs*trainer.epoch_time
|
||||
## 로스 box_loss, cls_loss, dfl_loss
|
||||
loss = trainer.label_loss_items(loss_items=trainer.loss_items)
|
||||
data = {
|
||||
"epoch": trainer.epoch, # 현재 에포크
|
||||
"total_epochs": trainer.epochs, # 전체 에포크
|
||||
"box_loss": loss["box_loss"], # box loss
|
||||
"cls_loss": loss["cls_loss"], # cls loss
|
||||
"dfl_loss": loss["dfl_loss"], # dfl loss
|
||||
"fitness": trainer.fitness, # 적합도
|
||||
"epoch_time": trainer.epoch_time, # 지난 에포크 걸린 시간 (에포크 시작 기준으로 결정)
|
||||
"left_second": left_sec # 남은 시간(초)
|
||||
}
|
||||
# 데이터 전송
|
||||
ws_client.send_message()
|
||||
|
||||
model.add_callback("on_train_epoch_start", send_data)
|
||||
|
||||
model.train(
|
||||
data=join_path(dataset_root_path, "dataset.yaml"),
|
||||
name=join_path(dataset_root_path, "result"),
|
||||
epochs=request.epochs,
|
||||
batch=request.batch,
|
||||
lr0=request.lr0,
|
||||
lrf=request.lrf,
|
||||
optimizer=request.optimizer
|
||||
)
|
||||
# return FileResponse(path=join_path(dataset_root_path, "result", "weights", "best.pt"), filename="best.pt", media_type="application/octet-stream")
|
||||
return {"status": "Training completed successfully"}
|
||||
|
||||
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt"))
|
||||
|
||||
return {"model_key": model_key, "results": results.results_dict}
|
||||
|
||||
except WebSocketConnectionException as e:
|
||||
|
||||
@ -247,9 +279,12 @@ async def detection_train(request: TrainRequest):
|
||||
name=join_path(dataset_root_path, "result"),
|
||||
epochs=request.epochs,
|
||||
batch=request.batch,
|
||||
lr0=request.lr0,
|
||||
lrf=request.lrf,
|
||||
optimizer=request.optimizer
|
||||
)
|
||||
|
||||
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "last.pt"))
|
||||
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt"))
|
||||
|
||||
return {"model_key": model_key, "results": results.results_dict}
|
||||
|
||||
@ -261,3 +296,5 @@ async def detection_train(request: TrainRequest):
|
||||
finally:
|
||||
if ws_client.is_connected():
|
||||
await ws_client.close()
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, Literal
|
||||
from schemas.predict_response import LabelData
|
||||
|
||||
class TrainDataInfo(BaseModel):
|
||||
@ -11,7 +11,11 @@ class TrainRequest(BaseModel):
|
||||
m_key: Optional[str] = Field(None, alias="model_key")
|
||||
label_map: dict[int, int] = Field(None, description="모델 레이블 카테고리 idx: 프로젝트 레이블 카테고리 idx , None 일경우 레이블 데이터(프로젝트 레이블)의 idx로 학습")
|
||||
data: List[TrainDataInfo]
|
||||
seed: Optional[int] = None # 랜덤 변수 시드
|
||||
ratio: float = 0.8 # 훈련/검증 분할 비율
|
||||
|
||||
# 학습 파라미터
|
||||
epochs: int = 50 # 훈련 반복 횟수
|
||||
batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지
|
||||
lr0: float = 0.01 # 초기 학습 가중치
|
||||
lrf: float = 0.01 # lr0 기준으로 학습 가중치의 최종 수렴치 (ex lr0의 0.01배)
|
||||
optimizer: Literal['auto', 'SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp'] = 'auto'
|
||||
|
@ -1,11 +1,8 @@
|
||||
import random
|
||||
from typing import List, Any, Optional
|
||||
|
||||
def split_data(data:List[Any], ratio:float, seed:Optional[int] = None):
|
||||
random.seed(seed)
|
||||
def split_data(data:list, ratio:float):
|
||||
train_size = int(ratio * len(data))
|
||||
random.shuffle(data)
|
||||
random.seed(None)
|
||||
train_data = data[:train_size]
|
||||
val_data = data[train_size:]
|
||||
return train_data, val_data
|
Loading…
Reference in New Issue
Block a user