Merge branch 'ai/feat/117-detection-model-train' into 'ai/develop'
Feat: Detection 모델 학습 API 구현 - S11P21S002-117 See merge request s11-s-project/S11P21S002!56
This commit is contained in:
commit
10f60972b8
7
ai/.gitignore
vendored
7
ai/.gitignore
vendored
@ -32,4 +32,9 @@ dist/
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
# 테스트 파일
|
# 테스트 파일
|
||||||
test-data/
|
test-data/
|
||||||
|
|
||||||
|
# 리소스
|
||||||
|
resources/
|
||||||
|
datasets/
|
||||||
|
*.pt
|
@ -4,9 +4,9 @@ from schemas.train_request import TrainRequest
|
|||||||
from schemas.predict_response import PredictResponse, LabelData
|
from schemas.predict_response import PredictResponse, LabelData
|
||||||
from services.ai_service import load_detection_model
|
from services.ai_service import load_detection_model
|
||||||
from utils.dataset_utils import split_data
|
from utils.dataset_utils import split_data
|
||||||
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label
|
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path
|
||||||
from typing import List
|
from typing import List
|
||||||
from PIL import Image
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@router.post("/detection", response_model=List[PredictResponse])
|
@router.post("/detection", response_model=List[PredictResponse])
|
||||||
@ -96,4 +96,15 @@ def train(request: TrainRequest):
|
|||||||
|
|
||||||
# 검증 데이터 처리
|
# 검증 데이터 처리
|
||||||
for data in val_data:
|
for data in val_data:
|
||||||
process_image_and_label(data, dataset_root_path, "val")
|
process_image_and_label(data, dataset_root_path, "val")
|
||||||
|
|
||||||
|
model = load_detection_model("test-data/model/best.pt")
|
||||||
|
|
||||||
|
model.train(
|
||||||
|
data=join_path(dataset_root_path,"dataset.yaml"),
|
||||||
|
name=join_path(dataset_root_path,"result"),
|
||||||
|
epochs= request.epochs,
|
||||||
|
batch=request.batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FileResponse(path=join_path(dataset_root_path, "result", "weights", "best.pt"), filename="best.pt", media_type="application/octet-stream")
|
@ -1,5 +1,5 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
from schemas.predict_response import LabelData
|
from schemas.predict_response import LabelData
|
||||||
|
|
||||||
class TrainDataInfo(BaseModel):
|
class TrainDataInfo(BaseModel):
|
||||||
@ -11,4 +11,6 @@ class TrainRequest(BaseModel):
|
|||||||
project_id: int
|
project_id: int
|
||||||
data: List[TrainDataInfo]
|
data: List[TrainDataInfo]
|
||||||
seed: Optional[int] = None # 랜덤 변수 시드
|
seed: Optional[int] = None # 랜덤 변수 시드
|
||||||
ratio: Optional[float] = 0.8 # 훈련/검증 분할 비율
|
ratio: float = 0.8 # 훈련/검증 분할 비율
|
||||||
|
epochs: int = 50 # 훈련 반복 횟수
|
||||||
|
batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지
|
||||||
|
@ -6,7 +6,7 @@ from ultralytics.nn.tasks import DetectionModel, SegmentationModel
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
def load_detection_model(model_path: str = "test-data/model/yolov8n.pt", device:str ="auto"):
|
def load_detection_model(model_path: str = os.path.join("test-data","model","initial.pt"), device:str ="auto"):
|
||||||
"""
|
"""
|
||||||
지정된 경로에서 YOLO 모델을 로드합니다.
|
지정된 경로에서 YOLO 모델을 로드합니다.
|
||||||
|
|
||||||
|
@ -5,8 +5,8 @@ from PIL import Image
|
|||||||
from schemas.train_request import TrainDataInfo
|
from schemas.train_request import TrainDataInfo
|
||||||
|
|
||||||
def get_dataset_root_path(project_id):
|
def get_dataset_root_path(project_id):
|
||||||
"""프로젝트 ID를 기반으로 데이터셋 루트 경로 반환"""
|
"""데이터셋 루트 절대 경로 반환"""
|
||||||
return os.path.join('test-data', 'projects', str(project_id), 'train_model')
|
return os.path.join(os.getcwd(), 'datasets', 'train')
|
||||||
|
|
||||||
def make_dir(path:str, init: bool):
|
def make_dir(path:str, init: bool):
|
||||||
"""
|
"""
|
||||||
@ -19,8 +19,8 @@ def make_dir(path:str, init: bool):
|
|||||||
|
|
||||||
def make_yml(path:str):
|
def make_yml(path:str):
|
||||||
data = {
|
data = {
|
||||||
"train": "train",
|
"train": f"{path}/train",
|
||||||
"val": "val",
|
"val": f"{path}/val",
|
||||||
"nc": 80,
|
"nc": 80,
|
||||||
"names":
|
"names":
|
||||||
{
|
{
|
||||||
@ -106,7 +106,7 @@ def make_yml(path:str):
|
|||||||
79: "toothbrush"
|
79: "toothbrush"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
with open(path, 'w') as f:
|
with open(os.path.join(path, "dataset.yaml"), 'w') as f:
|
||||||
yaml.dump(data, f)
|
yaml.dump(data, f)
|
||||||
|
|
||||||
def process_directories(dataset_root_path:str):
|
def process_directories(dataset_root_path:str):
|
||||||
@ -114,7 +114,9 @@ def process_directories(dataset_root_path:str):
|
|||||||
make_dir(dataset_root_path, init=False)
|
make_dir(dataset_root_path, init=False)
|
||||||
make_dir(os.path.join(dataset_root_path, "train"), init=True)
|
make_dir(os.path.join(dataset_root_path, "train"), init=True)
|
||||||
make_dir(os.path.join(dataset_root_path, "val"), init=True)
|
make_dir(os.path.join(dataset_root_path, "val"), init=True)
|
||||||
make_yml(os.path.join(dataset_root_path, "dataset.yaml"))
|
if os.path.exists(os.path.join(dataset_root_path, "result")):
|
||||||
|
shutil.rmtree(os.path.join(dataset_root_path, "result"))
|
||||||
|
make_yml(dataset_root_path)
|
||||||
|
|
||||||
def process_image_and_label(data:TrainDataInfo, dataset_root_path:str, child_path:str):
|
def process_image_and_label(data:TrainDataInfo, dataset_root_path:str, child_path:str):
|
||||||
|
|
||||||
@ -138,8 +140,12 @@ def process_image_and_label(data:TrainDataInfo, dataset_root_path:str, child_pat
|
|||||||
x2 = shape.points[1][0]
|
x2 = shape.points[1][0]
|
||||||
y2 = shape.points[1][1]
|
y2 = shape.points[1][1]
|
||||||
train_label.append(str(shape.group_id)) # label Id
|
train_label.append(str(shape.group_id)) # label Id
|
||||||
train_label.append(str((x1 + x2) / 2)) # 중심 x 좌표
|
train_label.append(str((x1 + x2) / 2 / label.imageWidth)) # 중심 x 좌표
|
||||||
train_label.append(str((y1 + y2) / 2)) # 중심 y 좌표
|
train_label.append(str((y1 + y2) / 2 / label.imageHeight)) # 중심 y 좌표
|
||||||
train_label.append(str(x2 - x1)) # 너비
|
train_label.append(str((x2 - x1) / label.imageWidth)) # 너비
|
||||||
train_label.append(str(y2 - y1)) # 높이
|
train_label.append(str((y2 - y1) / label.imageHeight )) # 높이
|
||||||
train_label_txt.write(" ".join(train_label)+"\n")
|
train_label_txt.write(" ".join(train_label)+"\n")
|
||||||
|
|
||||||
|
def join_path(path, *paths):
|
||||||
|
"""os.path.join()과 같은 기능, os import 하기 싫어서 만듦"""
|
||||||
|
return os.path.join(path, *paths)
|
Loading…
Reference in New Issue
Block a user