From 47288e6bab6317c3af28e0883e9833da727df4ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A7=84=ED=98=84?= Date: Thu, 19 Sep 2024 16:48:19 +0900 Subject: [PATCH] =?UTF-8?q?Refactor:=20services/load=5Fmodel=20=EB=A6=AC?= =?UTF-8?q?=ED=8C=A9=ED=86=A0=EB=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/app/services/load_model.py | 57 ++++++++++++++--------------------- 1 file changed, 23 insertions(+), 34 deletions(-) diff --git a/ai/app/services/load_model.py b/ai/app/services/load_model.py index 3359ba2..fed0d27 100644 --- a/ai/app/services/load_model.py +++ b/ai/app/services/load_model.py @@ -5,8 +5,9 @@ from ultralytics.models.yolo.model import YOLO as YOLO_Model from ultralytics.nn.tasks import DetectionModel, SegmentationModel import os import torch +import re -def load_detection_model(model_path: str = os.path.join("test-data","model","yolov8n.pt"), device:str ="auto"): +def load_detection_model(model_path:str): """ 지정된 경로에서 YOLO 모델을 로드합니다. @@ -19,44 +20,32 @@ def load_detection_model(model_path: str = os.path.join("test-data","model","yol YOLO: 로드된 YOLO 모델 인스턴스 """ - if not os.path.exists(model_path) and model_path != "test-data/model/yolov8n.pt": - raise FileNotFoundError(f"Model file not found at path: {model_path}") - - model = YOLO(model_path) - # Detection 모델인지 검증 - if not (isinstance(model, YOLO_Model) and isinstance(model.model, DetectionModel)): - raise TypeError(f"Invalid model type: {type(model)} (contained model type: {type(model.model)}). Expected a DetectionModel.") - - # gpu 이용 - if (device == "auto" and torch.cuda.is_available()): - model.to("cuda") - print('gpu 가속 활성화') - elif (device == "auto"): - model.to("cpu") + if model_path: + model = load_model(model_path) else: - model.to(device) + model = YOLO(os.path.join("resources","models","yolov8n.pt")) + # Detection 모델인지 검증 + if model.task != "detect": + raise TypeError(f"Invalid model type: {model.task}. Expected a DetectionModel.") return model -def load_segmentation_model(model_path: str = "test-data/model/yolov8n-seg.pt", device:str ="auto"): - if not os.path.exists(model_path) and model_path != "test-data/model/yolov8n-seg.pt": - raise FileNotFoundError(f"Model file not found at path: {model_path}") - - model = YOLO(model_path) - # Segmentation 모델인지 검증 - if not (isinstance(model, YOLO_Model) and isinstance(model.model, SegmentationModel)): - raise TypeError(f"Invalid model type: {type(model)} (contained model type: {type(model.model)}). Expected a SegmentationModel.") - - # gpu 이용 - if (device == "auto" and torch.cuda.is_available()): - model.to("cuda") - print('gpu 가속 활성화') - elif (device == "auto"): - model.to("cpu") +def load_segmentation_model(model_path: str): + if model_path: + model = YOLO(model_path) else: - model.to(device) + model = YOLO(os.path.join("resources","models","yolov8n-seg.pt")) + + # Segmentation 모델인지 검증 + if model.task != "segment": + raise TypeError(f"Invalid model type: {model.task}. Expected a SegmentationModel.") return model def load_model(model_path: str): + # model_path 검증 + pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$' + if not re.match(pattern, model_path): + raise Exception("Invalid path format") + if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found at path: {model_path}") @@ -64,9 +53,9 @@ def load_model(model_path: str): model = YOLO(model_path) if (torch.cuda.is_available()): model.to("cuda") - print("gpu 사용") + print("gpu 활성화") else: model.to("cpu") return model except: - raise Exception("YOLO model conversion failed: Unsupported architecture or invalid configuration.") \ No newline at end of file + raise Exception("YOLO model conversion failed: Unsupported architecture or invalid configuration.")