Refactor: services/load_model 리팩토링
This commit is contained in:
parent
bcae5d5255
commit
47288e6bab
@ -5,8 +5,9 @@ from ultralytics.models.yolo.model import YOLO as YOLO_Model
|
||||
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
||||
import os
|
||||
import torch
|
||||
import re
|
||||
|
||||
def load_detection_model(model_path: str = os.path.join("test-data","model","yolov8n.pt"), device:str ="auto"):
|
||||
def load_detection_model(model_path:str):
|
||||
"""
|
||||
지정된 경로에서 YOLO 모델을 로드합니다.
|
||||
|
||||
@ -19,44 +20,32 @@ def load_detection_model(model_path: str = os.path.join("test-data","model","yol
|
||||
YOLO: 로드된 YOLO 모델 인스턴스
|
||||
"""
|
||||
|
||||
if not os.path.exists(model_path) and model_path != "test-data/model/yolov8n.pt":
|
||||
raise FileNotFoundError(f"Model file not found at path: {model_path}")
|
||||
|
||||
model = YOLO(model_path)
|
||||
# Detection 모델인지 검증
|
||||
if not (isinstance(model, YOLO_Model) and isinstance(model.model, DetectionModel)):
|
||||
raise TypeError(f"Invalid model type: {type(model)} (contained model type: {type(model.model)}). Expected a DetectionModel.")
|
||||
|
||||
# gpu 이용
|
||||
if (device == "auto" and torch.cuda.is_available()):
|
||||
model.to("cuda")
|
||||
print('gpu 가속 활성화')
|
||||
elif (device == "auto"):
|
||||
model.to("cpu")
|
||||
if model_path:
|
||||
model = load_model(model_path)
|
||||
else:
|
||||
model.to(device)
|
||||
model = YOLO(os.path.join("resources","models","yolov8n.pt"))
|
||||
# Detection 모델인지 검증
|
||||
if model.task != "detect":
|
||||
raise TypeError(f"Invalid model type: {model.task}. Expected a DetectionModel.")
|
||||
return model
|
||||
|
||||
def load_segmentation_model(model_path: str = "test-data/model/yolov8n-seg.pt", device:str ="auto"):
|
||||
if not os.path.exists(model_path) and model_path != "test-data/model/yolov8n-seg.pt":
|
||||
raise FileNotFoundError(f"Model file not found at path: {model_path}")
|
||||
|
||||
model = YOLO(model_path)
|
||||
# Segmentation 모델인지 검증
|
||||
if not (isinstance(model, YOLO_Model) and isinstance(model.model, SegmentationModel)):
|
||||
raise TypeError(f"Invalid model type: {type(model)} (contained model type: {type(model.model)}). Expected a SegmentationModel.")
|
||||
|
||||
# gpu 이용
|
||||
if (device == "auto" and torch.cuda.is_available()):
|
||||
model.to("cuda")
|
||||
print('gpu 가속 활성화')
|
||||
elif (device == "auto"):
|
||||
model.to("cpu")
|
||||
def load_segmentation_model(model_path: str):
|
||||
if model_path:
|
||||
model = YOLO(model_path)
|
||||
else:
|
||||
model.to(device)
|
||||
model = YOLO(os.path.join("resources","models","yolov8n-seg.pt"))
|
||||
|
||||
# Segmentation 모델인지 검증
|
||||
if model.task != "segment":
|
||||
raise TypeError(f"Invalid model type: {model.task}. Expected a SegmentationModel.")
|
||||
return model
|
||||
|
||||
def load_model(model_path: str):
|
||||
# model_path 검증
|
||||
pattern = r'^resources[/\\]projects[/\\](\d+)[/\\]models[/\\]([a-f0-9\-]+)\.pt$'
|
||||
if not re.match(pattern, model_path):
|
||||
raise Exception("Invalid path format")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"Model file not found at path: {model_path}")
|
||||
|
||||
@ -64,9 +53,9 @@ def load_model(model_path: str):
|
||||
model = YOLO(model_path)
|
||||
if (torch.cuda.is_available()):
|
||||
model.to("cuda")
|
||||
print("gpu 사용")
|
||||
print("gpu 활성화")
|
||||
else:
|
||||
model.to("cpu")
|
||||
return model
|
||||
except:
|
||||
raise Exception("YOLO model conversion failed: Unsupported architecture or invalid configuration.")
|
||||
raise Exception("YOLO model conversion failed: Unsupported architecture or invalid configuration.")
|
||||
|
Loading…
Reference in New Issue
Block a user