Merge branch 'ai/feat/detection-autolabel' into 'ai/develop'
Feat: Detection모델 로드 함수 보완 See merge request s11-s-project/S11P21S002!47
This commit is contained in:
commit
4b56061f94
@ -61,9 +61,9 @@ def predict(request: PredictRequest):
|
||||
for summary in result.summary()
|
||||
],
|
||||
"split": "none",
|
||||
"imageHeight": result.orig_shape[0],
|
||||
"imageWidth": result.orig_shape[1],
|
||||
"imageDepth": 1
|
||||
"imageHeight": result.orig_img.shape[0],
|
||||
"imageWidth": result.orig_img.shape[1],
|
||||
"imageDepth": result.orig_img.shape[2]
|
||||
}
|
||||
response.append({
|
||||
"image_id":image.image_id,
|
||||
|
@ -1,10 +1,12 @@
|
||||
# ai_service.py
|
||||
|
||||
from ultralytics import YOLO # Ultralytics YOLO 모델을 가져오기
|
||||
from typing import List
|
||||
from ultralytics.models.yolo.model import YOLO as YOLO_Model
|
||||
from ultralytics.nn.tasks import DetectionModel
|
||||
import os
|
||||
import torch
|
||||
|
||||
def load_detection_model(model_path: str = "test-data/model/yolov8n.pt", device:str ="cpu"):
|
||||
def load_detection_model(model_path: str = "test-data/model/yolov8n.pt", device:str ="auto"):
|
||||
"""
|
||||
지정된 경로에서 YOLO 모델을 로드합니다.
|
||||
|
||||
@ -22,9 +24,18 @@ def load_detection_model(model_path: str = "test-data/model/yolov8n.pt", device:
|
||||
|
||||
try:
|
||||
model = YOLO(model_path)
|
||||
model.to(device)
|
||||
# Detection 모델인지 검증
|
||||
# 코드 추가
|
||||
if not (isinstance(model, YOLO_Model) and isinstance(model.model, DetectionModel)):
|
||||
raise TypeError(f"Invalid model type: {type(model)} (contained model type: {type(model.model)}). Expected a DetectionModel.")
|
||||
|
||||
# gpu 이용
|
||||
if (device == "auto" and torch.cuda.is_available()):
|
||||
model.to("cuda")
|
||||
print('gpu 가속 활성화')
|
||||
elif (device == "auto"):
|
||||
model.to("cpu")
|
||||
else:
|
||||
model.to(device)
|
||||
return model
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load the model from {model_path}. Error: {str(e)}")
|
||||
|
Loading…
Reference in New Issue
Block a user