From 1196476286990eb7604c0fb52174351c9b6c473c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A7=84=ED=98=84?= Date: Thu, 5 Sep 2024 11:02:52 +0900 Subject: [PATCH] =?UTF-8?q?Feat:=20Segmentation=20=EC=98=A4=ED=86=A0=20?= =?UTF-8?q?=EB=A0=88=EC=9D=B4=EB=B8=94=EB=A7=81=20API=20=EA=B5=AC=ED=98=84?= =?UTF-8?q?=20-=20S11P21S002-118?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/app/api/yolo/segmentation.py | 63 +++++++++++++++++++++++++++++++++ ai/app/main.py | 2 ++ ai/app/services/ai_service.py | 53 +++++++++++++++++---------- 3 files changed, 99 insertions(+), 19 deletions(-) create mode 100644 ai/app/api/yolo/segmentation.py diff --git a/ai/app/api/yolo/segmentation.py b/ai/app/api/yolo/segmentation.py new file mode 100644 index 0000000..7c801cf --- /dev/null +++ b/ai/app/api/yolo/segmentation.py @@ -0,0 +1,63 @@ +from fastapi import APIRouter, HTTPException +from schemas.predict_request import PredictRequest +from schemas.predict_response import PredictResponse, LabelData +from services.ai_service import load_segmentation_model +from typing import List + +router = APIRouter() + +@router.post("/segmentation", response_model=List[PredictResponse]) +def predict(request: PredictRequest): + version = "0.1.0" + + # 모델 로드 + try: + model = load_segmentation_model() + except Exception as e: + raise HTTPException(status_code=500, detail="load model exception: "+str(e)) + + # 추론 + results = [] + try: + for image in request.image_list: + predict_results = model.predict( + source=image.image_url, + iou=request.iou_threshold, + conf=request.conf_threshold, + classes=request.classes + ) + results.append(predict_results[0]) + except Exception as e: + raise HTTPException(status_code=500, detail="model predict exception: "+str(e)) + + # 추론 결과 -> 레이블 객체 파싱 + response = [] + try: + for (image, result) in zip(request.image_list, results): + label_data:LabelData = { + "version": version, + "task_type": "seg", + "shapes": [ + { + "label": summary['name'], + "color": "#ff0000", + "points": list(zip(summary['segments']['x'], summary['segments']['y'])), + "group_id": summary['class'], + "shape_type": "polygon", + "flags": {} + } + for summary in result.summary() + ], + "split": "none", + "imageHeight": result.orig_img.shape[0], + "imageWidth": result.orig_img.shape[1], + "imageDepth": result.orig_img.shape[2] + } + response.append({ + "image_id":image.image_id, + "image_url":image.image_url, + "data":label_data + }) + except Exception as e: + raise HTTPException(status_code=500, detail="label parsing exception: "+str(e)) + return response \ No newline at end of file diff --git a/ai/app/main.py b/ai/app/main.py index 04321cc..df7d1c8 100644 --- a/ai/app/main.py +++ b/ai/app/main.py @@ -1,10 +1,12 @@ from fastapi import FastAPI from api.yolo.detection import router as yolo_detection_router +from api.yolo.segmentation import router as yolo_segmentation_router app = FastAPI() # 각 기능별 라우터를 애플리케이션에 등록 app.include_router(yolo_detection_router, prefix="/api") +app.include_router(yolo_segmentation_router, prefix="/api") # 애플리케이션 실행 if __name__ == "__main__": diff --git a/ai/app/services/ai_service.py b/ai/app/services/ai_service.py index 133c330..790f61d 100644 --- a/ai/app/services/ai_service.py +++ b/ai/app/services/ai_service.py @@ -2,7 +2,7 @@ from ultralytics import YOLO # Ultralytics YOLO 모델을 가져오기 from ultralytics.models.yolo.model import YOLO as YOLO_Model -from ultralytics.nn.tasks import DetectionModel +from ultralytics.nn.tasks import DetectionModel, SegmentationModel import os import torch @@ -22,21 +22,36 @@ def load_detection_model(model_path: str = "test-data/model/yolov8n.pt", device: if not os.path.exists(model_path) and model_path != "test-data/model/yolov8n.pt": raise FileNotFoundError(f"Model file not found at path: {model_path}") - try: - model = YOLO(model_path) - # Detection 모델인지 검증 - if not (isinstance(model, YOLO_Model) and isinstance(model.model, DetectionModel)): - raise TypeError(f"Invalid model type: {type(model)} (contained model type: {type(model.model)}). Expected a DetectionModel.") - - # gpu 이용 - if (device == "auto" and torch.cuda.is_available()): - model.to("cuda") - print('gpu 가속 활성화') - elif (device == "auto"): - model.to("cpu") - else: - model.to(device) - return model - except Exception as e: - raise RuntimeError(f"Failed to load the model from {model_path}. Error: {str(e)}") - \ No newline at end of file + model = YOLO(model_path) + # Detection 모델인지 검증 + if not (isinstance(model, YOLO_Model) and isinstance(model.model, DetectionModel)): + raise TypeError(f"Invalid model type: {type(model)} (contained model type: {type(model.model)}). Expected a DetectionModel.") + + # gpu 이용 + if (device == "auto" and torch.cuda.is_available()): + model.to("cuda") + print('gpu 가속 활성화') + elif (device == "auto"): + model.to("cpu") + else: + model.to(device) + return model + +def load_segmentation_model(model_path: str = "test-data/model/yolov8n-seg.pt", device:str ="auto"): + if not os.path.exists(model_path) and model_path != "test-data/model/yolov8n-seg.pt": + raise FileNotFoundError(f"Model file not found at path: {model_path}") + + model = YOLO(model_path) + # Segmentation 모델인지 검증 + if not (isinstance(model, YOLO_Model) and isinstance(model.model, SegmentationModel)): + raise TypeError(f"Invalid model type: {type(model)} (contained model type: {type(model.model)}). Expected a SegmentationModel.") + + # gpu 이용 + if (device == "auto" and torch.cuda.is_available()): + model.to("cuda") + print('gpu 가속 활성화') + elif (device == "auto"): + model.to("cpu") + else: + model.to(device) + return model \ No newline at end of file