diff --git a/ai/app/api/yolo/classfication.py b/ai/app/api/yolo/classfication.py index a48c0ad..8496d14 100644 --- a/ai/app/api/yolo/classfication.py +++ b/ai/app/api/yolo/classfication.py @@ -1,32 +1,32 @@ -from fastapi import APIRouter, HTTPException, Request +from fastapi import APIRouter, HTTPException +from api.yolo.detection import get_classes, run_predictions, get_random_color, split_data from schemas.predict_request import PredictRequest -from schemas.train_request import TrainRequest +from schemas.train_request import TrainRequest, TrainDataInfo from schemas.predict_response import PredictResponse, LabelData from schemas.train_report_data import ReportData +from schemas.train_response import TrainResponse from services.load_model import load_classification_model from services.create_model import save_model from utils.file_utils import get_dataset_root_path, process_directories_in_cls, process_image_and_label_in_cls, join_path from utils.slackMessage import send_slack_message from utils.api_utils import send_data_call_api -import random - router = APIRouter() @router.post("/predict") async def classification_predict(request: PredictRequest): - send_slack_message(f"cls predict 요청: {request}", status="success") + send_slack_message(f"predict 요청: {request}", status="success") # 모델 로드 - model = get_model(request) - - # 모델 레이블 카테고리 연결 - classes = list(request.label_map) if request.label_map else None + model = get_model(request.project_id, request.m_key) # 이미지 데이터 정리 url_list = list(map(lambda x:x.image_url, request.image_list)) + # 이 값을 모델에 입력하면 해당하는 클래스 id만 출력됨 + classes = get_classes(request.label_map, model.names) + # 추론 results = run_predictions(model, url_list, request, classes) @@ -40,20 +40,7 @@ def get_model(request: PredictRequest): try: return load_classification_model(request.project_id, request.m_key) except Exception as e: - raise HTTPException(status_code=500, detail="load model exception: " + str(e)) - -# 추론 실행 함수 -def run_predictions(model, image, request, classes): - try: - return model.predict( - source=image, - iou=request.iou_threshold, - conf=request.conf_threshold, - classes=classes - ) - except Exception as e: - raise HTTPException(status_code=500, detail="model predict exception: " + str(e)) - + raise HTTPException(status_code=500, detail="exception in get_model(): " + str(e)) # 추론 결과 처리 함수 def process_prediction_result(result, image, label_map): @@ -68,7 +55,7 @@ def process_prediction_result(result, image, label_map): "points": [ [0, 0] ], - "group_id": label_map[summary['class']] if label_map else summary['class'], + "group_id": label_map[summary['name']], "shape_type": "point", "flags": {} } @@ -80,71 +67,68 @@ def process_prediction_result(result, image, label_map): imageDepth=result.orig_img.shape[2] ) except Exception as e: - raise HTTPException(status_code=500, detail="model predict exception: " + str(e)) + raise HTTPException(status_code=500, detail="exception in process_prediction_result(): " + str(e)) return PredictResponse( image_id=image.image_id, data=label_data.model_dump_json() ) -def get_random_color(): - random_number = random.randint(0, 0xFFFFFF) - return f"#{random_number:06X}" - - - @router.post("/train") async def classification_train(request: TrainRequest): - send_slack_message(f"cls train 요청{request}", status="success") + send_slack_message(f"train 요청{request}", status="success") - # 데이터셋 루트 경로 얻기 + # 데이터셋 루트 경로 얻기 (프로젝트 id 기반) dataset_root_path = get_dataset_root_path(request.project_id) # 모델 로드 - model = get_model(request) + model = get_model(request.project_id, request.m_key) - # 학습할 모델 카테고리, 카테고리가 추가되는 경우 추가 작업 필요 - model_categories = model.names + # 이 값을 학습할때 넣으면 이 카테고리들이 학습됨 + names = list(request.label_map) - # 데이터 전처리 - preprocess_dataset(dataset_root_path, model_categories, request.data, request.ratio) + # 데이터 전처리: 학습할 디렉토리 & 데이터셋 설정 파일을 생성 + process_directories_in_cls(dataset_root_path, names) + + # 데이터 전처리: 데이터를 학습데이터와 테스트 데이터로 분류 + train_data, test_data = split_data(request.data, request.ratio) + + # 데이터 전처리: 데이터 이미지 및 레이블 다운로드 + download_data(train_data, test_data, dataset_root_path) # 학습 - results = run_train(request,model,dataset_root_path) + results = run_train(request, model,dataset_root_path) # best 모델 저장 model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt")) - - response = {"model_key": model_key, "results": results.results_dict} - - send_slack_message(f"train 성공{response}", status="success") + result = results.results_dict + + response = TrainResponse( + modelKey=model_key, + precision= 0, + recall= 0, + mAP50= 0, + mAP5095= 0, + accuracy=result["accuracy_top1"], + fitness= result["fitness"] + ) + + send_slack_message(f"train 성공{response}", status="success") + return response - -def preprocess_dataset(dataset_root_path, model_categories, data, ratio): +def download_data(train_data:list[TrainDataInfo], test_data:list[TrainDataInfo], dataset_root_path:str): try: - # 디렉토리 생성 및 초기화 - process_directories_in_cls(dataset_root_path, model_categories) - - # 학습 데이터 분류 - train_data, test_data = split_data(data, ratio) - if not train_data or not test_data: - raise HTTPException(status_code=400, detail="data split exception: data size is too small or \"ratio\" has invalid value") - - # 학습 데이터 처리 for data in train_data: process_image_and_label_in_cls(data, dataset_root_path, "train") - # 검증 데이터 처리 for data in test_data: process_image_and_label_in_cls(data, dataset_root_path, "test") - - except HTTPException as e: - raise e # HTTP 예외를 다시 발생 except Exception as e: - raise HTTPException(status_code=500, detail="preprocess dataset exception: " + str(e)) + raise HTTPException(status_code=500, detail="exception in download_data(): " + str(e)) + def run_train(request, model, dataset_root_path): try: @@ -164,6 +148,7 @@ def run_train(request, model, dataset_root_path): data = ReportData( epoch=trainer.epoch, # 현재 에포크 total_epochs=trainer.epochs, # 전체 에포크 + seg_loss=0, # seg loss box_loss=0, # box loss cls_loss=loss["train/loss"], # cls loss dfl_loss=0, # dfl loss @@ -174,7 +159,7 @@ def run_train(request, model, dataset_root_path): # 데이터 전송 send_data_call_api(request.project_id, request.m_id, data) except Exception as e: - raise HTTPException(status_code=500, detail=f"send_data exception: {e}") + raise HTTPException(status_code=500, detail="exception in send_data: "+ str(e)) # 콜백 등록 model.add_callback("on_train_epoch_start", send_data) @@ -198,6 +183,6 @@ def run_train(request, model, dataset_root_path): except HTTPException as e: raise e # HTTP 예외를 다시 발생 except Exception as e: - raise HTTPException(status_code=500, detail=f"run_train exception: {e}") + raise HTTPException(status_code=500, detail="exception in run_train(): "+str(e)) diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index c30282a..b3ca8f7 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -144,6 +144,7 @@ async def detection_train(request: TrainRequest): recall= result["metrics/recall(B)"], mAP50= result["metrics/mAP50(B)"], mAP5095= result["metrics/mAP50-95(B)"], + accuracy=0, fitness= result["fitness"] ) send_slack_message(f"train 성공{response}", status="success") diff --git a/ai/app/api/yolo/segmentation.py b/ai/app/api/yolo/segmentation.py index 5df4ca1..89458ed 100644 --- a/ai/app/api/yolo/segmentation.py +++ b/ai/app/api/yolo/segmentation.py @@ -114,6 +114,7 @@ async def segmentation_train(request: TrainRequest): recall= result["metrics/recall(M)"], mAP50= result["metrics/mAP50(M)"], mAP5095= result["metrics/mAP50-95(M)"], + accuracy = 0, fitness= result["fitness"] ) send_slack_message(f"train 성공{response}", status="success") diff --git a/ai/app/schemas/train_response.py b/ai/app/schemas/train_response.py index 05b6403..222b3ce 100644 --- a/ai/app/schemas/train_response.py +++ b/ai/app/schemas/train_response.py @@ -6,4 +6,5 @@ class TrainResponse(BaseModel): recall: float mAP50: float mAP5095: float + accuracy: float fitness: float \ No newline at end of file diff --git a/ai/app/utils/file_utils.py b/ai/app/utils/file_utils.py index 4a9faa9..a4992ea 100644 --- a/ai/app/utils/file_utils.py +++ b/ai/app/utils/file_utils.py @@ -118,10 +118,10 @@ def get_file_name(path): raise FileNotFoundError() return os.path.basename(path) -def process_directories_in_cls(dataset_root_path:str, model_categories:dict[int,str]): +def process_directories_in_cls(dataset_root_path:str, model_categories:list[str]): """classification 학습을 위한 디렉토리 생성""" make_dir(dataset_root_path, init=False) - for category in model_categories.values(): + for category in model_categories: make_dir(os.path.join(dataset_root_path, "train", category), init=True) make_dir(os.path.join(dataset_root_path, "test", category), init=True) if os.path.exists(os.path.join(dataset_root_path, "result")): @@ -140,4 +140,11 @@ def process_image_and_label_in_cls(data:TrainDataInfo, dataset_root_path:str, ch label_path = os.path.join(dataset_root_path,child_path,label_name) # url로부터 이미지 다운로드 - urllib.request.urlretrieve(data.image_url, os.path.join(label_path, img_name)) \ No newline at end of file + if os.path.exists(label_path): + urllib.request.urlretrieve(data.image_url, os.path.join(label_path, img_name)) + else: + # raise FileNotFoundError("failed download") + print("Not Found Label Category. Failed Download") + # 레이블 데이터 중에서 프로젝트 카테고리에 해당되지않는 데이터가 있는 경우 처리 1. 에러 raise 2. 무시(+ warning) + + \ No newline at end of file