From 4fc13fa60c28009cd7d98259866361ba5b32104d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=9A=A9=EC=88=98?= Date: Thu, 12 Sep 2024 00:53:45 +0900 Subject: [PATCH 1/2] =?UTF-8?q?Feat:=20Fast=20API=20=EC=86=8C=EC=BC=93=20?= =?UTF-8?q?=EC=97=B4=EA=B8=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/app/api/yolo/detection.py | 77 ++++++++++++++++++++--------- ai/app/handler/__init__.py | 0 ai/app/handler/websocket_handler.py | 35 +++++++++++++ ai/environment.yml | 3 +- 4 files changed, 91 insertions(+), 24 deletions(-) create mode 100644 ai/app/handler/__init__.py create mode 100644 ai/app/handler/websocket_handler.py diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index 82ac1ac..ad6fd0d 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -7,6 +7,8 @@ from utils.dataset_utils import split_data from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path from typing import List from fastapi.responses import FileResponse +from handler.websocket_handler import WebSocketClient +import asyncio router = APIRouter() @router.post("/detection", response_model=List[PredictResponse]) @@ -80,31 +82,60 @@ def predict(request: PredictRequest): @router.post("/detection/train") -def train(request: TrainRequest): - # 데이터셋 루트 경로 얻기 - dataset_root_path = get_dataset_root_path(request.project_id) +async def train(request: TrainRequest): + # Spring 서버의 WebSocket URL + # spring_server_ws_url = f"ws://localhost:8080/ws/ai/train/progress/{request.project_id}" + spring_server_ws_url = f"ws://localhost:8080/ws" - # 디렉토리 생성 및 초기화 - process_directories(dataset_root_path) - - # 학습 데이터 분류 - train_data, val_data = split_data(request.data, request.ratio, request.seed) - - # 학습 데이터 처리 - for data in train_data: - process_image_and_label(data, dataset_root_path, "train") + # WebSocketClient 인스턴스 생성 + print("연결 요청 - " + spring_server_ws_url) + ws_client = WebSocketClient(spring_server_ws_url) - # 검증 데이터 처리 - for data in val_data: - process_image_and_label(data, dataset_root_path, "val") + try: + await ws_client.connect() - model = load_detection_model("test-data/model/best.pt") + await ws_client.send_message("/app/ai/train/progress", f"Training started for project {request.project_id}") - model.train( - data=join_path(dataset_root_path,"dataset.yaml"), - name=join_path(dataset_root_path,"result"), - epochs= request.epochs, - batch=request.batch, - ) + for i in range(1, 31): + await ws_client.send_message("/app/ai/train/progress", f"Training progress: {i}/30") + await asyncio.sleep(1) + + await ws_client.send_message("/app/ai/train/progress", "Training complete") + + return {"status": "Training completed successfully"} + + except Exception as e: + logging.error(f"Training process failed: {str(e)}") + raise HTTPException(status_code=500, detail="Training process failed") + + finally: + await ws_client.close() - return FileResponse(path=join_path(dataset_root_path, "result", "weights", "best.pt"), filename="best.pt", media_type="application/octet-stream") \ No newline at end of file + + # # 데이터셋 루트 경로 얻기 + # dataset_root_path = get_dataset_root_path(request.project_id) + + # # 디렉토리 생성 및 초기화 + # process_directories(dataset_root_path) + + # # 학습 데이터 분류 + # train_data, val_data = split_data(request.data, request.ratio, request.seed) + + # # 학습 데이터 처리 + # for data in train_data: + # process_image_and_label(data, dataset_root_path, "train") + + # # 검증 데이터 처리 + # for data in val_data: + # process_image_and_label(data, dataset_root_path, "val") + + # model = load_detection_model("test-data/model/best.pt") + + # model.train( + # data=join_path(dataset_root_path,"dataset.yaml"), + # name=join_path(dataset_root_path,"result"), + # epochs= request.epochs, + # batch=request.batch, + # ) + + # return FileResponse(path=join_path(dataset_root_path, "result", "weights", "best.pt"), filename="best.pt", media_type="application/octet-stream") \ No newline at end of file diff --git a/ai/app/handler/__init__.py b/ai/app/handler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ai/app/handler/websocket_handler.py b/ai/app/handler/websocket_handler.py new file mode 100644 index 0000000..20eb5a8 --- /dev/null +++ b/ai/app/handler/websocket_handler.py @@ -0,0 +1,35 @@ +import asyncio +import websockets +import logging + +class WebSocketClient: + def __init__(self, url: str): + self.url = url + self.websocket = None + + async def connect(self): + try: + self.websocket = await websockets.connect(self.url) + print(f"Connected to WebSocket at {self.url}") + except Exception as e: + print(f"Failed to connect to WebSocket: {str(e)}") + + async def send_message(self, destination: str, message: str): + try: + if self.websocket is not None: + # STOMP 형식의 메시지를 전송 + await self.websocket.send(f"SEND\ndestination:{destination}\n\n{message}\u0000") + print(f"Sent message to {destination}: {message}") + else: + print("WebSocket is not connected. Unable to send message.") + except Exception as e: + print(f"Failed to send message: {str(e)}") + return + + async def close(self): + try: + if self.websocket is not None: + await self.websocket.close() + print("WebSocket connection closed.") + except Exception as e: + print(f"Failed to close WebSocket connection: {str(e)}") \ No newline at end of file diff --git a/ai/environment.yml b/ai/environment.yml index 785f16e..d84a04a 100644 --- a/ai/environment.yml +++ b/ai/environment.yml @@ -16,4 +16,5 @@ dependencies: - dill - boto3 - python-dotenv - - locust \ No newline at end of file + - locust + - websockets From fea960cc0d59cf7434cbbe8ab5714f41d9bb46d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=9A=A9=EC=88=98?= Date: Thu, 12 Sep 2024 17:05:32 +0900 Subject: [PATCH 2/2] =?UTF-8?q?Feat:=20WebSocket=20=EC=9D=B8=EC=8A=A4?= =?UTF-8?q?=ED=84=B4=EC=8A=A4=20=EC=83=9D=EC=84=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/app/api/yolo/detection.py | 281 +++++++++++------- ai/app/handler/__init__.py | 0 ai/app/services/ai_service.py | 6 +- .../websocket_utils.py} | 7 +- 4 files changed, 181 insertions(+), 113 deletions(-) delete mode 100644 ai/app/handler/__init__.py rename ai/app/{handler/websocket_handler.py => utils/websocket_utils.py} (92%) diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index ad6fd0d..6a8a9fd 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -1,141 +1,208 @@ +import json + from fastapi import APIRouter, HTTPException from schemas.predict_request import PredictRequest from schemas.train_request import TrainRequest from schemas.predict_response import PredictResponse, LabelData from services.ai_service import load_detection_model -from utils.dataset_utils import split_data -from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path from typing import List -from fastapi.responses import FileResponse -from handler.websocket_handler import WebSocketClient +from utils.websocket_utils import WebSocketClient import asyncio + router = APIRouter() + + @router.post("/detection", response_model=List[PredictResponse]) -def predict(request: PredictRequest): +async def predict(request: PredictRequest): version = "0.1.0" + print("여기") - # 모델 로드 - try: - model = load_detection_model() - except Exception as e: - raise HTTPException(status_code=500, detail="load model exception: "+str(e)) - - # 추론 - results = [] - try: - for image in request.image_list: - # URL에서 이미지를 메모리로 로드 TODO: 추후 메모리에 할지 어떻게 해야할지 or 병렬 처리 고민 - # response = requests.get(image.image_url) - - # 이미지 데이터를 메모리로 로드 - # img = Image.open(io.BytesIO(response.content)) - - predict_results = model.predict( - source=image.image_url, - iou=request.iou_threshold, - conf=request.conf_threshold, - classes=request.classes - ) - results.append(predict_results[0]) - - # 메모리에서 이미지 객체 해제 - # img.close() - # del img - except Exception as e: - raise HTTPException(status_code=500, detail="model predict exception: "+str(e)) - - # 추론 결과 -> 레이블 객체 파싱 - response = [] - try: - for (image, result) in zip(request.image_list, results): - label_data:LabelData = { - "version": version, - "task_type": "det", - "shapes": [ - { - "label": summary['name'], - "color": "#ff0000", - "points": [ - [summary['box']['x1'], summary['box']['y1']], - [summary['box']['x2'], summary['box']['y2']] - ], - "group_id": summary['class'], - "shape_type": "rectangle", - "flags": {} - } - for summary in result.summary() - ], - "split": "none", - "imageHeight": result.orig_img.shape[0], - "imageWidth": result.orig_img.shape[1], - "imageDepth": result.orig_img.shape[2] - } - response.append({ - "image_id":image.image_id, - "image_url":image.image_url, - "data":label_data - }) - except Exception as e: - raise HTTPException(status_code=500, detail="label parsing exception: "+str(e)) - return response - - -@router.post("/detection/train") -async def train(request: TrainRequest): # Spring 서버의 WebSocket URL - # spring_server_ws_url = f"ws://localhost:8080/ws/ai/train/progress/{request.project_id}" + # TODO: 배포 시 변경 spring_server_ws_url = f"ws://localhost:8080/ws" + print("여기") # WebSocketClient 인스턴스 생성 - print("연결 요청 - " + spring_server_ws_url) ws_client = WebSocketClient(spring_server_ws_url) try: await ws_client.connect() - await ws_client.send_message("/app/ai/train/progress", f"Training started for project {request.project_id}") + # 모델 로드 + try: + model = load_detection_model() + except Exception as e: + raise HTTPException(status_code=500, detail="load model exception: " + str(e)) - for i in range(1, 31): - await ws_client.send_message("/app/ai/train/progress", f"Training progress: {i}/30") - await asyncio.sleep(1) + # 추론 + results = [] + total_images = len(request.image_list) + for idx, image in enumerate(request.image_list): + try: + # URL에서 이미지를 메모리로 로드 TODO: 추후 메모리에 할지 어떻게 해야할지 or 병렬 처리 고민 - await ws_client.send_message("/app/ai/train/progress", "Training complete") + predict_results = model.predict( + source=image.image_url, + iou=request.iou_threshold, + conf=request.conf_threshold, + classes=request.classes + ) + # 예측 결과 처리 + result = predict_results[0] + label_data = LabelData( + version=version, + task_type="det", + shapes=[ + { + "label": summary['name'], + "color": "#ff0000", + "points": [ + [summary['box']['x1'], summary['box']['y1']], + [summary['box']['x2'], summary['box']['y2']] + ], + "group_id": summary['class'], + "shape_type": "rectangle", + "flags": {} + } + for summary in result.summary() + ], + split="none", + imageHeight=result.orig_img.shape[0], + imageWidth=result.orig_img.shape[1], + imageDepth=result.orig_img.shape[2] + ) + + response_item = PredictResponse( + image_id=image.image_id, + image_url=image.image_url, + data=label_data + ) + + # 진행률 계산 + progress = (idx + 1) / total_images * 100 + + # 웹소켓으로 예측 결과와 진행률 전송 + message = { + "project_id": request.project_id, + "progress": progress, + "result": response_item.dict() + } + + await ws_client.send_message("/app/ai/predict/progress", json.dumps(message)) + + except Exception as e: + raise HTTPException(status_code=500, detail="model predict exception: " + str(e)) + + # 추론 결과 -> 레이블 객체 파싱 + response = [] + try: + for (image, result) in zip(request.image_list, results): + label_data: LabelData = { + "version": version, + "task_type": "det", + "shapes": [ + { + "label": summary['name'], + "color": "#ff0000", + "points": [ + [summary['box']['x1'], summary['box']['y1']], + [summary['box']['x2'], summary['box']['y2']] + ], + "group_id": summary['class'], + "shape_type": "rectangle", + "flags": {} + } + for summary in result.summary() + ], + "split": "none", + "imageHeight": result.orig_img.shape[0], + "imageWidth": result.orig_img.shape[1], + "imageDepth": result.orig_img.shape[2] + } + response.append({ + "image_id": image.image_id, + "image_url": image.image_url, + "data": label_data + }) + except Exception as e: + raise HTTPException(status_code=500, detail="label parsing exception: " + str(e)) + + return response + + except Exception as e: + print(f"Prediction process failed: {str(e)}") + raise HTTPException(status_code=500, detail="Prediction process failed") + + finally: + if ws_client.is_connected(): + await ws_client.close() + + +@router.post("/detection/train") +async def train(request: TrainRequest): + # 데이터셋 루트 경로 얻기 + dataset_root_path = get_dataset_root_path(request.project_id) + + # 디렉토리 생성 및 초기화 + process_directories(dataset_root_path) + + # 학습 데이터 분류 + train_data, val_data = split_data(request.data, request.ratio, request.seed) + + # Spring 서버의 WebSocket URL + # TODO: 배포시에 변경 + spring_server_ws_url = f"ws://localhost:8080/ws" + + # WebSocketClient 인스턴스 생성 + ws_client = WebSocketClient(spring_server_ws_url) + + + try: + await ws_client.connect() + + # 학습 데이터 처리 + total_data = len(train_data) + for idx, data in enumerate(train_data): + # TODO: 비동기면 await 연결 + # process_image_and_label(data, dataset_root_path, "train") + + # 진행률 계산 + progress = (idx + 1) / total_data * 100 + + await ws_client.send_message("/app/ai/train/progress", f"학습 데이터 처리 중 {request.project_id}: {progress:.2f}% 완료") + + # 검증 데이터 처리 + total_val_data = len(val_data) + for idx, data in enumerate(val_data): + # TODO: 비동기면 await 연결 + # process_image_and_label(data, dataset_root_path, "val") + + # 진행률 계산 + progress = (idx + 1) / total_val_data * 100 + # 웹소켓으로 메시지 전송 (필요할 경우 추가) + await ws_client.send_message("/app/ai/val/progress", f"검증 데이터 처리 중 {request.project_id}: {progress:.2f}% 완료") + + model = load_detection_model("test-data/model/best.pt") + model.train( + data=join_path(dataset_root_path, "dataset.yaml"), + name=join_path(dataset_root_path, "result"), + epochs=request.epochs, + batch=request.batch, + ) + + # return FileResponse(path=join_path(dataset_root_path, "result", "weights", "best.pt"), filename="best.pt", media_type="application/octet-stream") return {"status": "Training completed successfully"} except Exception as e: - logging.error(f"Training process failed: {str(e)}") + print(f"Training process failed: {str(e)}") raise HTTPException(status_code=500, detail="Training process failed") finally: - await ws_client.close() + if ws_client.is_connected(): + await ws_client.close() - # # 데이터셋 루트 경로 얻기 - # dataset_root_path = get_dataset_root_path(request.project_id) - - # # 디렉토리 생성 및 초기화 - # process_directories(dataset_root_path) - - # # 학습 데이터 분류 - # train_data, val_data = split_data(request.data, request.ratio, request.seed) - - # # 학습 데이터 처리 - # for data in train_data: - # process_image_and_label(data, dataset_root_path, "train") - # # 검증 데이터 처리 - # for data in val_data: - # process_image_and_label(data, dataset_root_path, "val") - - # model = load_detection_model("test-data/model/best.pt") - - # model.train( - # data=join_path(dataset_root_path,"dataset.yaml"), - # name=join_path(dataset_root_path,"result"), - # epochs= request.epochs, - # batch=request.batch, - # ) - # return FileResponse(path=join_path(dataset_root_path, "result", "weights", "best.pt"), filename="best.pt", media_type="application/octet-stream") \ No newline at end of file diff --git a/ai/app/handler/__init__.py b/ai/app/handler/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/ai/app/services/ai_service.py b/ai/app/services/ai_service.py index 234948b..8693bd8 100644 --- a/ai/app/services/ai_service.py +++ b/ai/app/services/ai_service.py @@ -6,7 +6,7 @@ from ultralytics.nn.tasks import DetectionModel, SegmentationModel import os import torch -def load_detection_model(model_path: str = os.path.join("test-data","model","initial.pt"), device:str ="auto"): +def load_detection_model(model_path: str = os.path.join("test-data","model","yolov8n.pt"), device:str ="auto"): """ 지정된 경로에서 YOLO 모델을 로드합니다. @@ -18,7 +18,7 @@ def load_detection_model(model_path: str = os.path.join("test-data","model","ini Returns: YOLO: 로드된 YOLO 모델 인스턴스 """ - + if not os.path.exists(model_path) and model_path != "test-data/model/yolov8n.pt": raise FileNotFoundError(f"Model file not found at path: {model_path}") @@ -26,7 +26,7 @@ def load_detection_model(model_path: str = os.path.join("test-data","model","ini # Detection 모델인지 검증 if not (isinstance(model, YOLO_Model) and isinstance(model.model, DetectionModel)): raise TypeError(f"Invalid model type: {type(model)} (contained model type: {type(model.model)}). Expected a DetectionModel.") - + # gpu 이용 if (device == "auto" and torch.cuda.is_available()): model.to("cuda") diff --git a/ai/app/handler/websocket_handler.py b/ai/app/utils/websocket_utils.py similarity index 92% rename from ai/app/handler/websocket_handler.py rename to ai/app/utils/websocket_utils.py index 20eb5a8..f2f0efc 100644 --- a/ai/app/handler/websocket_handler.py +++ b/ai/app/utils/websocket_utils.py @@ -1,6 +1,4 @@ -import asyncio import websockets -import logging class WebSocketClient: def __init__(self, url: str): @@ -32,4 +30,7 @@ class WebSocketClient: await self.websocket.close() print("WebSocket connection closed.") except Exception as e: - print(f"Failed to close WebSocket connection: {str(e)}") \ No newline at end of file + print(f"Failed to close WebSocket connection: {str(e)}") + + def is_connected(self): + return self.websocket is not None and self.websocket.open \ No newline at end of file