From 7dd09182b856562e59e3d092cebe573423493290 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A7=84=ED=98=84?= Date: Fri, 27 Sep 2024 19:03:18 +0900 Subject: [PATCH 1/2] =?UTF-8?q?Feat:=20=EC=98=A4=ED=86=A0=EB=A0=88?= =?UTF-8?q?=EC=9D=B4=EB=B8=94=EB=A7=81,=20=ED=95=99=EC=8A=B5=20=EB=B9=84?= =?UTF-8?q?=EB=8F=99=EA=B8=B0=20=EC=B2=98=EB=A6=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/app/api/yolo/classfication.py | 9 +++++---- ai/app/api/yolo/detection.py | 16 ++++++++++------ ai/app/api/yolo/segmentation.py | 10 ++++++---- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/ai/app/api/yolo/classfication.py b/ai/app/api/yolo/classfication.py index e870a82..cd32ce8 100644 --- a/ai/app/api/yolo/classfication.py +++ b/ai/app/api/yolo/classfication.py @@ -1,4 +1,5 @@ from fastapi import APIRouter, HTTPException +from fastapi.concurrency import run_in_threadpool from api.yolo.detection import run_predictions, get_random_color, split_data from schemas.predict_request import PredictRequest from schemas.train_request import TrainRequest, TrainDataInfo @@ -25,7 +26,7 @@ async def classification_predict(request: PredictRequest): url_list = list(map(lambda x:x.image_url, request.image_list)) # 추론 - results = run_predictions(model, url_list, request, classes=[]) # classification은 classes를 무시함 + results = await run_predictions(model, url_list, request, classes=[]) # classification은 classes를 무시함 # 추론 결과 변환 response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)] @@ -104,7 +105,7 @@ async def classification_train(request: TrainRequest): download_data(train_data, test_data, dataset_root_path) # 학습 - results = run_train(request, model,dataset_root_path) + results = await run_train(request, model,dataset_root_path) # best 모델 저장 model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt")) @@ -136,7 +137,7 @@ def download_data(train_data:list[TrainDataInfo], test_data:list[TrainDataInfo], raise HTTPException(status_code=500, detail="exception in download_data(): " + str(e)) -def run_train(request, model, dataset_root_path): +async def run_train(request, model, dataset_root_path): try: # 데이터 전송 콜백함수 def send_data(trainer): @@ -171,7 +172,7 @@ def run_train(request, model, dataset_root_path): model.add_callback("on_train_epoch_start", send_data) # 학습 실행 - results = model.train( + results = await run_in_threadpool(model.train, data=dataset_root_path, name=join_path(dataset_root_path, "result"), epochs=request.epochs, diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index b3ca8f7..e3abc00 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -1,4 +1,5 @@ from fastapi import APIRouter, HTTPException +from fastapi.concurrency import run_in_threadpool from schemas.predict_request import PredictRequest from schemas.train_request import TrainRequest, TrainDataInfo from schemas.predict_response import PredictResponse, LabelData, Shape @@ -29,7 +30,7 @@ async def detection_predict(request: PredictRequest): classes = get_classes(request.label_map, model.names) # 추론 - results = run_predictions(model, url_list, request, classes) + results = await run_predictions(model, url_list, request, classes) # 추론 결과 변환 response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)] @@ -51,14 +52,16 @@ def get_classes(label_map:dict[str: int], model_names: dict[int, str]): raise HTTPException(status_code=500, detail="exception in get_classes(): " + str(e)) # 추론 실행 함수 -def run_predictions(model, image, request, classes): +async def run_predictions(model, image, request, classes): try: - return model.predict( + result = await run_in_threadpool( + model.predict, source=image, iou=request.iou_threshold, conf=request.conf_threshold, classes=classes ) + return result except Exception as e: raise HTTPException(status_code=500, detail="exception in run_predictions: " + str(e)) @@ -126,12 +129,13 @@ async def detection_train(request: TrainRequest): # 데이터 전처리: 데이터를 학습데이터와 검증데이터로 분류 train_data, val_data = split_data(request.data, request.ratio) + # 데이터 전처리: 데이터 이미지 및 레이블 다운로드 download_data(train_data, val_data, dataset_root_path, label_converter) # 학습 - results = run_train(request, model,dataset_root_path) + results = await run_train(request, model,dataset_root_path) # best 모델 저장 model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt")) @@ -171,7 +175,7 @@ def download_data(train_data:list[TrainDataInfo], val_data:list[TrainDataInfo], except Exception as e: raise HTTPException(status_code=500, detail="exception in download_data(): " + str(e)) -def run_train(request, model, dataset_root_path): +async def run_train(request, model, dataset_root_path): try: # 데이터 전송 콜백함수 def send_data(trainer): @@ -206,7 +210,7 @@ def run_train(request, model, dataset_root_path): model.add_callback("on_train_epoch_start", send_data) # 학습 실행 - results = model.train( + results = await run_in_threadpool(model.train, data=join_path(dataset_root_path, "dataset.yaml"), name=join_path(dataset_root_path, "result"), epochs=request.epochs, diff --git a/ai/app/api/yolo/segmentation.py b/ai/app/api/yolo/segmentation.py index 4ef0078..1752b8a 100644 --- a/ai/app/api/yolo/segmentation.py +++ b/ai/app/api/yolo/segmentation.py @@ -1,4 +1,5 @@ from fastapi import APIRouter, HTTPException +from fastapi.concurrency import run_in_threadpool from api.yolo.detection import get_classes, run_predictions, get_random_color, split_data, download_data from schemas.predict_request import PredictRequest from schemas.train_request import TrainRequest @@ -27,7 +28,7 @@ async def segmentation_predict(request: PredictRequest): classes = get_classes(request.label_map, model.names) # 추론 - results = run_predictions(model, url_list, request, classes) + results = await run_predictions(model, url_list, request, classes) # 추론 결과 변환 response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)] @@ -101,7 +102,7 @@ async def segmentation_train(request: TrainRequest): download_data(train_data, val_data, dataset_root_path, label_converter) # 학습 - results = run_train(request, model,dataset_root_path) + results = await run_train(request, model,dataset_root_path) # best 모델 저장 model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt")) @@ -121,7 +122,7 @@ async def segmentation_train(request: TrainRequest): return response -def run_train(request, model, dataset_root_path): +async def run_train(request, model, dataset_root_path): try: # 데이터 전송 콜백함수 def send_data(trainer): @@ -155,8 +156,9 @@ def run_train(request, model, dataset_root_path): # 콜백 등록 model.add_callback("on_train_epoch_start", send_data) + # 학습 실행 - results = model.train( + results = await run_in_threadpool(model.train, data=join_path(dataset_root_path, "dataset.yaml"), name=join_path(dataset_root_path, "result"), epochs=request.epochs, From 7e75f41c645ee08e5b93c4a9aac10cc4c0ece668 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A7=84=ED=98=84?= Date: Fri, 27 Sep 2024 19:06:23 +0900 Subject: [PATCH 2/2] =?UTF-8?q?Refactor:=20split=5Fdata()=20=EC=98=88?= =?UTF-8?q?=EC=99=B8=EC=B2=98=EB=A6=AC=20=EC=B6=94=EA=B0=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/app/api/yolo/detection.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index e3abc00..454bc68 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -161,6 +161,9 @@ def split_data(data:list[TrainDataInfo], ratio:float): random.shuffle(data) train_data = data[:train_size] val_data = data[train_size:] + + if not train_data or not val_data: + raise Exception("data size is too small") return train_data, val_data except Exception as e: raise HTTPException(status_code=500, detail="exception in split_data(): " + str(e))