Refactor: services/load_model 리팩토링

This commit is contained in:
김진현 2024-09-19 16:48:19 +09:00
parent bcae5d5255
commit 47288e6bab

View File

@ -5,8 +5,9 @@ from ultralytics.models.yolo.model import YOLO as YOLO_Model
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
import os
import torch
import re
def load_detection_model(model_path: str = os.path.join("test-data","model","yolov8n.pt"), device:str ="auto"):
def load_detection_model(model_path:str):
"""
지정된 경로에서 YOLO 모델을 로드합니다.
@ -19,44 +20,32 @@ def load_detection_model(model_path: str = os.path.join("test-data","model","yol
YOLO: 로드된 YOLO 모델 인스턴스
"""
if not os.path.exists(model_path) and model_path != "test-data/model/yolov8n.pt":
raise FileNotFoundError(f"Model file not found at path: {model_path}")
model = YOLO(model_path)
# Detection 모델인지 검증
if not (isinstance(model, YOLO_Model) and isinstance(model.model, DetectionModel)):
raise TypeError(f"Invalid model type: {type(model)} (contained model type: {type(model.model)}). Expected a DetectionModel.")
# gpu 이용
if (device == "auto" and torch.cuda.is_available()):
model.to("cuda")
print('gpu 가속 활성화')
elif (device == "auto"):
model.to("cpu")
if model_path:
model = load_model(model_path)
else:
model.to(device)
model = YOLO(os.path.join("resources","models","yolov8n.pt"))
# Detection 모델인지 검증
if model.task != "detect":
raise TypeError(f"Invalid model type: {model.task}. Expected a DetectionModel.")
return model
def load_segmentation_model(model_path: str = "test-data/model/yolov8n-seg.pt", device:str ="auto"):
if not os.path.exists(model_path) and model_path != "test-data/model/yolov8n-seg.pt":
raise FileNotFoundError(f"Model file not found at path: {model_path}")
model = YOLO(model_path)
# Segmentation 모델인지 검증
if not (isinstance(model, YOLO_Model) and isinstance(model.model, SegmentationModel)):
raise TypeError(f"Invalid model type: {type(model)} (contained model type: {type(model.model)}). Expected a SegmentationModel.")
# gpu 이용
if (device == "auto" and torch.cuda.is_available()):
model.to("cuda")
print('gpu 가속 활성화')
elif (device == "auto"):
model.to("cpu")
def load_segmentation_model(model_path: str):
if model_path:
model = YOLO(model_path)
else:
model.to(device)
model = YOLO(os.path.join("resources","models","yolov8n-seg.pt"))
# Segmentation 모델인지 검증
if model.task != "segment":
raise TypeError(f"Invalid model type: {model.task}. Expected a SegmentationModel.")
return model
def load_model(model_path: str):
# model_path 검증
pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$'
if not re.match(pattern, model_path):
raise Exception("Invalid path format")
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at path: {model_path}")
@ -64,9 +53,9 @@ def load_model(model_path: str):
model = YOLO(model_path)
if (torch.cuda.is_available()):
model.to("cuda")
print("gpu 사용")
print("gpu 활성화")
else:
model.to("cpu")
return model
except:
raise Exception("YOLO model conversion failed: Unsupported architecture or invalid configuration.")
raise Exception("YOLO model conversion failed: Unsupported architecture or invalid configuration.")