Feat: Detection AI 모델 훈련 데이터 세팅 - S11P21S002-106
This commit is contained in:
parent
9abf99b992
commit
144d837361
@ -1,8 +1,12 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from schemas.predict_request import PredictRequest
|
||||
from schemas.train_request import TrainRequest
|
||||
from schemas.predict_response import PredictResponse, LabelData
|
||||
from services.ai_service import load_detection_model
|
||||
from utils.dataset_utils import split_data
|
||||
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label
|
||||
from typing import List
|
||||
from PIL import Image
|
||||
|
||||
router = APIRouter()
|
||||
@router.post("/detection", response_model=List[PredictResponse])
|
||||
@ -72,4 +76,24 @@ def predict(request: PredictRequest):
|
||||
})
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="label parsing exception: "+str(e))
|
||||
return response
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/detection/train")
|
||||
def train(request: TrainRequest):
|
||||
# 데이터셋 루트 경로 얻기
|
||||
dataset_root_path = get_dataset_root_path(request.project_id)
|
||||
|
||||
# 디렉토리 생성 및 초기화
|
||||
process_directories(dataset_root_path)
|
||||
|
||||
# 학습 데이터 분류
|
||||
train_data, val_data = split_data(request.data, request.ratio, request.seed)
|
||||
|
||||
# 학습 데이터 처리
|
||||
for data in train_data:
|
||||
process_image_and_label(data, dataset_root_path, "train")
|
||||
|
||||
# 검증 데이터 처리
|
||||
for data in val_data:
|
||||
process_image_and_label(data, dataset_root_path, "val")
|
14
ai/app/schemas/train_request.py
Normal file
14
ai/app/schemas/train_request.py
Normal file
@ -0,0 +1,14 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from schemas.predict_response import LabelData
|
||||
|
||||
class TrainDataInfo(BaseModel):
|
||||
image_url: str
|
||||
label: LabelData
|
||||
|
||||
|
||||
class TrainRequest(BaseModel):
|
||||
project_id: int
|
||||
data: List[TrainDataInfo]
|
||||
seed: Optional[int] = None # 랜덤 변수 시드
|
||||
ratio: Optional[float] = 0.8 # 훈련/검증 분할 비율
|
11
ai/app/utils/dataset_utils.py
Normal file
11
ai/app/utils/dataset_utils.py
Normal file
@ -0,0 +1,11 @@
|
||||
import random
|
||||
from typing import List, Any, Optional
|
||||
|
||||
def split_data(data:List[Any], ratio:float, seed:Optional[int] = None):
|
||||
random.seed(seed)
|
||||
train_size = int(ratio * len(data))
|
||||
random.shuffle(data)
|
||||
random.seed(None)
|
||||
train_data = data[:train_size]
|
||||
val_data = data[train_size:]
|
||||
return train_data, val_data
|
145
ai/app/utils/file_utils.py
Normal file
145
ai/app/utils/file_utils.py
Normal file
@ -0,0 +1,145 @@
|
||||
import os
|
||||
import shutil
|
||||
import yaml
|
||||
from PIL import Image
|
||||
from schemas.train_request import TrainDataInfo
|
||||
|
||||
def get_dataset_root_path(project_id):
|
||||
"""프로젝트 ID를 기반으로 데이터셋 루트 경로 반환"""
|
||||
return os.path.join('test-data', 'projects', str(project_id), 'train_model')
|
||||
|
||||
def make_dir(path:str, init: bool):
|
||||
"""
|
||||
path : 디렉토리 경로
|
||||
init : 폴더를 초기화 할지 여부
|
||||
"""
|
||||
if (os.path.exists(path) and init):
|
||||
shutil.rmtree(path)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
def make_yml(path:str):
|
||||
data = {
|
||||
"train": "train",
|
||||
"val": "val",
|
||||
"nc": 80,
|
||||
"names":
|
||||
{
|
||||
0: "person",
|
||||
1: "bicycle",
|
||||
2: "car",
|
||||
3: "motorcycle",
|
||||
4: "airplane",
|
||||
5: "bus",
|
||||
6: "train",
|
||||
7: "truck",
|
||||
8: "boat",
|
||||
9: "traffic light",
|
||||
10: "fire hydrant",
|
||||
11: "stop sign",
|
||||
12: "parking meter",
|
||||
13: "bench",
|
||||
14: "bird",
|
||||
15: "cat",
|
||||
16: "dog",
|
||||
17: "horse",
|
||||
18: "sheep",
|
||||
19: "cow",
|
||||
20: "elephant",
|
||||
21: "bear",
|
||||
22: "zebra",
|
||||
23: "giraffe",
|
||||
24: "backpack",
|
||||
25: "umbrella",
|
||||
26: "handbag",
|
||||
27: "tie",
|
||||
28: "suitcase",
|
||||
29: "frisbee",
|
||||
30: "skis",
|
||||
31: "snowboard",
|
||||
32: "sports ball",
|
||||
33: "kite",
|
||||
34: "baseball bat",
|
||||
35: "baseball glove",
|
||||
36: "skateboard",
|
||||
37: "surfboard",
|
||||
38: "tennis racket",
|
||||
39: "bottle",
|
||||
40: "wine glass",
|
||||
41: "cup",
|
||||
42: "fork",
|
||||
43: "knife",
|
||||
44: "spoon",
|
||||
45: "bowl",
|
||||
46: "banana",
|
||||
47: "apple",
|
||||
48: "sandwich",
|
||||
49: "orange",
|
||||
50: "broccoli",
|
||||
51: "carrot",
|
||||
52: "hot dog",
|
||||
53: "pizza",
|
||||
54: "donut",
|
||||
55: "cake",
|
||||
56: "chair",
|
||||
57: "couch",
|
||||
58: "potted plant",
|
||||
59: "bed",
|
||||
60: "dining table",
|
||||
61: "toilet",
|
||||
62: "tv",
|
||||
63: "laptop",
|
||||
64: "mouse",
|
||||
65: "remote",
|
||||
66: "keyboard",
|
||||
67: "cell phone",
|
||||
68: "microwave",
|
||||
69: "oven",
|
||||
70: "toaster",
|
||||
71: "sink",
|
||||
72: "refrigerator",
|
||||
73: "book",
|
||||
74: "clock",
|
||||
75: "vase",
|
||||
76: "scissors",
|
||||
77: "teddy bear",
|
||||
78: "hair drier",
|
||||
79: "toothbrush"
|
||||
}
|
||||
}
|
||||
with open(path, 'w') as f:
|
||||
yaml.dump(data, f)
|
||||
|
||||
def process_directories(dataset_root_path:str):
|
||||
"""학습을 위한 디렉토리 생성"""
|
||||
make_dir(dataset_root_path, init=False)
|
||||
make_dir(os.path.join(dataset_root_path, "train"), init=True)
|
||||
make_dir(os.path.join(dataset_root_path, "val"), init=True)
|
||||
make_yml(os.path.join(dataset_root_path, "dataset.yaml"))
|
||||
|
||||
def process_image_and_label(data:TrainDataInfo, dataset_root_path:str, child_path:str):
|
||||
|
||||
"""이미지 저장 및 레이블 파일 생성"""
|
||||
# 이미지 저장
|
||||
img = Image.open(data.image_url)
|
||||
|
||||
# 파일명에서 확장자를 제거하여 img_title과 img_ext 생성
|
||||
img_title, img_ext = os.path.splitext(os.path.basename(data.image_url))
|
||||
|
||||
# 이미지 파일 저장 (확장자를 그대로 사용)
|
||||
img.save(os.path.join(dataset_root_path, child_path, img_title + img_ext))
|
||||
|
||||
# 레이블 -> 학습용 레이블 데이터 파싱(detection)
|
||||
label = data.label
|
||||
with open(os.path.join(dataset_root_path, child_path, f"{img_title}.txt"), "w") as train_label_txt:
|
||||
for shape in label.shapes:
|
||||
train_label = []
|
||||
x1 = shape.points[0][0]
|
||||
y1 = shape.points[0][1]
|
||||
x2 = shape.points[1][0]
|
||||
y2 = shape.points[1][1]
|
||||
train_label.append(str(shape.group_id)) # label Id
|
||||
train_label.append(str((x1 + x2) / 2)) # 중심 x 좌표
|
||||
train_label.append(str((y1 + y2) / 2)) # 중심 y 좌표
|
||||
train_label.append(str(x2 - x1)) # 너비
|
||||
train_label.append(str(y2 - y1)) # 높이
|
||||
train_label_txt.write(" ".join(train_label)+"\n")
|
Loading…
Reference in New Issue
Block a user