diff --git a/ai/app/api/yolo/classfication.py b/ai/app/api/yolo/classfication.py index cd32ce8..592fd98 100644 --- a/ai/app/api/yolo/classfication.py +++ b/ai/app/api/yolo/classfication.py @@ -1,5 +1,4 @@ from fastapi import APIRouter, HTTPException -from fastapi.concurrency import run_in_threadpool from api.yolo.detection import run_predictions, get_random_color, split_data from schemas.predict_request import PredictRequest from schemas.train_request import TrainRequest, TrainDataInfo @@ -26,7 +25,7 @@ async def classification_predict(request: PredictRequest): url_list = list(map(lambda x:x.image_url, request.image_list)) # 추론 - results = await run_predictions(model, url_list, request, classes=[]) # classification은 classes를 무시함 + results = run_predictions(model, url_list, request, classes=[]) # classification은 classes를 무시함 # 추론 결과 변환 response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)] @@ -105,7 +104,7 @@ async def classification_train(request: TrainRequest): download_data(train_data, test_data, dataset_root_path) # 학습 - results = await run_train(request, model,dataset_root_path) + results = run_train(request, model,dataset_root_path) # best 모델 저장 model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt")) @@ -137,7 +136,7 @@ def download_data(train_data:list[TrainDataInfo], test_data:list[TrainDataInfo], raise HTTPException(status_code=500, detail="exception in download_data(): " + str(e)) -async def run_train(request, model, dataset_root_path): +def run_train(request, model, dataset_root_path): try: # 데이터 전송 콜백함수 def send_data(trainer): @@ -166,21 +165,25 @@ async def run_train(request, model, dataset_root_path): # 데이터 전송 send_data_call_api(request.project_id, request.m_id, data) except Exception as e: - raise HTTPException(status_code=500, detail="exception in send_data: "+ str(e)) + print(f"Exception in send_data(): {e}") # 콜백 등록 model.add_callback("on_train_epoch_start", send_data) # 학습 실행 - results = await run_in_threadpool(model.train, - data=dataset_root_path, - name=join_path(dataset_root_path, "result"), - epochs=request.epochs, - batch=request.batch, - lr0=request.lr0, - lrf=request.lrf, - optimizer=request.optimizer - ) + try: + results = model.train( + data=dataset_root_path, + name=join_path(dataset_root_path, "result"), + epochs=request.epochs, + batch=request.batch, + lr0=request.lr0, + lrf=request.lrf, + optimizer=request.optimizer + ) + finally: + # 콜백 해제 및 자원 해제 + model.reset_callbacks() # 마지막 에포크 전송 model.trainer.epoch += 1 send_data(model.trainer) diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index 454bc68..ab496a9 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -1,5 +1,4 @@ from fastapi import APIRouter, HTTPException -from fastapi.concurrency import run_in_threadpool from schemas.predict_request import PredictRequest from schemas.train_request import TrainRequest, TrainDataInfo from schemas.predict_response import PredictResponse, LabelData, Shape @@ -10,14 +9,13 @@ from services.create_model import save_model from utils.file_utils import get_dataset_root_path, process_directories, join_path, process_image_and_label from utils.slackMessage import send_slack_message from utils.api_utils import send_data_call_api -import random +import random, torch router = APIRouter() @router.post("/predict") async def detection_predict(request: PredictRequest): - send_slack_message(f"predict 요청: {request}", status="success") # 모델 로드 @@ -30,7 +28,7 @@ async def detection_predict(request: PredictRequest): classes = get_classes(request.label_map, model.names) # 추론 - results = await run_predictions(model, url_list, request, classes) + results = run_predictions(model, url_list, request, classes) # 추론 결과 변환 response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)] @@ -52,19 +50,18 @@ def get_classes(label_map:dict[str: int], model_names: dict[int, str]): raise HTTPException(status_code=500, detail="exception in get_classes(): " + str(e)) # 추론 실행 함수 -async def run_predictions(model, image, request, classes): +def run_predictions(model, image, request, classes): try: - result = await run_in_threadpool( - model.predict, - source=image, - iou=request.iou_threshold, - conf=request.conf_threshold, - classes=classes - ) - return result + with torch.no_grad(): + result = model.predict( + source=image, + iou=request.iou_threshold, + conf=request.conf_threshold, + classes=classes + ) + return result except Exception as e: raise HTTPException(status_code=500, detail="exception in run_predictions: " + str(e)) - # 추론 결과 처리 함수 def process_prediction_result(result, image, label_map): @@ -135,7 +132,7 @@ async def detection_train(request: TrainRequest): download_data(train_data, val_data, dataset_root_path, label_converter) # 학습 - results = await run_train(request, model,dataset_root_path) + results = run_train(request, model,dataset_root_path) # best 모델 저장 model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt")) @@ -178,9 +175,9 @@ def download_data(train_data:list[TrainDataInfo], val_data:list[TrainDataInfo], except Exception as e: raise HTTPException(status_code=500, detail="exception in download_data(): " + str(e)) -async def run_train(request, model, dataset_root_path): +def run_train(request, model, dataset_root_path): try: - # 데이터 전송 콜백함수 + # 콜백 함수 정의 def send_data(trainer): try: # 첫번째 epoch는 스킵 @@ -207,31 +204,33 @@ async def run_train(request, model, dataset_root_path): # 데이터 전송 send_data_call_api(request.project_id, request.m_id, data) except Exception as e: - raise HTTPException(status_code=500, detail=f"exception in send_data(): {e}") + # 예외 처리 + print(f"Exception in send_data(): {e}") # 콜백 등록 model.add_callback("on_train_epoch_start", send_data) - # 학습 실행 - results = await run_in_threadpool(model.train, - data=join_path(dataset_root_path, "dataset.yaml"), - name=join_path(dataset_root_path, "result"), - epochs=request.epochs, - batch=request.batch, - lr0=request.lr0, - lrf=request.lrf, - optimizer=request.optimizer - ) + try: + # 비동기 함수로 학습 실행 + results = model.train( + data=join_path(dataset_root_path, "dataset.yaml"), + name=join_path(dataset_root_path, "result"), + epochs=request.epochs, + batch=request.batch, + lr0=request.lr0, + lrf=request.lrf, + optimizer=request.optimizer + ) + finally: + # 콜백 해제 및 자원 해제 + model.reset_callbacks() + torch.cuda.empty_cache() # 마지막 에포크 전송 model.trainer.epoch += 1 send_data(model.trainer) return results - - except HTTPException as e: - raise e # HTTP 예외를 다시 발생 + raise e except Exception as e: raise HTTPException(status_code=500, detail=f"exception in run_train(): {e}") - - diff --git a/ai/app/api/yolo/segmentation.py b/ai/app/api/yolo/segmentation.py index 1752b8a..4e8886f 100644 --- a/ai/app/api/yolo/segmentation.py +++ b/ai/app/api/yolo/segmentation.py @@ -1,5 +1,4 @@ from fastapi import APIRouter, HTTPException -from fastapi.concurrency import run_in_threadpool from api.yolo.detection import get_classes, run_predictions, get_random_color, split_data, download_data from schemas.predict_request import PredictRequest from schemas.train_request import TrainRequest @@ -28,7 +27,7 @@ async def segmentation_predict(request: PredictRequest): classes = get_classes(request.label_map, model.names) # 추론 - results = await run_predictions(model, url_list, request, classes) + results = run_predictions(model, url_list, request, classes) # 추론 결과 변환 response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)] @@ -102,7 +101,7 @@ async def segmentation_train(request: TrainRequest): download_data(train_data, val_data, dataset_root_path, label_converter) # 학습 - results = await run_train(request, model,dataset_root_path) + results = run_train(request, model,dataset_root_path) # best 모델 저장 model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt")) @@ -122,7 +121,7 @@ async def segmentation_train(request: TrainRequest): return response -async def run_train(request, model, dataset_root_path): +def run_train(request, model, dataset_root_path): try: # 데이터 전송 콜백함수 def send_data(trainer): @@ -151,23 +150,26 @@ async def run_train(request, model, dataset_root_path): # 데이터 전송 send_data_call_api(request.project_id, request.m_id, data) except Exception as e: - raise HTTPException(status_code=500, detail=f"send_data exception: {e}") + print(f"Exception in send_data(): {e}") # 콜백 등록 model.add_callback("on_train_epoch_start", send_data) + try: + # 비동기 함수로 학습 실행 + results = model.train( + data=join_path(dataset_root_path, "dataset.yaml"), + name=join_path(dataset_root_path, "result"), + epochs=request.epochs, + batch=request.batch, + lr0=request.lr0, + lrf=request.lrf, + optimizer=request.optimizer + ) + finally: + # 콜백 해제 및 자원 해제 + model.reset_callbacks() - # 학습 실행 - results = await run_in_threadpool(model.train, - data=join_path(dataset_root_path, "dataset.yaml"), - name=join_path(dataset_root_path, "result"), - epochs=request.epochs, - batch=request.batch, - lr0=request.lr0, - lrf=request.lrf, - optimizer=request.optimizer - ) - # 마지막 에포크 전송 model.trainer.epoch += 1 send_data(model.trainer) diff --git a/ai/app/main.py b/ai/app/main.py index e1b507e..f8fe00e 100644 --- a/ai/app/main.py +++ b/ai/app/main.py @@ -7,6 +7,7 @@ from api.yolo.segmentation import router as yolo_segmentation_router from api.yolo.classfication import router as yolo_classification_router from api.yolo.model import router as yolo_model_router from utils.slackMessage import send_slack_message +import time, torch, gc app = FastAPI() @@ -17,6 +18,24 @@ app.include_router(yolo_classification_router, prefix="/api/classification", tag app.include_router(yolo_model_router, prefix="/api/model", tags=["Model"]) + +@app.middleware("http") +async def resource_cleaner_middleware(request: Request, call_next): + start_time = time.time() + try: + response = await call_next(request) + except Exception as exc: + raise exc + finally: + process_time = time.time() - start_time + send_slack_message(f"처리 시간: {process_time}초") + for obj in gc.get_objects(): + if torch.is_tensor(obj): + del obj + gc.collect() + torch.cuda.empty_cache() + return response + # 예외 처리기 @app.exception_handler(HTTPException) async def custom_http_exception_handler(request:Request, exc):