Merge branch 'ai/feat/229-data-process-for-train-stats' into 'ai/develop'

Feat: 학습 통계 데이터 처리 구현 - S11P21S002-229

See merge request s11-s-project/S11P21S002!140
This commit is contained in:
정현조 2024-09-23 16:28:19 +09:00
commit dcd4799d20
3 changed files with 49 additions and 11 deletions

View File

@ -10,6 +10,7 @@ from utils.dataset_utils import split_data
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path, get_model_path from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path, get_model_path
from utils.websocket_utils import WebSocketClient, WebSocketConnectionException from utils.websocket_utils import WebSocketClient, WebSocketConnectionException
import asyncio import asyncio
import time
router = APIRouter() router = APIRouter()
@ -189,7 +190,7 @@ async def detection_train(request: TrainRequest):
inverted_label_map = {value: key for key, value in request.label_map.items()} inverted_label_map = {value: key for key, value in request.label_map.items()}
# 학습 데이터 분류 # 학습 데이터 분류
train_data, val_data = split_data(request.data, request.ratio, request.seed) train_data, val_data = split_data(request.data, request.ratio)
try: try:
await ws_client.connect() await ws_client.connect()
@ -218,15 +219,46 @@ async def detection_train(request: TrainRequest):
# 웹소켓으로 메시지 전송 (필요할 경우 추가) # 웹소켓으로 메시지 전송 (필요할 경우 추가)
await ws_client.send_message("/app/ai/val/progress", f"검증 데이터 처리 중 {request.project_id}: {progress:.2f}% 완료") await ws_client.send_message("/app/ai/val/progress", f"검증 데이터 처리 중 {request.project_id}: {progress:.2f}% 완료")
model = load_detection_model("test-data/model/best.pt") from ultralytics.models.yolo.detect import DetectionTrainer
def send_data(trainer):
# 첫번째 epoch는 스킵
if trainer.epoch == 0:
return
## 남은 시간 계산(초)
left_epochs = trainer.epochs-trainer.epoch
left_sec = left_epochs*trainer.epoch_time
## 로스 box_loss, cls_loss, dfl_loss
loss = trainer.label_loss_items(loss_items=trainer.loss_items)
data = {
"epoch": trainer.epoch, # 현재 에포크
"total_epochs": trainer.epochs, # 전체 에포크
"box_loss": loss["box_loss"], # box loss
"cls_loss": loss["cls_loss"], # cls loss
"dfl_loss": loss["dfl_loss"], # dfl loss
"fitness": trainer.fitness, # 적합도
"epoch_time": trainer.epoch_time, # 지난 에포크 걸린 시간 (에포크 시작 기준으로 결정)
"left_second": left_sec # 남은 시간(초)
}
# 데이터 전송
ws_client.send_message()
model.add_callback("on_train_epoch_start", send_data)
model.train( model.train(
data=join_path(dataset_root_path, "dataset.yaml"), data=join_path(dataset_root_path, "dataset.yaml"),
name=join_path(dataset_root_path, "result"), name=join_path(dataset_root_path, "result"),
epochs=request.epochs, epochs=request.epochs,
batch=request.batch, batch=request.batch,
lr0=request.lr0,
lrf=request.lrf,
optimizer=request.optimizer
) )
# return FileResponse(path=join_path(dataset_root_path, "result", "weights", "best.pt"), filename="best.pt", media_type="application/octet-stream")
return {"status": "Training completed successfully"} model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt"))
return {"model_key": model_key, "results": results.results_dict}
except WebSocketConnectionException as e: except WebSocketConnectionException as e:
@ -247,9 +279,12 @@ async def detection_train(request: TrainRequest):
name=join_path(dataset_root_path, "result"), name=join_path(dataset_root_path, "result"),
epochs=request.epochs, epochs=request.epochs,
batch=request.batch, batch=request.batch,
lr0=request.lr0,
lrf=request.lrf,
optimizer=request.optimizer
) )
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "last.pt")) model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt"))
return {"model_key": model_key, "results": results.results_dict} return {"model_key": model_key, "results": results.results_dict}
@ -261,3 +296,5 @@ async def detection_train(request: TrainRequest):
finally: finally:
if ws_client.is_connected(): if ws_client.is_connected():
await ws_client.close() await ws_client.close()

View File

@ -1,5 +1,5 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Optional, Union from typing import List, Optional, Union, Literal
from schemas.predict_response import LabelData from schemas.predict_response import LabelData
class TrainDataInfo(BaseModel): class TrainDataInfo(BaseModel):
@ -11,7 +11,11 @@ class TrainRequest(BaseModel):
m_key: Optional[str] = Field(None, alias="model_key") m_key: Optional[str] = Field(None, alias="model_key")
label_map: dict[int, int] = Field(None, description="모델 레이블 카테고리 idx: 프로젝트 레이블 카테고리 idx , None 일경우 레이블 데이터(프로젝트 레이블)의 idx로 학습") label_map: dict[int, int] = Field(None, description="모델 레이블 카테고리 idx: 프로젝트 레이블 카테고리 idx , None 일경우 레이블 데이터(프로젝트 레이블)의 idx로 학습")
data: List[TrainDataInfo] data: List[TrainDataInfo]
seed: Optional[int] = None # 랜덤 변수 시드
ratio: float = 0.8 # 훈련/검증 분할 비율 ratio: float = 0.8 # 훈련/검증 분할 비율
# 학습 파라미터
epochs: int = 50 # 훈련 반복 횟수 epochs: int = 50 # 훈련 반복 횟수
batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지 batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지
lr0: float = 0.01 # 초기 학습 가중치
lrf: float = 0.01 # lr0 기준으로 학습 가중치의 최종 수렴치 (ex lr0의 0.01배)
optimizer: Literal['auto', 'SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp'] = 'auto'

View File

@ -1,11 +1,8 @@
import random import random
from typing import List, Any, Optional
def split_data(data:List[Any], ratio:float, seed:Optional[int] = None): def split_data(data:list, ratio:float):
random.seed(seed)
train_size = int(ratio * len(data)) train_size = int(ratio * len(data))
random.shuffle(data) random.shuffle(data)
random.seed(None)
train_data = data[:train_size] train_data = data[:train_size]
val_data = data[train_size:] val_data = data[train_size:]
return train_data, val_data return train_data, val_data