diff --git a/ai/app/api/yolo/classfication.py b/ai/app/api/yolo/classfication.py index d1c6a9b..f7d11fc 100644 --- a/ai/app/api/yolo/classfication.py +++ b/ai/app/api/yolo/classfication.py @@ -6,7 +6,7 @@ from schemas.train_report_data import ReportData from services.load_model import load_classification_model from services.create_model import save_model from utils.dataset_utils import split_data -from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path +from utils.file_utils import get_dataset_root_path, process_directories_in_cls, process_image_and_label_in_cls, join_path from utils.slackMessage import send_slack_message from utils.api_utils import send_data_call_api import random @@ -17,7 +17,7 @@ router = APIRouter() @router.post("/predict") async def classification_predict(request: PredictRequest): - send_slack_message(f"predict 요청: {request}", status="success") + send_slack_message(f"cls predict 요청: {request}", status="success") # 모델 로드 model = get_model(request) @@ -61,17 +61,16 @@ def process_prediction_result(result, image, label_map): try: label_data = LabelData( version="0.0.0", - task_type="det", + task_type="cls", shapes=[ { "label": summary['name'], "color": get_random_color(), "points": [ - [summary['box']['x1'], summary['box']['y1']], - [summary['box']['x2'], summary['box']['y2']] + [0, 0] ], "group_id": label_map[summary['class']] if label_map else summary['class'], - "shape_type": "rectangle", + "shape_type": "point", "flags": {} } for summary in result.summary() @@ -96,16 +95,9 @@ def get_random_color(): @router.post("/train") -async def classification_train(request: TrainRequest, http_request: Request): +async def classification_train(request: TrainRequest): - send_slack_message(f"train 요청{request}", status="success") - - # Authorization 헤더에서 Bearer 토큰 추출 - auth_header = http_request.headers.get("Authorization") - token = auth_header.split(" ")[1] if auth_header and auth_header.startswith("Bearer ") else None - - # 레이블 맵 - inverted_label_map = {value: key for key, value in request.label_map.items()} if request.label_map else None + send_slack_message(f"cls train 요청{request}", status="success") # 데이터셋 루트 경로 얻기 dataset_root_path = get_dataset_root_path(request.project_id) @@ -117,10 +109,10 @@ async def classification_train(request: TrainRequest, http_request: Request): model_categories = model.names # 데이터 전처리 - preprocess_dataset(dataset_root_path, model_categories, request.data, request.ratio, inverted_label_map) + preprocess_dataset(dataset_root_path, model_categories, request.data, request.ratio) # 학습 - results = run_train(request,token,model,dataset_root_path) + results = run_train(request,model,dataset_root_path) # best 모델 저장 model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt")) @@ -132,30 +124,30 @@ async def classification_train(request: TrainRequest, http_request: Request): return response -def preprocess_dataset(dataset_root_path, model_categories, data, ratio, label_map): +def preprocess_dataset(dataset_root_path, model_categories, data, ratio): try: # 디렉토리 생성 및 초기화 - process_directories(dataset_root_path, model_categories) + process_directories_in_cls(dataset_root_path, model_categories) # 학습 데이터 분류 - train_data, val_data = split_data(data, ratio) - if not train_data or not val_data: + train_data, test_data = split_data(data, ratio) + if not train_data or not test_data: raise HTTPException(status_code=400, detail="data split exception: data size is too small or \"ratio\" has invalid value") # 학습 데이터 처리 for data in train_data: - process_image_and_label(data, dataset_root_path, "train", label_map) + process_image_and_label_in_cls(data, dataset_root_path, "train") # 검증 데이터 처리 - for data in val_data: - process_image_and_label(data, dataset_root_path, "val", label_map) + for data in test_data: + process_image_and_label_in_cls(data, dataset_root_path, "test") except HTTPException as e: raise e # HTTP 예외를 다시 발생 except Exception as e: raise HTTPException(status_code=500, detail="preprocess dataset exception: " + str(e)) -def run_train(request, token, model, dataset_root_path): +def run_train(request, model, dataset_root_path): try: # 데이터 전송 콜백함수 def send_data(trainer): @@ -171,17 +163,17 @@ def run_train(request, token, model, dataset_root_path): # 로스 box_loss, cls_loss, dfl_loss loss = trainer.label_loss_items(loss_items=trainer.loss_items) data = ReportData( - epoch=trainer.epoch, # 현재 에포크 - total_epochs=trainer.epochs, # 전체 에포크 - box_loss=loss["train/box_loss"], # box loss - cls_loss=loss["train/cls_loss"], # cls loss - dfl_loss=loss["train/dfl_loss"], # dfl loss - fitness=trainer.fitness, # 적합도 - epoch_time=trainer.epoch_time, # 지난 에포크 걸린 시간 (에포크 시작 기준으로 결정) - left_seconds=left_seconds # 남은 시간(초) + epoch=trainer.epoch, # 현재 에포크 + total_epochs=trainer.epochs, # 전체 에포크 + box_loss=0, # box loss + cls_loss=loss["train/loss"], # cls loss + dfl_loss=0, # dfl loss + fitness=trainer.fitness, # 적합도 + epoch_time=trainer.epoch_time, # 지난 에포크 걸린 시간 (에포크 시작 기준으로 결정) + left_seconds=left_seconds # 남은 시간(초) ) # 데이터 전송 - send_data_call_api(request.project_id, request.m_id, data, token) + send_data_call_api(request.project_id, request.m_id, data) except Exception as e: raise HTTPException(status_code=500, detail=f"send_data exception: {e}") @@ -189,20 +181,16 @@ def run_train(request, token, model, dataset_root_path): model.add_callback("on_train_epoch_start", send_data) # 학습 실행 - try: - results = model.train( - data=join_path(dataset_root_path, "dataset.yaml"), - name=join_path(dataset_root_path, "result"), - epochs=request.epochs, - batch=request.batch, - lr0=request.lr0, - lrf=request.lrf, - optimizer=request.optimizer - ) - except Exception as e: - raise HTTPException(status_code=500, detail=f"model train exception: {e}") - - # 마지막 에포크 전송 + results = model.train( + data=dataset_root_path, + name=join_path(dataset_root_path, "result"), + epochs=request.epochs, + batch=request.batch, + lr0=request.lr0, + lrf=request.lrf, + optimizer=request.optimizer + ) + # 마지막 에포크 전송 model.trainer.epoch += 1 send_data(model.trainer) diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index 3674d00..ddaef31 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -101,7 +101,6 @@ async def detection_train(request: TrainRequest): send_slack_message(f"train 요청{request}", status="success") - # Authorization 헤더에서 Bearer 토큰 추출 try: # 레이블 맵 inverted_label_map = {value: key for key, value in request.label_map.items()} if request.label_map else None diff --git a/ai/app/main.py b/ai/app/main.py index 71c2064..e1b507e 100644 --- a/ai/app/main.py +++ b/ai/app/main.py @@ -4,6 +4,7 @@ from fastapi.exceptions import RequestValidationError from starlette.exceptions import HTTPException from api.yolo.detection import router as yolo_detection_router from api.yolo.segmentation import router as yolo_segmentation_router +from api.yolo.classfication import router as yolo_classification_router from api.yolo.model import router as yolo_model_router from utils.slackMessage import send_slack_message @@ -12,6 +13,7 @@ app = FastAPI() # 각 기능별 라우터를 애플리케이션에 등록 app.include_router(yolo_detection_router, prefix="/api/detection", tags=["Detection"]) app.include_router(yolo_segmentation_router, prefix="/api/segmentation", tags=["Segmentation"]) +app.include_router(yolo_classification_router, prefix="/api/classification", tags=["Classification"]) app.include_router(yolo_model_router, prefix="/api/model", tags=["Model"]) diff --git a/ai/app/services/load_model.py b/ai/app/services/load_model.py index e799353..ad12f97 100644 --- a/ai/app/services/load_model.py +++ b/ai/app/services/load_model.py @@ -10,7 +10,7 @@ def load_detection_model(project_id:int, model_key:str): if model_key in default_model_map: model = YOLO(default_model_map[model_key]) else: - model = load_model(model_path=os.path.join("projects",str(project_id),"models", model_key)) + model = load_model(model_path=os.path.join("resources", "projects",str(project_id),"models", model_key)) # Detection 모델인지 검증 if model.task != "detect": @@ -23,13 +23,26 @@ def load_segmentation_model(project_id:int, model_key:str): if model_key in default_model_map: model = YOLO(default_model_map[model_key]) else: - model = load_model(model_path=os.path.join("projects",str(project_id),"models",model_key)) + model = load_model(model_path=os.path.join("resources", "projects",str(project_id),"models",model_key)) # Segmentation 모델인지 검증 if model.task != "segment": raise TypeError(f"Invalid model type: {model.task}. Expected a SegmentationModel.") return model +def load_classification_model(project_id:int, model_key:str): + default_model_map = {"yolo8": os.path.join("resources","models","yolov8n-cls.pt")} + # 디폴트 모델 확인 + if model_key in default_model_map: + model = YOLO(default_model_map[model_key]) + else: + model = load_model(model_path=os.path.join("resources", "projects",str(project_id),"models",model_key)) + + # Segmentation 모델인지 검증 + if model.task != "classify": + raise TypeError(f"Invalid model type: {model.task}. Expected a ClassificationModel.") + return model + def load_model(model_path: str): if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found at path: {model_path}") diff --git a/ai/app/utils/file_utils.py b/ai/app/utils/file_utils.py index da0db24..e8b553d 100644 --- a/ai/app/utils/file_utils.py +++ b/ai/app/utils/file_utils.py @@ -24,7 +24,7 @@ def make_yml(path:str, model_categories): data = { "train": f"{path}/train", "val": f"{path}/val", - "nc": 80, + "nc": len(model_categories), "names": model_categories } with open(os.path.join(path, "dataset.yaml"), 'w') as f: @@ -117,3 +117,28 @@ def get_file_name(path): if not os.path.exists(path): raise FileNotFoundError() return os.path.basename(path) + +def process_directories_in_cls(dataset_root_path:str, model_categories:dict[int,str]): + """classification 학습을 위한 디렉토리 생성""" + make_dir(dataset_root_path, init=False) + for category in model_categories.values(): + make_dir(os.path.join(dataset_root_path, "train", category), init=True) + make_dir(os.path.join(dataset_root_path, "test", category), init=True) + if os.path.exists(os.path.join(dataset_root_path, "result")): + shutil.rmtree(os.path.join(dataset_root_path, "result")) + +def process_image_and_label_in_cls(data:TrainDataInfo, dataset_root_path:str, child_path:str): + """이미지 저장 및 레이블 파일 생성""" + # 이미지 url로부터 파일명 분리 + img_name = data.image_url.split('/')[-1] + + # 레이블 객체 불러오기 + label = json.loads(urllib.request.urlopen(data.data_url).read()) + + label_name = label["shapes"][0]["label"] + + label_path = os.path.join(dataset_root_path,child_path,label_name) + + # url로부터 이미지 다운로드 + urllib.request.urlretrieve(data.image_url, os.path.join(label_path, img_name)) +