diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index b67914b..07b33b6 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -1,8 +1,12 @@ from fastapi import APIRouter, HTTPException from schemas.predict_request import PredictRequest +from schemas.train_request import TrainRequest from schemas.predict_response import PredictResponse, LabelData from services.ai_service import load_detection_model +from utils.dataset_utils import split_data +from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label from typing import List +from PIL import Image router = APIRouter() @router.post("/detection", response_model=List[PredictResponse]) @@ -72,4 +76,24 @@ def predict(request: PredictRequest): }) except Exception as e: raise HTTPException(status_code=500, detail="label parsing exception: "+str(e)) - return response \ No newline at end of file + return response + + +@router.post("/detection/train") +def train(request: TrainRequest): + # 데이터셋 루트 경로 얻기 + dataset_root_path = get_dataset_root_path(request.project_id) + + # 디렉토리 생성 및 초기화 + process_directories(dataset_root_path) + + # 학습 데이터 분류 + train_data, val_data = split_data(request.data, request.ratio, request.seed) + + # 학습 데이터 처리 + for data in train_data: + process_image_and_label(data, dataset_root_path, "train") + + # 검증 데이터 처리 + for data in val_data: + process_image_and_label(data, dataset_root_path, "val") \ No newline at end of file diff --git a/ai/app/schemas/train_request.py b/ai/app/schemas/train_request.py new file mode 100644 index 0000000..a239dd2 --- /dev/null +++ b/ai/app/schemas/train_request.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel +from typing import List, Optional +from schemas.predict_response import LabelData + +class TrainDataInfo(BaseModel): + image_url: str + label: LabelData + + +class TrainRequest(BaseModel): + project_id: int + data: List[TrainDataInfo] + seed: Optional[int] = None # 랜덤 변수 시드 + ratio: Optional[float] = 0.8 # 훈련/검증 분할 비율 diff --git a/ai/app/utils/dataset_utils.py b/ai/app/utils/dataset_utils.py new file mode 100644 index 0000000..0d18f8e --- /dev/null +++ b/ai/app/utils/dataset_utils.py @@ -0,0 +1,11 @@ +import random +from typing import List, Any, Optional + +def split_data(data:List[Any], ratio:float, seed:Optional[int] = None): + random.seed(seed) + train_size = int(ratio * len(data)) + random.shuffle(data) + random.seed(None) + train_data = data[:train_size] + val_data = data[train_size:] + return train_data, val_data \ No newline at end of file diff --git a/ai/app/utils/file_utils.py b/ai/app/utils/file_utils.py new file mode 100644 index 0000000..f9605ca --- /dev/null +++ b/ai/app/utils/file_utils.py @@ -0,0 +1,145 @@ +import os +import shutil +import yaml +from PIL import Image +from schemas.train_request import TrainDataInfo + +def get_dataset_root_path(project_id): + """프로젝트 ID를 기반으로 데이터셋 루트 경로 반환""" + return os.path.join('test-data', 'projects', str(project_id), 'train_model') + +def make_dir(path:str, init: bool): + """ + path : 디렉토리 경로 + init : 폴더를 초기화 할지 여부 + """ + if (os.path.exists(path) and init): + shutil.rmtree(path) + os.makedirs(path, exist_ok=True) + +def make_yml(path:str): + data = { + "train": "train", + "val": "val", + "nc": 80, + "names": + { + 0: "person", + 1: "bicycle", + 2: "car", + 3: "motorcycle", + 4: "airplane", + 5: "bus", + 6: "train", + 7: "truck", + 8: "boat", + 9: "traffic light", + 10: "fire hydrant", + 11: "stop sign", + 12: "parking meter", + 13: "bench", + 14: "bird", + 15: "cat", + 16: "dog", + 17: "horse", + 18: "sheep", + 19: "cow", + 20: "elephant", + 21: "bear", + 22: "zebra", + 23: "giraffe", + 24: "backpack", + 25: "umbrella", + 26: "handbag", + 27: "tie", + 28: "suitcase", + 29: "frisbee", + 30: "skis", + 31: "snowboard", + 32: "sports ball", + 33: "kite", + 34: "baseball bat", + 35: "baseball glove", + 36: "skateboard", + 37: "surfboard", + 38: "tennis racket", + 39: "bottle", + 40: "wine glass", + 41: "cup", + 42: "fork", + 43: "knife", + 44: "spoon", + 45: "bowl", + 46: "banana", + 47: "apple", + 48: "sandwich", + 49: "orange", + 50: "broccoli", + 51: "carrot", + 52: "hot dog", + 53: "pizza", + 54: "donut", + 55: "cake", + 56: "chair", + 57: "couch", + 58: "potted plant", + 59: "bed", + 60: "dining table", + 61: "toilet", + 62: "tv", + 63: "laptop", + 64: "mouse", + 65: "remote", + 66: "keyboard", + 67: "cell phone", + 68: "microwave", + 69: "oven", + 70: "toaster", + 71: "sink", + 72: "refrigerator", + 73: "book", + 74: "clock", + 75: "vase", + 76: "scissors", + 77: "teddy bear", + 78: "hair drier", + 79: "toothbrush" + } + } + with open(path, 'w') as f: + yaml.dump(data, f) + +def process_directories(dataset_root_path:str): + """학습을 위한 디렉토리 생성""" + make_dir(dataset_root_path, init=False) + make_dir(os.path.join(dataset_root_path, "train"), init=True) + make_dir(os.path.join(dataset_root_path, "val"), init=True) + make_yml(os.path.join(dataset_root_path, "dataset.yaml")) + +def process_image_and_label(data:TrainDataInfo, dataset_root_path:str, child_path:str): + + """이미지 저장 및 레이블 파일 생성""" + # 이미지 저장 + img = Image.open(data.image_url) + + # 파일명에서 확장자를 제거하여 img_title과 img_ext 생성 + img_title, img_ext = os.path.splitext(os.path.basename(data.image_url)) + + # 이미지 파일 저장 (확장자를 그대로 사용) + img.save(os.path.join(dataset_root_path, child_path, img_title + img_ext)) + + # 레이블 -> 학습용 레이블 데이터 파싱(detection) + label = data.label + with open(os.path.join(dataset_root_path, child_path, f"{img_title}.txt"), "w") as train_label_txt: + for shape in label.shapes: + train_label = [] + x1 = shape.points[0][0] + y1 = shape.points[0][1] + x2 = shape.points[1][0] + y2 = shape.points[1][1] + train_label.append(str(shape.group_id)) # label Id + train_label.append(str((x1 + x2) / 2)) # 중심 x 좌표 + train_label.append(str((y1 + y2) / 2)) # 중심 y 좌표 + train_label.append(str(x2 - x1)) # 너비 + train_label.append(str(y2 - y1)) # 높이 + train_label_txt.write(" ".join(train_label)+"\n") diff --git a/ai/environment.yml b/ai/environment.yml index ccd77ce..19c564e 100644 --- a/ai/environment.yml +++ b/ai/environment.yml @@ -6,10 +6,13 @@ channels: - defaults dependencies: - python=3.10.10 - - pytorch - - torchvision - - torchaudio + - pytorch=2.3.1 + - torchvision=0.16.1 + - torchaudio=2.3.1 - pytorch-cuda=12.1 - fastapi - uvicorn - - ultralytics \ No newline at end of file + - ultralytics + - dill + - boto3 + - python-dotenv \ No newline at end of file