Merge branch 'ai/feat/229-data-process-for-train-stats' into 'ai/develop'
Feat: 학습 통계 데이터 처리 구현 - S11P21S002-229 See merge request s11-s-project/S11P21S002!140
This commit is contained in:
commit
dcd4799d20
@ -10,6 +10,7 @@ from utils.dataset_utils import split_data
|
|||||||
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path, get_model_path
|
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path, get_model_path
|
||||||
from utils.websocket_utils import WebSocketClient, WebSocketConnectionException
|
from utils.websocket_utils import WebSocketClient, WebSocketConnectionException
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -189,7 +190,7 @@ async def detection_train(request: TrainRequest):
|
|||||||
inverted_label_map = {value: key for key, value in request.label_map.items()}
|
inverted_label_map = {value: key for key, value in request.label_map.items()}
|
||||||
|
|
||||||
# 학습 데이터 분류
|
# 학습 데이터 분류
|
||||||
train_data, val_data = split_data(request.data, request.ratio, request.seed)
|
train_data, val_data = split_data(request.data, request.ratio)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await ws_client.connect()
|
await ws_client.connect()
|
||||||
@ -218,15 +219,46 @@ async def detection_train(request: TrainRequest):
|
|||||||
# 웹소켓으로 메시지 전송 (필요할 경우 추가)
|
# 웹소켓으로 메시지 전송 (필요할 경우 추가)
|
||||||
await ws_client.send_message("/app/ai/val/progress", f"검증 데이터 처리 중 {request.project_id}: {progress:.2f}% 완료")
|
await ws_client.send_message("/app/ai/val/progress", f"검증 데이터 처리 중 {request.project_id}: {progress:.2f}% 완료")
|
||||||
|
|
||||||
model = load_detection_model("test-data/model/best.pt")
|
from ultralytics.models.yolo.detect import DetectionTrainer
|
||||||
|
|
||||||
|
def send_data(trainer):
|
||||||
|
# 첫번째 epoch는 스킵
|
||||||
|
if trainer.epoch == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
## 남은 시간 계산(초)
|
||||||
|
left_epochs = trainer.epochs-trainer.epoch
|
||||||
|
left_sec = left_epochs*trainer.epoch_time
|
||||||
|
## 로스 box_loss, cls_loss, dfl_loss
|
||||||
|
loss = trainer.label_loss_items(loss_items=trainer.loss_items)
|
||||||
|
data = {
|
||||||
|
"epoch": trainer.epoch, # 현재 에포크
|
||||||
|
"total_epochs": trainer.epochs, # 전체 에포크
|
||||||
|
"box_loss": loss["box_loss"], # box loss
|
||||||
|
"cls_loss": loss["cls_loss"], # cls loss
|
||||||
|
"dfl_loss": loss["dfl_loss"], # dfl loss
|
||||||
|
"fitness": trainer.fitness, # 적합도
|
||||||
|
"epoch_time": trainer.epoch_time, # 지난 에포크 걸린 시간 (에포크 시작 기준으로 결정)
|
||||||
|
"left_second": left_sec # 남은 시간(초)
|
||||||
|
}
|
||||||
|
# 데이터 전송
|
||||||
|
ws_client.send_message()
|
||||||
|
|
||||||
|
model.add_callback("on_train_epoch_start", send_data)
|
||||||
|
|
||||||
model.train(
|
model.train(
|
||||||
data=join_path(dataset_root_path, "dataset.yaml"),
|
data=join_path(dataset_root_path, "dataset.yaml"),
|
||||||
name=join_path(dataset_root_path, "result"),
|
name=join_path(dataset_root_path, "result"),
|
||||||
epochs=request.epochs,
|
epochs=request.epochs,
|
||||||
batch=request.batch,
|
batch=request.batch,
|
||||||
|
lr0=request.lr0,
|
||||||
|
lrf=request.lrf,
|
||||||
|
optimizer=request.optimizer
|
||||||
)
|
)
|
||||||
# return FileResponse(path=join_path(dataset_root_path, "result", "weights", "best.pt"), filename="best.pt", media_type="application/octet-stream")
|
|
||||||
return {"status": "Training completed successfully"}
|
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt"))
|
||||||
|
|
||||||
|
return {"model_key": model_key, "results": results.results_dict}
|
||||||
|
|
||||||
except WebSocketConnectionException as e:
|
except WebSocketConnectionException as e:
|
||||||
|
|
||||||
@ -247,9 +279,12 @@ async def detection_train(request: TrainRequest):
|
|||||||
name=join_path(dataset_root_path, "result"),
|
name=join_path(dataset_root_path, "result"),
|
||||||
epochs=request.epochs,
|
epochs=request.epochs,
|
||||||
batch=request.batch,
|
batch=request.batch,
|
||||||
|
lr0=request.lr0,
|
||||||
|
lrf=request.lrf,
|
||||||
|
optimizer=request.optimizer
|
||||||
)
|
)
|
||||||
|
|
||||||
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "last.pt"))
|
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt"))
|
||||||
|
|
||||||
return {"model_key": model_key, "results": results.results_dict}
|
return {"model_key": model_key, "results": results.results_dict}
|
||||||
|
|
||||||
@ -261,3 +296,5 @@ async def detection_train(request: TrainRequest):
|
|||||||
finally:
|
finally:
|
||||||
if ws_client.is_connected():
|
if ws_client.is_connected():
|
||||||
await ws_client.close()
|
await ws_client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union, Literal
|
||||||
from schemas.predict_response import LabelData
|
from schemas.predict_response import LabelData
|
||||||
|
|
||||||
class TrainDataInfo(BaseModel):
|
class TrainDataInfo(BaseModel):
|
||||||
@ -11,7 +11,11 @@ class TrainRequest(BaseModel):
|
|||||||
m_key: Optional[str] = Field(None, alias="model_key")
|
m_key: Optional[str] = Field(None, alias="model_key")
|
||||||
label_map: dict[int, int] = Field(None, description="모델 레이블 카테고리 idx: 프로젝트 레이블 카테고리 idx , None 일경우 레이블 데이터(프로젝트 레이블)의 idx로 학습")
|
label_map: dict[int, int] = Field(None, description="모델 레이블 카테고리 idx: 프로젝트 레이블 카테고리 idx , None 일경우 레이블 데이터(프로젝트 레이블)의 idx로 학습")
|
||||||
data: List[TrainDataInfo]
|
data: List[TrainDataInfo]
|
||||||
seed: Optional[int] = None # 랜덤 변수 시드
|
|
||||||
ratio: float = 0.8 # 훈련/검증 분할 비율
|
ratio: float = 0.8 # 훈련/검증 분할 비율
|
||||||
|
|
||||||
|
# 학습 파라미터
|
||||||
epochs: int = 50 # 훈련 반복 횟수
|
epochs: int = 50 # 훈련 반복 횟수
|
||||||
batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지
|
batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지
|
||||||
|
lr0: float = 0.01 # 초기 학습 가중치
|
||||||
|
lrf: float = 0.01 # lr0 기준으로 학습 가중치의 최종 수렴치 (ex lr0의 0.01배)
|
||||||
|
optimizer: Literal['auto', 'SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp'] = 'auto'
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
import random
|
import random
|
||||||
from typing import List, Any, Optional
|
|
||||||
|
|
||||||
def split_data(data:List[Any], ratio:float, seed:Optional[int] = None):
|
def split_data(data:list, ratio:float):
|
||||||
random.seed(seed)
|
|
||||||
train_size = int(ratio * len(data))
|
train_size = int(ratio * len(data))
|
||||||
random.shuffle(data)
|
random.shuffle(data)
|
||||||
random.seed(None)
|
|
||||||
train_data = data[:train_size]
|
train_data = data[:train_size]
|
||||||
val_data = data[train_size:]
|
val_data = data[train_size:]
|
||||||
return train_data, val_data
|
return train_data, val_data
|
Loading…
Reference in New Issue
Block a user