worlabel/ai/app/services/load_model.py

60 lines
2.3 KiB
Python
Raw Permalink Normal View History

# ai_service.py
from ultralytics import YOLO # Ultralytics YOLO 모델을 가져오기
import os
import torch
2024-09-25 13:55:44 +09:00
def load_detection_model(project_id:int, model_key:str):
default_model_map = {"yolo8": os.path.join("resources","models","yolov8n.pt")}
# 디폴트 모델 확인
if model_key in default_model_map:
model = YOLO(default_model_map[model_key])
else:
2024-09-26 22:14:16 +09:00
model = load_model(model_path=os.path.join("resources", "projects",str(project_id),"models", model_key))
2024-09-25 13:55:44 +09:00
# Detection 모델인지 검증
if model.task != "detect":
raise TypeError(f"Invalid model type: {model.task}. Expected a DetectionModel.")
return model
2024-09-25 13:55:44 +09:00
def load_segmentation_model(project_id:int, model_key:str):
default_model_map = {"yolo8": os.path.join("resources","models","yolov8n-seg.pt")}
# 디폴트 모델 확인
if model_key in default_model_map:
model = YOLO(default_model_map[model_key])
else:
2024-09-26 22:14:16 +09:00
model = load_model(model_path=os.path.join("resources", "projects",str(project_id),"models",model_key))
# Segmentation 모델인지 검증
if model.task != "segment":
raise TypeError(f"Invalid model type: {model.task}. Expected a SegmentationModel.")
2024-09-18 01:08:40 +09:00
return model
2024-09-26 22:14:16 +09:00
def load_classification_model(project_id:int, model_key:str):
default_model_map = {"yolo8": os.path.join("resources","models","yolov8n-cls.pt")}
# 디폴트 모델 확인
if model_key in default_model_map:
model = YOLO(default_model_map[model_key])
else:
model = load_model(model_path=os.path.join("resources", "projects",str(project_id),"models",model_key))
# Segmentation 모델인지 검증
if model.task != "classify":
raise TypeError(f"Invalid model type: {model.task}. Expected a ClassificationModel.")
return model
2024-09-18 01:08:40 +09:00
def load_model(model_path: str):
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at path: {model_path}")
try:
model = YOLO(model_path)
if (torch.cuda.is_available()):
model.to("cuda")
print("gpu 활성화")
2024-09-18 01:08:40 +09:00
else:
model.to("cpu")
return model
except:
raise Exception("YOLO model conversion failed: Unsupported architecture or invalid configuration.")