diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index 241bf09..beecdd1 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -1,3 +1,5 @@ +import json + from fastapi import APIRouter, HTTPException from schemas.predict_request import PredictRequest from schemas.train_request import TrainRequest @@ -6,105 +8,203 @@ from app.services.load_model import load_detection_model from utils.dataset_utils import split_data from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path from typing import List -from fastapi.responses import FileResponse +from utils.websocket_utils import WebSocketClient +import asyncio + router = APIRouter() + + @router.post("/detection", response_model=List[PredictResponse]) -def predict(request: PredictRequest): +async def predict(request: PredictRequest): version = "0.1.0" + print("여기") - # 모델 로드 - try: - model = load_detection_model() - except Exception as e: - raise HTTPException(status_code=500, detail="load model exception: "+str(e)) + # Spring 서버의 WebSocket URL + # TODO: 배포 시 변경 + spring_server_ws_url = f"ws://localhost:8080/ws" - # 추론 - results = [] - try: - for image in request.image_list: - # URL에서 이미지를 메모리로 로드 TODO: 추후 메모리에 할지 어떻게 해야할지 or 병렬 처리 고민 - # response = requests.get(image.image_url) + print("여기") + # WebSocketClient 인스턴스 생성 + ws_client = WebSocketClient(spring_server_ws_url) - # 이미지 데이터를 메모리로 로드 - # img = Image.open(io.BytesIO(response.content)) - - predict_results = model.predict( - source=image.image_url, - iou=request.iou_threshold, - conf=request.conf_threshold, - classes=request.classes - ) - results.append(predict_results[0]) - - # 메모리에서 이미지 객체 해제 - # img.close() - # del img - except Exception as e: - raise HTTPException(status_code=500, detail="model predict exception: "+str(e)) - - # 추론 결과 -> 레이블 객체 파싱 - response = [] try: - for (image, result) in zip(request.image_list, results): - label_data:LabelData = { - "version": version, - "task_type": "det", - "shapes": [ - { - "label": summary['name'], - "color": "#ff0000", - "points": [ - [summary['box']['x1'], summary['box']['y1']], - [summary['box']['x2'], summary['box']['y2']] - ], - "group_id": summary['class'], - "shape_type": "rectangle", - "flags": {} - } - for summary in result.summary() - ], - "split": "none", - "imageHeight": result.orig_img.shape[0], - "imageWidth": result.orig_img.shape[1], - "imageDepth": result.orig_img.shape[2] - } - response.append({ - "image_id":image.image_id, - "image_url":image.image_url, - "data":label_data + await ws_client.connect() + + # 모델 로드 + try: + model = load_detection_model() + except Exception as e: + raise HTTPException(status_code=500, detail="load model exception: " + str(e)) + + # 추론 + results = [] + total_images = len(request.image_list) + for idx, image in enumerate(request.image_list): + try: + # URL에서 이미지를 메모리로 로드 TODO: 추후 메모리에 할지 어떻게 해야할지 or 병렬 처리 고민 + + predict_results = model.predict( + source=image.image_url, + iou=request.iou_threshold, + conf=request.conf_threshold, + classes=request.classes + ) + # 예측 결과 처리 + result = predict_results[0] + label_data = LabelData( + version=version, + task_type="det", + shapes=[ + { + "label": summary['name'], + "color": "#ff0000", + "points": [ + [summary['box']['x1'], summary['box']['y1']], + [summary['box']['x2'], summary['box']['y2']] + ], + "group_id": summary['class'], + "shape_type": "rectangle", + "flags": {} + } + for summary in result.summary() + ], + split="none", + imageHeight=result.orig_img.shape[0], + imageWidth=result.orig_img.shape[1], + imageDepth=result.orig_img.shape[2] + ) + + response_item = PredictResponse( + image_id=image.image_id, + image_url=image.image_url, + data=label_data + ) + + # 진행률 계산 + progress = (idx + 1) / total_images * 100 + + # 웹소켓으로 예측 결과와 진행률 전송 + message = { + "project_id": request.project_id, + "progress": progress, + "result": response_item.dict() + } + + await ws_client.send_message("/app/ai/predict/progress", json.dumps(message)) + + except Exception as e: + raise HTTPException(status_code=500, detail="model predict exception: " + str(e)) + + # 추론 결과 -> 레이블 객체 파싱 + response = [] + try: + for (image, result) in zip(request.image_list, results): + label_data: LabelData = { + "version": version, + "task_type": "det", + "shapes": [ + { + "label": summary['name'], + "color": "#ff0000", + "points": [ + [summary['box']['x1'], summary['box']['y1']], + [summary['box']['x2'], summary['box']['y2']] + ], + "group_id": summary['class'], + "shape_type": "rectangle", + "flags": {} + } + for summary in result.summary() + ], + "split": "none", + "imageHeight": result.orig_img.shape[0], + "imageWidth": result.orig_img.shape[1], + "imageDepth": result.orig_img.shape[2] + } + response.append({ + "image_id": image.image_id, + "image_url": image.image_url, + "data": label_data }) + except Exception as e: + raise HTTPException(status_code=500, detail="label parsing exception: " + str(e)) + + return response + except Exception as e: - raise HTTPException(status_code=500, detail="label parsing exception: "+str(e)) - return response + print(f"Prediction process failed: {str(e)}") + raise HTTPException(status_code=500, detail="Prediction process failed") + + finally: + if ws_client.is_connected(): + await ws_client.close() @router.post("/detection/train") -def train(request: TrainRequest): +async def train(request: TrainRequest): # 데이터셋 루트 경로 얻기 dataset_root_path = get_dataset_root_path(request.project_id) - + # 디렉토리 생성 및 초기화 process_directories(dataset_root_path) - + # 학습 데이터 분류 train_data, val_data = split_data(request.data, request.ratio, request.seed) + + # Spring 서버의 WebSocket URL + # TODO: 배포시에 변경 + spring_server_ws_url = f"ws://localhost:8080/ws" - # 학습 데이터 처리 - for data in train_data: - process_image_and_label(data, dataset_root_path, "train") + # WebSocketClient 인스턴스 생성 + ws_client = WebSocketClient(spring_server_ws_url) - # 검증 데이터 처리 - for data in val_data: - process_image_and_label(data, dataset_root_path, "val") - model = load_detection_model("test-data/model/best.pt") + try: + await ws_client.connect() - model.train( - data=join_path(dataset_root_path,"dataset.yaml"), - name=join_path(dataset_root_path,"result"), - epochs= request.epochs, - batch=request.batch, + # 학습 데이터 처리 + total_data = len(train_data) + for idx, data in enumerate(train_data): + # TODO: 비동기면 await 연결 + # process_image_and_label(data, dataset_root_path, "train") + + # 진행률 계산 + progress = (idx + 1) / total_data * 100 + + await ws_client.send_message("/app/ai/train/progress", f"학습 데이터 처리 중 {request.project_id}: {progress:.2f}% 완료") + + # 검증 데이터 처리 + total_val_data = len(val_data) + for idx, data in enumerate(val_data): + # TODO: 비동기면 await 연결 + # process_image_and_label(data, dataset_root_path, "val") + + # 진행률 계산 + progress = (idx + 1) / total_val_data * 100 + # 웹소켓으로 메시지 전송 (필요할 경우 추가) + await ws_client.send_message("/app/ai/val/progress", f"검증 데이터 처리 중 {request.project_id}: {progress:.2f}% 완료") + + model = load_detection_model("test-data/model/best.pt") + model.train( + data=join_path(dataset_root_path, "dataset.yaml"), + name=join_path(dataset_root_path, "result"), + epochs=request.epochs, + batch=request.batch, ) + + # return FileResponse(path=join_path(dataset_root_path, "result", "weights", "best.pt"), filename="best.pt", media_type="application/octet-stream") + + return {"status": "Training completed successfully"} + + except Exception as e: + print(f"Training process failed: {str(e)}") + raise HTTPException(status_code=500, detail="Training process failed") + + finally: + if ws_client.is_connected(): + await ws_client.close() + + + - return FileResponse(path=join_path(dataset_root_path, "result", "weights", "best.pt"), filename="best.pt", media_type="application/octet-stream") \ No newline at end of file diff --git a/ai/app/services/load_model.py b/ai/app/services/load_model.py index 234948b..8693bd8 100644 --- a/ai/app/services/load_model.py +++ b/ai/app/services/load_model.py @@ -6,7 +6,7 @@ from ultralytics.nn.tasks import DetectionModel, SegmentationModel import os import torch -def load_detection_model(model_path: str = os.path.join("test-data","model","initial.pt"), device:str ="auto"): +def load_detection_model(model_path: str = os.path.join("test-data","model","yolov8n.pt"), device:str ="auto"): """ 지정된 경로에서 YOLO 모델을 로드합니다. @@ -18,7 +18,7 @@ def load_detection_model(model_path: str = os.path.join("test-data","model","ini Returns: YOLO: 로드된 YOLO 모델 인스턴스 """ - + if not os.path.exists(model_path) and model_path != "test-data/model/yolov8n.pt": raise FileNotFoundError(f"Model file not found at path: {model_path}") @@ -26,7 +26,7 @@ def load_detection_model(model_path: str = os.path.join("test-data","model","ini # Detection 모델인지 검증 if not (isinstance(model, YOLO_Model) and isinstance(model.model, DetectionModel)): raise TypeError(f"Invalid model type: {type(model)} (contained model type: {type(model.model)}). Expected a DetectionModel.") - + # gpu 이용 if (device == "auto" and torch.cuda.is_available()): model.to("cuda") diff --git a/ai/app/utils/websocket_utils.py b/ai/app/utils/websocket_utils.py new file mode 100644 index 0000000..f2f0efc --- /dev/null +++ b/ai/app/utils/websocket_utils.py @@ -0,0 +1,36 @@ +import websockets + +class WebSocketClient: + def __init__(self, url: str): + self.url = url + self.websocket = None + + async def connect(self): + try: + self.websocket = await websockets.connect(self.url) + print(f"Connected to WebSocket at {self.url}") + except Exception as e: + print(f"Failed to connect to WebSocket: {str(e)}") + + async def send_message(self, destination: str, message: str): + try: + if self.websocket is not None: + # STOMP 형식의 메시지를 전송 + await self.websocket.send(f"SEND\ndestination:{destination}\n\n{message}\u0000") + print(f"Sent message to {destination}: {message}") + else: + print("WebSocket is not connected. Unable to send message.") + except Exception as e: + print(f"Failed to send message: {str(e)}") + return + + async def close(self): + try: + if self.websocket is not None: + await self.websocket.close() + print("WebSocket connection closed.") + except Exception as e: + print(f"Failed to close WebSocket connection: {str(e)}") + + def is_connected(self): + return self.websocket is not None and self.websocket.open \ No newline at end of file diff --git a/ai/environment.yml b/ai/environment.yml index 785f16e..d84a04a 100644 --- a/ai/environment.yml +++ b/ai/environment.yml @@ -16,4 +16,5 @@ dependencies: - dill - boto3 - python-dotenv - - locust \ No newline at end of file + - locust + - websockets