Merge branch 'ai/refactor/resource-cleanup' into 'ai/develop'

Feat: 리소스 해제 관련 미들웨어 구현

See merge request s11-s-project/S11P21S002!273
This commit is contained in:
김용수 2024-10-03 00:41:09 +09:00
commit 6fa912df5a
4 changed files with 86 additions and 63 deletions

View File

@ -1,5 +1,4 @@
from fastapi import APIRouter, HTTPException
from fastapi.concurrency import run_in_threadpool
from api.yolo.detection import run_predictions, get_random_color, split_data
from schemas.predict_request import PredictRequest
from schemas.train_request import TrainRequest, TrainDataInfo
@ -26,7 +25,7 @@ async def classification_predict(request: PredictRequest):
url_list = list(map(lambda x:x.image_url, request.image_list))
# 추론
results = await run_predictions(model, url_list, request, classes=[]) # classification은 classes를 무시함
results = run_predictions(model, url_list, request, classes=[]) # classification은 classes를 무시함
# 추론 결과 변환
response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)]
@ -105,7 +104,7 @@ async def classification_train(request: TrainRequest):
download_data(train_data, test_data, dataset_root_path)
# 학습
results = await run_train(request, model,dataset_root_path)
results = run_train(request, model,dataset_root_path)
# best 모델 저장
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt"))
@ -137,7 +136,7 @@ def download_data(train_data:list[TrainDataInfo], test_data:list[TrainDataInfo],
raise HTTPException(status_code=500, detail="exception in download_data(): " + str(e))
async def run_train(request, model, dataset_root_path):
def run_train(request, model, dataset_root_path):
try:
# 데이터 전송 콜백함수
def send_data(trainer):
@ -166,21 +165,25 @@ async def run_train(request, model, dataset_root_path):
# 데이터 전송
send_data_call_api(request.project_id, request.m_id, data)
except Exception as e:
raise HTTPException(status_code=500, detail="exception in send_data: "+ str(e))
print(f"Exception in send_data(): {e}")
# 콜백 등록
model.add_callback("on_train_epoch_start", send_data)
# 학습 실행
results = await run_in_threadpool(model.train,
data=dataset_root_path,
name=join_path(dataset_root_path, "result"),
epochs=request.epochs,
batch=request.batch,
lr0=request.lr0,
lrf=request.lrf,
optimizer=request.optimizer
)
try:
results = model.train(
data=dataset_root_path,
name=join_path(dataset_root_path, "result"),
epochs=request.epochs,
batch=request.batch,
lr0=request.lr0,
lrf=request.lrf,
optimizer=request.optimizer
)
finally:
# 콜백 해제 및 자원 해제
model.reset_callbacks()
# 마지막 에포크 전송
model.trainer.epoch += 1
send_data(model.trainer)

View File

@ -1,5 +1,4 @@
from fastapi import APIRouter, HTTPException
from fastapi.concurrency import run_in_threadpool
from schemas.predict_request import PredictRequest
from schemas.train_request import TrainRequest, TrainDataInfo
from schemas.predict_response import PredictResponse, LabelData, Shape
@ -10,14 +9,13 @@ from services.create_model import save_model
from utils.file_utils import get_dataset_root_path, process_directories, join_path, process_image_and_label
from utils.slackMessage import send_slack_message
from utils.api_utils import send_data_call_api
import random
import random, torch
router = APIRouter()
@router.post("/predict")
async def detection_predict(request: PredictRequest):
send_slack_message(f"predict 요청: {request}", status="success")
# 모델 로드
@ -30,7 +28,7 @@ async def detection_predict(request: PredictRequest):
classes = get_classes(request.label_map, model.names)
# 추론
results = await run_predictions(model, url_list, request, classes)
results = run_predictions(model, url_list, request, classes)
# 추론 결과 변환
response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)]
@ -52,20 +50,19 @@ def get_classes(label_map:dict[str: int], model_names: dict[int, str]):
raise HTTPException(status_code=500, detail="exception in get_classes(): " + str(e))
# 추론 실행 함수
async def run_predictions(model, image, request, classes):
def run_predictions(model, image, request, classes):
try:
result = await run_in_threadpool(
model.predict,
source=image,
iou=request.iou_threshold,
conf=request.conf_threshold,
classes=classes
)
return result
with torch.no_grad():
result = model.predict(
source=image,
iou=request.iou_threshold,
conf=request.conf_threshold,
classes=classes
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="exception in run_predictions: " + str(e))
# 추론 결과 처리 함수
def process_prediction_result(result, image, label_map):
try:
@ -135,7 +132,7 @@ async def detection_train(request: TrainRequest):
download_data(train_data, val_data, dataset_root_path, label_converter)
# 학습
results = await run_train(request, model,dataset_root_path)
results = run_train(request, model,dataset_root_path)
# best 모델 저장
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt"))
@ -178,9 +175,9 @@ def download_data(train_data:list[TrainDataInfo], val_data:list[TrainDataInfo],
except Exception as e:
raise HTTPException(status_code=500, detail="exception in download_data(): " + str(e))
async def run_train(request, model, dataset_root_path):
def run_train(request, model, dataset_root_path):
try:
# 데이터 전송 콜백함수
# 콜백 함수 정의
def send_data(trainer):
try:
# 첫번째 epoch는 스킵
@ -207,31 +204,33 @@ async def run_train(request, model, dataset_root_path):
# 데이터 전송
send_data_call_api(request.project_id, request.m_id, data)
except Exception as e:
raise HTTPException(status_code=500, detail=f"exception in send_data(): {e}")
# 예외 처리
print(f"Exception in send_data(): {e}")
# 콜백 등록
model.add_callback("on_train_epoch_start", send_data)
# 학습 실행
results = await run_in_threadpool(model.train,
data=join_path(dataset_root_path, "dataset.yaml"),
name=join_path(dataset_root_path, "result"),
epochs=request.epochs,
batch=request.batch,
lr0=request.lr0,
lrf=request.lrf,
optimizer=request.optimizer
)
try:
# 비동기 함수로 학습 실행
results = model.train(
data=join_path(dataset_root_path, "dataset.yaml"),
name=join_path(dataset_root_path, "result"),
epochs=request.epochs,
batch=request.batch,
lr0=request.lr0,
lrf=request.lrf,
optimizer=request.optimizer
)
finally:
# 콜백 해제 및 자원 해제
model.reset_callbacks()
torch.cuda.empty_cache()
# 마지막 에포크 전송
model.trainer.epoch += 1
send_data(model.trainer)
return results
except HTTPException as e:
raise e # HTTP 예외를 다시 발생
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"exception in run_train(): {e}")

View File

@ -1,5 +1,4 @@
from fastapi import APIRouter, HTTPException
from fastapi.concurrency import run_in_threadpool
from api.yolo.detection import get_classes, run_predictions, get_random_color, split_data, download_data
from schemas.predict_request import PredictRequest
from schemas.train_request import TrainRequest
@ -28,7 +27,7 @@ async def segmentation_predict(request: PredictRequest):
classes = get_classes(request.label_map, model.names)
# 추론
results = await run_predictions(model, url_list, request, classes)
results = run_predictions(model, url_list, request, classes)
# 추론 결과 변환
response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)]
@ -102,7 +101,7 @@ async def segmentation_train(request: TrainRequest):
download_data(train_data, val_data, dataset_root_path, label_converter)
# 학습
results = await run_train(request, model,dataset_root_path)
results = run_train(request, model,dataset_root_path)
# best 모델 저장
model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt"))
@ -122,7 +121,7 @@ async def segmentation_train(request: TrainRequest):
return response
async def run_train(request, model, dataset_root_path):
def run_train(request, model, dataset_root_path):
try:
# 데이터 전송 콜백함수
def send_data(trainer):
@ -151,22 +150,25 @@ async def run_train(request, model, dataset_root_path):
# 데이터 전송
send_data_call_api(request.project_id, request.m_id, data)
except Exception as e:
raise HTTPException(status_code=500, detail=f"send_data exception: {e}")
print(f"Exception in send_data(): {e}")
# 콜백 등록
model.add_callback("on_train_epoch_start", send_data)
# 학습 실행
results = await run_in_threadpool(model.train,
data=join_path(dataset_root_path, "dataset.yaml"),
name=join_path(dataset_root_path, "result"),
epochs=request.epochs,
batch=request.batch,
lr0=request.lr0,
lrf=request.lrf,
optimizer=request.optimizer
)
try:
# 비동기 함수로 학습 실행
results = model.train(
data=join_path(dataset_root_path, "dataset.yaml"),
name=join_path(dataset_root_path, "result"),
epochs=request.epochs,
batch=request.batch,
lr0=request.lr0,
lrf=request.lrf,
optimizer=request.optimizer
)
finally:
# 콜백 해제 및 자원 해제
model.reset_callbacks()
# 마지막 에포크 전송
model.trainer.epoch += 1

View File

@ -7,6 +7,7 @@ from api.yolo.segmentation import router as yolo_segmentation_router
from api.yolo.classfication import router as yolo_classification_router
from api.yolo.model import router as yolo_model_router
from utils.slackMessage import send_slack_message
import time, torch, gc
app = FastAPI()
@ -17,6 +18,24 @@ app.include_router(yolo_classification_router, prefix="/api/classification", tag
app.include_router(yolo_model_router, prefix="/api/model", tags=["Model"])
@app.middleware("http")
async def resource_cleaner_middleware(request: Request, call_next):
start_time = time.time()
try:
response = await call_next(request)
except Exception as exc:
raise exc
finally:
process_time = time.time() - start_time
send_slack_message(f"처리 시간: {process_time}")
for obj in gc.get_objects():
if torch.is_tensor(obj):
del obj
gc.collect()
torch.cuda.empty_cache()
return response
# 예외 처리기
@app.exception_handler(HTTPException)
async def custom_http_exception_handler(request:Request, exc):