Feat: 오토레이블링, 학습 비동기 처리
This commit is contained in:
parent
7016d3a91e
commit
7dd09182b8
@ -1,4 +1,5 @@
|
|||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
from api.yolo.detection import run_predictions, get_random_color, split_data
|
from api.yolo.detection import run_predictions, get_random_color, split_data
|
||||||
from schemas.predict_request import PredictRequest
|
from schemas.predict_request import PredictRequest
|
||||||
from schemas.train_request import TrainRequest, TrainDataInfo
|
from schemas.train_request import TrainRequest, TrainDataInfo
|
||||||
@ -25,7 +26,7 @@ async def classification_predict(request: PredictRequest):
|
|||||||
url_list = list(map(lambda x:x.image_url, request.image_list))
|
url_list = list(map(lambda x:x.image_url, request.image_list))
|
||||||
|
|
||||||
# 추론
|
# 추론
|
||||||
results = run_predictions(model, url_list, request, classes=[]) # classification은 classes를 무시함
|
results = await run_predictions(model, url_list, request, classes=[]) # classification은 classes를 무시함
|
||||||
|
|
||||||
# 추론 결과 변환
|
# 추론 결과 변환
|
||||||
response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)]
|
response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)]
|
||||||
@ -104,7 +105,7 @@ async def classification_train(request: TrainRequest):
|
|||||||
download_data(train_data, test_data, dataset_root_path)
|
download_data(train_data, test_data, dataset_root_path)
|
||||||
|
|
||||||
# 학습
|
# 학습
|
||||||
results = run_train(request, model,dataset_root_path)
|
results = await run_train(request, model,dataset_root_path)
|
||||||
|
|
||||||
# best 모델 저장
|
# best 모델 저장
|
||||||
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt"))
|
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt"))
|
||||||
@ -136,7 +137,7 @@ def download_data(train_data:list[TrainDataInfo], test_data:list[TrainDataInfo],
|
|||||||
raise HTTPException(status_code=500, detail="exception in download_data(): " + str(e))
|
raise HTTPException(status_code=500, detail="exception in download_data(): " + str(e))
|
||||||
|
|
||||||
|
|
||||||
def run_train(request, model, dataset_root_path):
|
async def run_train(request, model, dataset_root_path):
|
||||||
try:
|
try:
|
||||||
# 데이터 전송 콜백함수
|
# 데이터 전송 콜백함수
|
||||||
def send_data(trainer):
|
def send_data(trainer):
|
||||||
@ -171,7 +172,7 @@ def run_train(request, model, dataset_root_path):
|
|||||||
model.add_callback("on_train_epoch_start", send_data)
|
model.add_callback("on_train_epoch_start", send_data)
|
||||||
|
|
||||||
# 학습 실행
|
# 학습 실행
|
||||||
results = model.train(
|
results = await run_in_threadpool(model.train,
|
||||||
data=dataset_root_path,
|
data=dataset_root_path,
|
||||||
name=join_path(dataset_root_path, "result"),
|
name=join_path(dataset_root_path, "result"),
|
||||||
epochs=request.epochs,
|
epochs=request.epochs,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
from schemas.predict_request import PredictRequest
|
from schemas.predict_request import PredictRequest
|
||||||
from schemas.train_request import TrainRequest, TrainDataInfo
|
from schemas.train_request import TrainRequest, TrainDataInfo
|
||||||
from schemas.predict_response import PredictResponse, LabelData, Shape
|
from schemas.predict_response import PredictResponse, LabelData, Shape
|
||||||
@ -29,7 +30,7 @@ async def detection_predict(request: PredictRequest):
|
|||||||
classes = get_classes(request.label_map, model.names)
|
classes = get_classes(request.label_map, model.names)
|
||||||
|
|
||||||
# 추론
|
# 추론
|
||||||
results = run_predictions(model, url_list, request, classes)
|
results = await run_predictions(model, url_list, request, classes)
|
||||||
|
|
||||||
# 추론 결과 변환
|
# 추론 결과 변환
|
||||||
response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)]
|
response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)]
|
||||||
@ -51,14 +52,16 @@ def get_classes(label_map:dict[str: int], model_names: dict[int, str]):
|
|||||||
raise HTTPException(status_code=500, detail="exception in get_classes(): " + str(e))
|
raise HTTPException(status_code=500, detail="exception in get_classes(): " + str(e))
|
||||||
|
|
||||||
# 추론 실행 함수
|
# 추론 실행 함수
|
||||||
def run_predictions(model, image, request, classes):
|
async def run_predictions(model, image, request, classes):
|
||||||
try:
|
try:
|
||||||
return model.predict(
|
result = await run_in_threadpool(
|
||||||
|
model.predict,
|
||||||
source=image,
|
source=image,
|
||||||
iou=request.iou_threshold,
|
iou=request.iou_threshold,
|
||||||
conf=request.conf_threshold,
|
conf=request.conf_threshold,
|
||||||
classes=classes
|
classes=classes
|
||||||
)
|
)
|
||||||
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="exception in run_predictions: " + str(e))
|
raise HTTPException(status_code=500, detail="exception in run_predictions: " + str(e))
|
||||||
|
|
||||||
@ -127,11 +130,12 @@ async def detection_train(request: TrainRequest):
|
|||||||
# 데이터 전처리: 데이터를 학습데이터와 검증데이터로 분류
|
# 데이터 전처리: 데이터를 학습데이터와 검증데이터로 분류
|
||||||
train_data, val_data = split_data(request.data, request.ratio)
|
train_data, val_data = split_data(request.data, request.ratio)
|
||||||
|
|
||||||
|
|
||||||
# 데이터 전처리: 데이터 이미지 및 레이블 다운로드
|
# 데이터 전처리: 데이터 이미지 및 레이블 다운로드
|
||||||
download_data(train_data, val_data, dataset_root_path, label_converter)
|
download_data(train_data, val_data, dataset_root_path, label_converter)
|
||||||
|
|
||||||
# 학습
|
# 학습
|
||||||
results = run_train(request, model,dataset_root_path)
|
results = await run_train(request, model,dataset_root_path)
|
||||||
|
|
||||||
# best 모델 저장
|
# best 모델 저장
|
||||||
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt"))
|
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt"))
|
||||||
@ -171,7 +175,7 @@ def download_data(train_data:list[TrainDataInfo], val_data:list[TrainDataInfo],
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="exception in download_data(): " + str(e))
|
raise HTTPException(status_code=500, detail="exception in download_data(): " + str(e))
|
||||||
|
|
||||||
def run_train(request, model, dataset_root_path):
|
async def run_train(request, model, dataset_root_path):
|
||||||
try:
|
try:
|
||||||
# 데이터 전송 콜백함수
|
# 데이터 전송 콜백함수
|
||||||
def send_data(trainer):
|
def send_data(trainer):
|
||||||
@ -206,7 +210,7 @@ def run_train(request, model, dataset_root_path):
|
|||||||
model.add_callback("on_train_epoch_start", send_data)
|
model.add_callback("on_train_epoch_start", send_data)
|
||||||
|
|
||||||
# 학습 실행
|
# 학습 실행
|
||||||
results = model.train(
|
results = await run_in_threadpool(model.train,
|
||||||
data=join_path(dataset_root_path, "dataset.yaml"),
|
data=join_path(dataset_root_path, "dataset.yaml"),
|
||||||
name=join_path(dataset_root_path, "result"),
|
name=join_path(dataset_root_path, "result"),
|
||||||
epochs=request.epochs,
|
epochs=request.epochs,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
from api.yolo.detection import get_classes, run_predictions, get_random_color, split_data, download_data
|
from api.yolo.detection import get_classes, run_predictions, get_random_color, split_data, download_data
|
||||||
from schemas.predict_request import PredictRequest
|
from schemas.predict_request import PredictRequest
|
||||||
from schemas.train_request import TrainRequest
|
from schemas.train_request import TrainRequest
|
||||||
@ -27,7 +28,7 @@ async def segmentation_predict(request: PredictRequest):
|
|||||||
classes = get_classes(request.label_map, model.names)
|
classes = get_classes(request.label_map, model.names)
|
||||||
|
|
||||||
# 추론
|
# 추론
|
||||||
results = run_predictions(model, url_list, request, classes)
|
results = await run_predictions(model, url_list, request, classes)
|
||||||
|
|
||||||
# 추론 결과 변환
|
# 추론 결과 변환
|
||||||
response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)]
|
response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)]
|
||||||
@ -101,7 +102,7 @@ async def segmentation_train(request: TrainRequest):
|
|||||||
download_data(train_data, val_data, dataset_root_path, label_converter)
|
download_data(train_data, val_data, dataset_root_path, label_converter)
|
||||||
|
|
||||||
# 학습
|
# 학습
|
||||||
results = run_train(request, model,dataset_root_path)
|
results = await run_train(request, model,dataset_root_path)
|
||||||
|
|
||||||
# best 모델 저장
|
# best 모델 저장
|
||||||
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt"))
|
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt"))
|
||||||
@ -121,7 +122,7 @@ async def segmentation_train(request: TrainRequest):
|
|||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def run_train(request, model, dataset_root_path):
|
async def run_train(request, model, dataset_root_path):
|
||||||
try:
|
try:
|
||||||
# 데이터 전송 콜백함수
|
# 데이터 전송 콜백함수
|
||||||
def send_data(trainer):
|
def send_data(trainer):
|
||||||
@ -155,8 +156,9 @@ def run_train(request, model, dataset_root_path):
|
|||||||
# 콜백 등록
|
# 콜백 등록
|
||||||
model.add_callback("on_train_epoch_start", send_data)
|
model.add_callback("on_train_epoch_start", send_data)
|
||||||
|
|
||||||
|
|
||||||
# 학습 실행
|
# 학습 실행
|
||||||
results = model.train(
|
results = await run_in_threadpool(model.train,
|
||||||
data=join_path(dataset_root_path, "dataset.yaml"),
|
data=join_path(dataset_root_path, "dataset.yaml"),
|
||||||
name=join_path(dataset_root_path, "result"),
|
name=join_path(dataset_root_path, "result"),
|
||||||
epochs=request.epochs,
|
epochs=request.epochs,
|
||||||
|
Loading…
Reference in New Issue
Block a user