diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index bb9234c..870c61d 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -5,6 +5,7 @@ from schemas.predict_request import PredictRequest from schemas.train_request import TrainRequest from schemas.predict_response import PredictResponse, LabelData from services.load_model import load_detection_model +from services.create_model import save_model from utils.dataset_utils import split_data from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path, get_model_path from utils.websocket_utils import WebSocketClient, WebSocketConnectionException @@ -159,15 +160,6 @@ async def detection_predict(request: PredictRequest): @router.post("/train") async def detection_train(request: TrainRequest): - # 데이터셋 루트 경로 얻기 - dataset_root_path = get_dataset_root_path(request.project_id) - - # 디렉토리 생성 및 초기화 - process_directories(dataset_root_path) - - # 학습 데이터 분류 - train_data, val_data = split_data(request.data, request.ratio, request.seed) - # Spring 서버의 WebSocket URL # TODO: 배포시에 변경 spring_server_ws_url = f"ws://localhost:8080/ws" @@ -175,9 +167,34 @@ async def detection_train(request: TrainRequest): # WebSocketClient 인스턴스 생성 ws_client = WebSocketClient(spring_server_ws_url) + # 데이터셋 루트 경로 얻기 + dataset_root_path = get_dataset_root_path(request.project_id) + + # 모델 로드 + try: + model_path = request.m_key and get_model_path(request.project_id, request.m_key) + model = load_detection_model(model_path=model_path) + except Exception as e: + raise HTTPException(status_code=500, detail="load model exception: " + str(e)) + + # 학습할 모델 카테고리 정리 카테고리가 추가되는 경우에 추가할 수 있게 + names = model.names + + # 디렉토리 생성 및 초기화 + process_directories(dataset_root_path, names) + + # 레이블 맵 + inverted_label_map = None + if request.label_map: + inverted_label_map = {value: key for key, value in request.label_map.items()} + + # 학습 데이터 분류 + train_data, val_data = split_data(request.data, request.ratio, request.seed) try: await ws_client.connect() + if not ws_client.is_connected(): + raise WebSocketConnectionException() # 학습 데이터 처리 total_data = len(train_data) @@ -208,10 +225,34 @@ async def detection_train(request: TrainRequest): epochs=request.epochs, batch=request.batch, ) - # return FileResponse(path=join_path(dataset_root_path, "result", "weights", "best.pt"), filename="best.pt", media_type="application/octet-stream") - return {"status": "Training completed successfully"} + + except WebSocketConnectionException as e: + + # 학습 데이터 처리 + total_data = len(train_data) + for idx, data in enumerate(train_data): + # TODO: 비동기면 await 연결 + process_image_and_label(data, dataset_root_path, "train", inverted_label_map) + + # 검증 데이터 처리 + total_val_data = len(val_data) + for idx, data in enumerate(val_data): + # TODO: 비동기면 await 연결 + process_image_and_label(data, dataset_root_path, "val", inverted_label_map) + + results = model.train( + data=join_path(dataset_root_path, "dataset.yaml"), + name=join_path(dataset_root_path, "result"), + epochs=request.epochs, + batch=request.batch, + ) + + model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "last.pt")) + + return {"model_key": model_key, "results": results.results_dict} + except Exception as e: print(f"Training process failed: {str(e)}") diff --git a/ai/app/schemas/train_request.py b/ai/app/schemas/train_request.py index 97ad1f7..3ec5015 100644 --- a/ai/app/schemas/train_request.py +++ b/ai/app/schemas/train_request.py @@ -8,9 +8,10 @@ class TrainDataInfo(BaseModel): class TrainRequest(BaseModel): project_id: int + m_key: Optional[str] = Field(None, alias="model_key") + label_map: dict[int, int] = Field(None, description="모델 레이블 카테고리 idx: 프로젝트 레이블 카테고리 idx , None 일경우 레이블 데이터(프로젝트 레이블)의 idx로 학습") data: List[TrainDataInfo] seed: Optional[int] = None # 랜덤 변수 시드 ratio: float = 0.8 # 훈련/검증 분할 비율 epochs: int = 50 # 훈련 반복 횟수 batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지 - path: Optional[str] = Field(None, alias="model_path") diff --git a/ai/app/utils/file_utils.py b/ai/app/utils/file_utils.py index f1f0f3f..9ecb851 100644 --- a/ai/app/utils/file_utils.py +++ b/ai/app/utils/file_utils.py @@ -6,7 +6,7 @@ from schemas.train_request import TrainDataInfo def get_dataset_root_path(project_id): """데이터셋 루트 절대 경로 반환""" - return os.path.join(os.getcwd(), 'datasets', 'train') + return os.path.join(os.getcwd(), 'resources', 'projects', str(project_id), "train") def make_dir(path:str, init: bool): """ @@ -17,108 +17,26 @@ def make_dir(path:str, init: bool): shutil.rmtree(path) os.makedirs(path, exist_ok=True) -def make_yml(path:str): +def make_yml(path:str, names): data = { "train": f"{path}/train", "val": f"{path}/val", "nc": 80, - "names": - { - 0: "person", - 1: "bicycle", - 2: "car", - 3: "motorcycle", - 4: "airplane", - 5: "bus", - 6: "train", - 7: "truck", - 8: "boat", - 9: "traffic light", - 10: "fire hydrant", - 11: "stop sign", - 12: "parking meter", - 13: "bench", - 14: "bird", - 15: "cat", - 16: "dog", - 17: "horse", - 18: "sheep", - 19: "cow", - 20: "elephant", - 21: "bear", - 22: "zebra", - 23: "giraffe", - 24: "backpack", - 25: "umbrella", - 26: "handbag", - 27: "tie", - 28: "suitcase", - 29: "frisbee", - 30: "skis", - 31: "snowboard", - 32: "sports ball", - 33: "kite", - 34: "baseball bat", - 35: "baseball glove", - 36: "skateboard", - 37: "surfboard", - 38: "tennis racket", - 39: "bottle", - 40: "wine glass", - 41: "cup", - 42: "fork", - 43: "knife", - 44: "spoon", - 45: "bowl", - 46: "banana", - 47: "apple", - 48: "sandwich", - 49: "orange", - 50: "broccoli", - 51: "carrot", - 52: "hot dog", - 53: "pizza", - 54: "donut", - 55: "cake", - 56: "chair", - 57: "couch", - 58: "potted plant", - 59: "bed", - 60: "dining table", - 61: "toilet", - 62: "tv", - 63: "laptop", - 64: "mouse", - 65: "remote", - 66: "keyboard", - 67: "cell phone", - 68: "microwave", - 69: "oven", - 70: "toaster", - 71: "sink", - 72: "refrigerator", - 73: "book", - 74: "clock", - 75: "vase", - 76: "scissors", - 77: "teddy bear", - 78: "hair drier", - 79: "toothbrush" - } + "names": names } with open(os.path.join(path, "dataset.yaml"), 'w') as f: yaml.dump(data, f) -def process_directories(dataset_root_path:str): +def process_directories(dataset_root_path:str, names:list[str]): """학습을 위한 디렉토리 생성""" make_dir(dataset_root_path, init=False) make_dir(os.path.join(dataset_root_path, "train"), init=True) make_dir(os.path.join(dataset_root_path, "val"), init=True) if os.path.exists(os.path.join(dataset_root_path, "result")): shutil.rmtree(os.path.join(dataset_root_path, "result")) - make_yml(dataset_root_path) + make_yml(dataset_root_path, names) -def process_image_and_label(data:TrainDataInfo, dataset_root_path:str, child_path:str): +def process_image_and_label(data:TrainDataInfo, dataset_root_path:str, child_path:str, label_map:dict[int, int]|None): """이미지 저장 및 레이블 파일 생성""" # 이미지 저장 @@ -139,7 +57,7 @@ def process_image_and_label(data:TrainDataInfo, dataset_root_path:str, child_pat y1 = shape.points[0][1] x2 = shape.points[1][0] y2 = shape.points[1][1] - train_label.append(str(shape.group_id)) # label Id + train_label.append(str(label_map[shape.group_id]) if label_map else str(shape.group_id)) # label Id train_label.append(str((x1 + x2) / 2 / label.imageWidth)) # 중심 x 좌표 train_label.append(str((y1 + y2) / 2 / label.imageHeight)) # 중심 y 좌표 train_label.append(str((x2 - x1) / label.imageWidth)) # 너비