Refactor: 레이블 카테고리를 포함한 학습 구현 및 리팩토링 - S11P21S002-228
This commit is contained in:
parent
ae3506d2eb
commit
1b17cfadad
@ -5,6 +5,7 @@ from schemas.predict_request import PredictRequest
|
|||||||
from schemas.train_request import TrainRequest
|
from schemas.train_request import TrainRequest
|
||||||
from schemas.predict_response import PredictResponse, LabelData
|
from schemas.predict_response import PredictResponse, LabelData
|
||||||
from services.load_model import load_detection_model
|
from services.load_model import load_detection_model
|
||||||
|
from services.create_model import save_model
|
||||||
from utils.dataset_utils import split_data
|
from utils.dataset_utils import split_data
|
||||||
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path, get_model_path
|
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path, get_model_path
|
||||||
from utils.websocket_utils import WebSocketClient, WebSocketConnectionException
|
from utils.websocket_utils import WebSocketClient, WebSocketConnectionException
|
||||||
@ -159,15 +160,6 @@ async def detection_predict(request: PredictRequest):
|
|||||||
|
|
||||||
@router.post("/train")
|
@router.post("/train")
|
||||||
async def detection_train(request: TrainRequest):
|
async def detection_train(request: TrainRequest):
|
||||||
# 데이터셋 루트 경로 얻기
|
|
||||||
dataset_root_path = get_dataset_root_path(request.project_id)
|
|
||||||
|
|
||||||
# 디렉토리 생성 및 초기화
|
|
||||||
process_directories(dataset_root_path)
|
|
||||||
|
|
||||||
# 학습 데이터 분류
|
|
||||||
train_data, val_data = split_data(request.data, request.ratio, request.seed)
|
|
||||||
|
|
||||||
# Spring 서버의 WebSocket URL
|
# Spring 서버의 WebSocket URL
|
||||||
# TODO: 배포시에 변경
|
# TODO: 배포시에 변경
|
||||||
spring_server_ws_url = f"ws://localhost:8080/ws"
|
spring_server_ws_url = f"ws://localhost:8080/ws"
|
||||||
@ -175,9 +167,34 @@ async def detection_train(request: TrainRequest):
|
|||||||
# WebSocketClient 인스턴스 생성
|
# WebSocketClient 인스턴스 생성
|
||||||
ws_client = WebSocketClient(spring_server_ws_url)
|
ws_client = WebSocketClient(spring_server_ws_url)
|
||||||
|
|
||||||
|
# 데이터셋 루트 경로 얻기
|
||||||
|
dataset_root_path = get_dataset_root_path(request.project_id)
|
||||||
|
|
||||||
|
# 모델 로드
|
||||||
|
try:
|
||||||
|
model_path = request.m_key and get_model_path(request.project_id, request.m_key)
|
||||||
|
model = load_detection_model(model_path=model_path)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="load model exception: " + str(e))
|
||||||
|
|
||||||
|
# 학습할 모델 카테고리 정리 카테고리가 추가되는 경우에 추가할 수 있게
|
||||||
|
names = model.names
|
||||||
|
|
||||||
|
# 디렉토리 생성 및 초기화
|
||||||
|
process_directories(dataset_root_path, names)
|
||||||
|
|
||||||
|
# 레이블 맵
|
||||||
|
inverted_label_map = None
|
||||||
|
if request.label_map:
|
||||||
|
inverted_label_map = {value: key for key, value in request.label_map.items()}
|
||||||
|
|
||||||
|
# 학습 데이터 분류
|
||||||
|
train_data, val_data = split_data(request.data, request.ratio, request.seed)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await ws_client.connect()
|
await ws_client.connect()
|
||||||
|
if not ws_client.is_connected():
|
||||||
|
raise WebSocketConnectionException()
|
||||||
|
|
||||||
# 학습 데이터 처리
|
# 학습 데이터 처리
|
||||||
total_data = len(train_data)
|
total_data = len(train_data)
|
||||||
@ -208,11 +225,35 @@ async def detection_train(request: TrainRequest):
|
|||||||
epochs=request.epochs,
|
epochs=request.epochs,
|
||||||
batch=request.batch,
|
batch=request.batch,
|
||||||
)
|
)
|
||||||
|
|
||||||
# return FileResponse(path=join_path(dataset_root_path, "result", "weights", "best.pt"), filename="best.pt", media_type="application/octet-stream")
|
# return FileResponse(path=join_path(dataset_root_path, "result", "weights", "best.pt"), filename="best.pt", media_type="application/octet-stream")
|
||||||
|
|
||||||
return {"status": "Training completed successfully"}
|
return {"status": "Training completed successfully"}
|
||||||
|
|
||||||
|
except WebSocketConnectionException as e:
|
||||||
|
|
||||||
|
# 학습 데이터 처리
|
||||||
|
total_data = len(train_data)
|
||||||
|
for idx, data in enumerate(train_data):
|
||||||
|
# TODO: 비동기면 await 연결
|
||||||
|
process_image_and_label(data, dataset_root_path, "train", inverted_label_map)
|
||||||
|
|
||||||
|
# 검증 데이터 처리
|
||||||
|
total_val_data = len(val_data)
|
||||||
|
for idx, data in enumerate(val_data):
|
||||||
|
# TODO: 비동기면 await 연결
|
||||||
|
process_image_and_label(data, dataset_root_path, "val", inverted_label_map)
|
||||||
|
|
||||||
|
results = model.train(
|
||||||
|
data=join_path(dataset_root_path, "dataset.yaml"),
|
||||||
|
name=join_path(dataset_root_path, "result"),
|
||||||
|
epochs=request.epochs,
|
||||||
|
batch=request.batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "last.pt"))
|
||||||
|
|
||||||
|
return {"model_key": model_key, "results": results.results_dict}
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Training process failed: {str(e)}")
|
print(f"Training process failed: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail="Training process failed")
|
raise HTTPException(status_code=500, detail="Training process failed")
|
||||||
|
@ -8,9 +8,10 @@ class TrainDataInfo(BaseModel):
|
|||||||
|
|
||||||
class TrainRequest(BaseModel):
|
class TrainRequest(BaseModel):
|
||||||
project_id: int
|
project_id: int
|
||||||
|
m_key: Optional[str] = Field(None, alias="model_key")
|
||||||
|
label_map: dict[int, int] = Field(None, description="모델 레이블 카테고리 idx: 프로젝트 레이블 카테고리 idx , None 일경우 레이블 데이터(프로젝트 레이블)의 idx로 학습")
|
||||||
data: List[TrainDataInfo]
|
data: List[TrainDataInfo]
|
||||||
seed: Optional[int] = None # 랜덤 변수 시드
|
seed: Optional[int] = None # 랜덤 변수 시드
|
||||||
ratio: float = 0.8 # 훈련/검증 분할 비율
|
ratio: float = 0.8 # 훈련/검증 분할 비율
|
||||||
epochs: int = 50 # 훈련 반복 횟수
|
epochs: int = 50 # 훈련 반복 횟수
|
||||||
batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지
|
batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지
|
||||||
path: Optional[str] = Field(None, alias="model_path")
|
|
||||||
|
@ -6,7 +6,7 @@ from schemas.train_request import TrainDataInfo
|
|||||||
|
|
||||||
def get_dataset_root_path(project_id):
|
def get_dataset_root_path(project_id):
|
||||||
"""데이터셋 루트 절대 경로 반환"""
|
"""데이터셋 루트 절대 경로 반환"""
|
||||||
return os.path.join(os.getcwd(), 'datasets', 'train')
|
return os.path.join(os.getcwd(), 'resources', 'projects', str(project_id), "train")
|
||||||
|
|
||||||
def make_dir(path:str, init: bool):
|
def make_dir(path:str, init: bool):
|
||||||
"""
|
"""
|
||||||
@ -17,108 +17,26 @@ def make_dir(path:str, init: bool):
|
|||||||
shutil.rmtree(path)
|
shutil.rmtree(path)
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
|
|
||||||
def make_yml(path:str):
|
def make_yml(path:str, names):
|
||||||
data = {
|
data = {
|
||||||
"train": f"{path}/train",
|
"train": f"{path}/train",
|
||||||
"val": f"{path}/val",
|
"val": f"{path}/val",
|
||||||
"nc": 80,
|
"nc": 80,
|
||||||
"names":
|
"names": names
|
||||||
{
|
|
||||||
0: "person",
|
|
||||||
1: "bicycle",
|
|
||||||
2: "car",
|
|
||||||
3: "motorcycle",
|
|
||||||
4: "airplane",
|
|
||||||
5: "bus",
|
|
||||||
6: "train",
|
|
||||||
7: "truck",
|
|
||||||
8: "boat",
|
|
||||||
9: "traffic light",
|
|
||||||
10: "fire hydrant",
|
|
||||||
11: "stop sign",
|
|
||||||
12: "parking meter",
|
|
||||||
13: "bench",
|
|
||||||
14: "bird",
|
|
||||||
15: "cat",
|
|
||||||
16: "dog",
|
|
||||||
17: "horse",
|
|
||||||
18: "sheep",
|
|
||||||
19: "cow",
|
|
||||||
20: "elephant",
|
|
||||||
21: "bear",
|
|
||||||
22: "zebra",
|
|
||||||
23: "giraffe",
|
|
||||||
24: "backpack",
|
|
||||||
25: "umbrella",
|
|
||||||
26: "handbag",
|
|
||||||
27: "tie",
|
|
||||||
28: "suitcase",
|
|
||||||
29: "frisbee",
|
|
||||||
30: "skis",
|
|
||||||
31: "snowboard",
|
|
||||||
32: "sports ball",
|
|
||||||
33: "kite",
|
|
||||||
34: "baseball bat",
|
|
||||||
35: "baseball glove",
|
|
||||||
36: "skateboard",
|
|
||||||
37: "surfboard",
|
|
||||||
38: "tennis racket",
|
|
||||||
39: "bottle",
|
|
||||||
40: "wine glass",
|
|
||||||
41: "cup",
|
|
||||||
42: "fork",
|
|
||||||
43: "knife",
|
|
||||||
44: "spoon",
|
|
||||||
45: "bowl",
|
|
||||||
46: "banana",
|
|
||||||
47: "apple",
|
|
||||||
48: "sandwich",
|
|
||||||
49: "orange",
|
|
||||||
50: "broccoli",
|
|
||||||
51: "carrot",
|
|
||||||
52: "hot dog",
|
|
||||||
53: "pizza",
|
|
||||||
54: "donut",
|
|
||||||
55: "cake",
|
|
||||||
56: "chair",
|
|
||||||
57: "couch",
|
|
||||||
58: "potted plant",
|
|
||||||
59: "bed",
|
|
||||||
60: "dining table",
|
|
||||||
61: "toilet",
|
|
||||||
62: "tv",
|
|
||||||
63: "laptop",
|
|
||||||
64: "mouse",
|
|
||||||
65: "remote",
|
|
||||||
66: "keyboard",
|
|
||||||
67: "cell phone",
|
|
||||||
68: "microwave",
|
|
||||||
69: "oven",
|
|
||||||
70: "toaster",
|
|
||||||
71: "sink",
|
|
||||||
72: "refrigerator",
|
|
||||||
73: "book",
|
|
||||||
74: "clock",
|
|
||||||
75: "vase",
|
|
||||||
76: "scissors",
|
|
||||||
77: "teddy bear",
|
|
||||||
78: "hair drier",
|
|
||||||
79: "toothbrush"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
with open(os.path.join(path, "dataset.yaml"), 'w') as f:
|
with open(os.path.join(path, "dataset.yaml"), 'w') as f:
|
||||||
yaml.dump(data, f)
|
yaml.dump(data, f)
|
||||||
|
|
||||||
def process_directories(dataset_root_path:str):
|
def process_directories(dataset_root_path:str, names:list[str]):
|
||||||
"""학습을 위한 디렉토리 생성"""
|
"""학습을 위한 디렉토리 생성"""
|
||||||
make_dir(dataset_root_path, init=False)
|
make_dir(dataset_root_path, init=False)
|
||||||
make_dir(os.path.join(dataset_root_path, "train"), init=True)
|
make_dir(os.path.join(dataset_root_path, "train"), init=True)
|
||||||
make_dir(os.path.join(dataset_root_path, "val"), init=True)
|
make_dir(os.path.join(dataset_root_path, "val"), init=True)
|
||||||
if os.path.exists(os.path.join(dataset_root_path, "result")):
|
if os.path.exists(os.path.join(dataset_root_path, "result")):
|
||||||
shutil.rmtree(os.path.join(dataset_root_path, "result"))
|
shutil.rmtree(os.path.join(dataset_root_path, "result"))
|
||||||
make_yml(dataset_root_path)
|
make_yml(dataset_root_path, names)
|
||||||
|
|
||||||
def process_image_and_label(data:TrainDataInfo, dataset_root_path:str, child_path:str):
|
def process_image_and_label(data:TrainDataInfo, dataset_root_path:str, child_path:str, label_map:dict[int, int]|None):
|
||||||
|
|
||||||
"""이미지 저장 및 레이블 파일 생성"""
|
"""이미지 저장 및 레이블 파일 생성"""
|
||||||
# 이미지 저장
|
# 이미지 저장
|
||||||
@ -139,7 +57,7 @@ def process_image_and_label(data:TrainDataInfo, dataset_root_path:str, child_pat
|
|||||||
y1 = shape.points[0][1]
|
y1 = shape.points[0][1]
|
||||||
x2 = shape.points[1][0]
|
x2 = shape.points[1][0]
|
||||||
y2 = shape.points[1][1]
|
y2 = shape.points[1][1]
|
||||||
train_label.append(str(shape.group_id)) # label Id
|
train_label.append(str(label_map[shape.group_id]) if label_map else str(shape.group_id)) # label Id
|
||||||
train_label.append(str((x1 + x2) / 2 / label.imageWidth)) # 중심 x 좌표
|
train_label.append(str((x1 + x2) / 2 / label.imageWidth)) # 중심 x 좌표
|
||||||
train_label.append(str((y1 + y2) / 2 / label.imageHeight)) # 중심 y 좌표
|
train_label.append(str((y1 + y2) / 2 / label.imageHeight)) # 중심 y 좌표
|
||||||
train_label.append(str((x2 - x1) / label.imageWidth)) # 너비
|
train_label.append(str((x2 - x1) / label.imageWidth)) # 너비
|
||||||
|
Loading…
Reference in New Issue
Block a user