From 6c9782a80742e6db6f2adb4cbdfce57004bf913a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A7=84=ED=98=84?= Date: Wed, 25 Sep 2024 23:51:19 +0900 Subject: [PATCH] =?UTF-8?q?refactor:=20detection=5Ftrain=20=EB=A6=AC?= =?UTF-8?q?=ED=8C=A9=ED=86=A0=EB=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/app/api/yolo/detection.py | 134 ++++++++++++++++++++++------------- ai/app/utils/api_utils.py | 2 +- 2 files changed, 87 insertions(+), 49 deletions(-) diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index 2892335..ead3833 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -59,16 +59,13 @@ def run_predictions(model, image, request, classes): # 추론 결과 처리 함수 def process_prediction_result(result, image, label_map): try: - random_number = random.randint(0, 0xFFFFFF) - color = f"#{random_number:06X}" - label_data = LabelData( version="0.0.0", task_type="det", shapes=[ { "label": summary['name'], - "color": color, + "color": get_random_color(), "points": [ [summary['box']['x1'], summary['box']['y1']], [summary['box']['x2'], summary['box']['y2']] @@ -92,6 +89,10 @@ def process_prediction_result(result, image, label_map): data=label_data.model_dump_json() ) +def get_random_color(): + random_number = random.randint(0, 0xFFFFFF) + return f"#{random_number:06X}" + @router.post("/train") @@ -103,76 +104,113 @@ async def detection_train(request: TrainRequest, http_request: Request): auth_header = http_request.headers.get("Authorization") token = auth_header.split(" ")[1] if auth_header and auth_header.startswith("Bearer ") else None + # 레이블 맵 + inverted_label_map = {value: key for key, value in request.label_map.items()} if request.label_map else None + # 데이터셋 루트 경로 얻기 dataset_root_path = get_dataset_root_path(request.project_id) # 모델 로드 model = get_model(request) - # 학습할 모델 카테고리 정리 카테고리가 추가되는 경우에 추가할 수 있게 - names = model.names + # 학습할 모델 카테고리, 카테고리가 추가되는 경우 추가 작업 필요 + model_categories = model.names + + # 데이터 전처리 + preprocess_dataset(dataset_root_path, model_categories, request.data, request.ratio, inverted_label_map) - # 디렉토리 생성 및 초기화 - process_directories(dataset_root_path, names) + # 학습 + results = run_train(request,token,model,dataset_root_path) - # 레이블 맵 - inverted_label_map = {value: key for key, value in request.label_map.items()} if request.label_map else None + # last 모델 저장 + model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt")) - # 학습 데이터 분류 - train_data, val_data = split_data(request.data, request.ratio) + response = {"model_key": model_key, "results": results.results_dict} + send_slack_message(f"train 성공{response}", status="success") + + return response + + +def preprocess_dataset(dataset_root_path, model_categories, data, ratio, label_map): try: + # 디렉토리 생성 및 초기화 + process_directories(dataset_root_path, model_categories) + + # 학습 데이터 분류 + train_data, val_data = split_data(data, ratio) + if not train_data or not val_data: + raise HTTPException(status_code=400, detail="data split exception: data size is too small or \"ratio\" has invalid value") + # 학습 데이터 처리 for data in train_data: - process_image_and_label(data, dataset_root_path, "train", inverted_label_map) + process_image_and_label(data, dataset_root_path, "train", label_map) # 검증 데이터 처리 for data in val_data: - process_image_and_label(data, dataset_root_path, "val", inverted_label_map) + process_image_and_label(data, dataset_root_path, "val", label_map) + except HTTPException as e: + raise e # HTTP 예외를 다시 발생 + except Exception as e: + raise HTTPException(status_code=500, detail="preprocess dataset exception: " + str(e)) + +def run_train(request, token, model, dataset_root_path): + try: + # 데이터 전송 콜백함수 def send_data(trainer): - # 첫번째 epoch는 스킵 - if trainer.epoch == 0: - return + try: + # 첫번째 epoch는 스킵 + if trainer.epoch == 0: + return - ## 남은 시간 계산(초) - left_epochs = trainer.epochs-trainer.epoch - left_seconds = left_epochs*trainer.epoch_time - ## 로스 box_loss, cls_loss, dfl_loss - loss = trainer.label_loss_items(loss_items=trainer.loss_items) - data = ReportData( - epoch= trainer.epoch, # 현재 에포크 - total_epochs= trainer.epochs, # 전체 에포크 - box_loss= loss["train/box_loss"], # box loss - cls_loss= loss["train/cls_loss"], # cls loss - dfl_loss= loss["train/dfl_loss"], # dfl loss - fitness= trainer.fitness, # 적합도 - epoch_time= trainer.epoch_time, # 지난 에포크 걸린 시간 (에포크 시작 기준으로 결정) - left_seconds= left_seconds # 남은 시간(초) - ) - # 데이터 전송 - send_data_call_api(request.project_id, request.m_id, data, token) + # 남은 시간 계산(초) + left_epochs = trainer.epochs - trainer.epoch + left_seconds = left_epochs * trainer.epoch_time + # 로스 box_loss, cls_loss, dfl_loss + loss = trainer.label_loss_items(loss_items=trainer.loss_items) + data = ReportData( + epoch=trainer.epoch, # 현재 에포크 + total_epochs=trainer.epochs, # 전체 에포크 + box_loss=loss["train/box_loss"], # box loss + cls_loss=loss["train/cls_loss"], # cls loss + dfl_loss=loss["train/dfl_loss"], # dfl loss + fitness=trainer.fitness, # 적합도 + epoch_time=trainer.epoch_time, # 지난 에포크 걸린 시간 (에포크 시작 기준으로 결정) + left_seconds=left_seconds # 남은 시간(초) + ) + # 데이터 전송 + send_data_call_api(request.project_id, request.m_id, data, token) + except Exception as e: + raise HTTPException(status_code=500, detail=f"send_data exception: {e}") + + # 콜백 등록 model.add_callback("on_train_epoch_start", send_data) - results = model.train( - data=join_path(dataset_root_path, "dataset.yaml"), - name=join_path(dataset_root_path, "result"), - epochs=request.epochs, - batch=request.batch, - lr0=request.lr0, - lrf=request.lrf, - optimizer=request.optimizer - ) + # 학습 실행 + try: + results = model.train( + data=join_path(dataset_root_path, "dataset.yaml"), + name=join_path(dataset_root_path, "result"), + epochs=request.epochs, + batch=request.batch, + lr0=request.lr0, + lrf=request.lrf, + optimizer=request.optimizer + ) + except Exception as e: + raise HTTPException(status_code=500, detail=f"model train exception: {e}") + # 마지막 에포크 전송 model.trainer.epoch += 1 send_data(model.trainer) - model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt")) - response = {"model_key": model_key, "results": results.results_dict} - send_slack_message(f"train 성공{response}", status="success") - return response + return results + + except HTTPException as e: + raise e # HTTP 예외를 다시 발생 except Exception as e: + raise HTTPException(status_code=500, detail=f"run_train exception: {e}") - raise HTTPException(status_code=500, detail="model train exception: " + str(e)) diff --git a/ai/app/utils/api_utils.py b/ai/app/utils/api_utils.py index 5f11476..d55c7c9 100644 --- a/ai/app/utils/api_utils.py +++ b/ai/app/utils/api_utils.py @@ -3,7 +3,7 @@ from dotenv import load_dotenv import os, httpx -def report_data(project_id:int, model_id:int, data:ReportData, token): +def send_data_call_api(project_id:int, model_id:int, data:ReportData, token): try: load_dotenv() # main.py와 같은 디렉토리에 .env 파일 생성해서 따옴표 없이 입력