From d6b132c6ce45162175487d48e37c6655b5d73198 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A7=84=ED=98=84?= Date: Thu, 26 Sep 2024 19:53:56 +0900 Subject: [PATCH] =?UTF-8?q?Feat:=20=EC=84=B8=EA=B7=B8=EB=A9=98=ED=85=8C?= =?UTF-8?q?=EC=9D=B4=EC=85=98=20train=20response=20=EC=88=98=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/app/api/yolo/segmentation.py | 94 ++++++++++++++++++--------------- 1 file changed, 52 insertions(+), 42 deletions(-) diff --git a/ai/app/api/yolo/segmentation.py b/ai/app/api/yolo/segmentation.py index cc5887b..948e28c 100644 --- a/ai/app/api/yolo/segmentation.py +++ b/ai/app/api/yolo/segmentation.py @@ -3,6 +3,7 @@ from schemas.predict_request import PredictRequest from schemas.train_request import TrainRequest from schemas.predict_response import PredictResponse, LabelData from schemas.train_report_data import ReportData +from schemas.train_response import TrainResponse from services.load_model import load_segmentation_model from services.create_model import save_model from utils.dataset_utils import split_data @@ -91,40 +92,50 @@ def get_random_color(): @router.post("/train") -async def segmentation_train(request: TrainRequest, http_request: Request): - +async def segmentation_train(request: TrainRequest): + send_slack_message(f"train 요청{request}", status="success") - # Authorization 헤더에서 Bearer 토큰 추출 - auth_header = http_request.headers.get("Authorization") - token = auth_header.split(" ")[1] if auth_header and auth_header.startswith("Bearer ") else None + try: + # 레이블 맵 + inverted_label_map = {value: key for key, value in request.label_map.items()} if request.label_map else None - # 레이블 맵 - inverted_label_map = {value: key for key, value in request.label_map.items()} if request.label_map else None + # 데이터셋 루트 경로 얻기 + dataset_root_path = get_dataset_root_path(request.project_id) - # 데이터셋 루트 경로 얻기 - dataset_root_path = get_dataset_root_path(request.project_id) + # 모델 로드 + model = get_model(request) - # 모델 로드 - model = get_model(request) + # 학습할 모델 카테고리, 카테고리가 추가되는 경우 추가 작업 필요 + model_categories = model.names + + # 데이터 전처리 + preprocess_dataset(dataset_root_path, model_categories, request.data, request.ratio, inverted_label_map) - # 학습할 모델 카테고리, 카테고리가 추가되는 경우 추가 작업 필요 - model_categories = model.names - - # 데이터 전처리 - preprocess_dataset(dataset_root_path, model_categories, request.data, request.ratio, inverted_label_map) + # 학습 + results = run_train(request, model,dataset_root_path) - # 학습 - results = run_train(request,token,model,dataset_root_path) + # best 모델 저장 + model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt")) + + result = results.results_dict - # best 모델 저장 - model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt")) + response = TrainResponse( + modelKey=model_key, + precision= result["metrics/precision(M)"], + recall= result["metrics/recall(M)"], + mAP50= result["metrics/mAP50(M)"], + mAP5095= result["metrics/mAP50-95(M)"], + fitness= result["fitness"] + ) + send_slack_message(f"train 성공{response}", status="success") + + return response - response = {"model_key": model_key, "results": results.results_dict} - - send_slack_message(f"train 성공{response}", status="success") - - return response + except HTTPException as e: + raise e + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) def preprocess_dataset(dataset_root_path, model_categories, data, ratio, label_map): @@ -150,7 +161,7 @@ def preprocess_dataset(dataset_root_path, model_categories, data, ratio, label_m except Exception as e: raise HTTPException(status_code=500, detail="preprocess dataset exception: " + str(e)) -def run_train(request, token, model, dataset_root_path): +def run_train(request, model, dataset_root_path): try: # 데이터 전송 콜백함수 def send_data(trainer): @@ -168,7 +179,7 @@ def run_train(request, token, model, dataset_root_path): data = ReportData( epoch=trainer.epoch, # 현재 에포크 total_epochs=trainer.epochs, # 전체 에포크 - seg_loss=loss["train/seg_loss"], # seg loss + box_loss=loss["train/box_loss"], # box loss cls_loss=loss["train/cls_loss"], # cls loss dfl_loss=loss["train/dfl_loss"], # dfl loss fitness=trainer.fitness, # 적합도 @@ -176,7 +187,7 @@ def run_train(request, token, model, dataset_root_path): left_seconds=left_seconds # 남은 시간(초) ) # 데이터 전송 - send_data_call_api(request.project_id, request.m_id, data, token) + send_data_call_api(request.project_id, request.m_id, data) except Exception as e: raise HTTPException(status_code=500, detail=f"send_data exception: {e}") @@ -184,23 +195,19 @@ def run_train(request, token, model, dataset_root_path): model.add_callback("on_train_epoch_start", send_data) # 학습 실행 - try: - results = model.train( - data=join_path(dataset_root_path, "dataset.yaml"), - name=join_path(dataset_root_path, "result"), - epochs=request.epochs, - batch=request.batch, - lr0=request.lr0, - lrf=request.lrf, - optimizer=request.optimizer - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"model train exception: {e}") - + results = model.train( + data=join_path(dataset_root_path, "dataset.yaml"), + name=join_path(dataset_root_path, "result"), + epochs=request.epochs, + batch=request.batch, + lr0=request.lr0, + lrf=request.lrf, + optimizer=request.optimizer + ) + # 마지막 에포크 전송 model.trainer.epoch += 1 send_data(model.trainer) - return results except HTTPException as e: @@ -211,3 +218,6 @@ def run_train(request, token, model, dataset_root_path): + + +