diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index 0fc4c62..2fcf0dc 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -73,13 +73,16 @@ def get_classes(label_map:dict[str: int], model_names: dict[int, str]): def run_predictions(model, image, request, classes): try: with torch.no_grad(): - result = model.predict( - source=image, - iou=request.iou_threshold, - conf=request.conf_threshold, - classes=classes - ) - return result + results = [] + for img in image: + result = model.predict( + source=[img], + iou=request.iou_threshold, + conf=request.conf_threshold, + classes=classes + ) + results += result + return results except Exception as e: raise HTTPException(status_code=500, detail="exception in run_predictions: " + str(e)) @@ -122,7 +125,7 @@ def get_random_color(): random_number = random.randint(0, 0xFFFFFF) return f"#{random_number:06X}" -@router.post("/train") +@router.post("/train", response_model=TrainResponse) async def detection_train(request: TrainRequest): send_slack_message(f"Detection train 요청 projectId : {request.project_id}, 이미지 개수:{len(request.data)}", status="success") diff --git a/ai/app/main.py b/ai/app/main.py index 7dc432a..b119930 100644 --- a/ai/app/main.py +++ b/ai/app/main.py @@ -24,15 +24,16 @@ async def resource_cleaner_middleware(request: Request, call_next): start_time = time.time() try: response = await call_next(request) + return response except Exception as exc: raise exc finally: process_time = time.time() - start_time if request.method != "GET": send_slack_message(f"처리 시간: {process_time}초") - gc.collect() + # gc.collect() torch.cuda.empty_cache() - return response + # 예외 처리기 @app.exception_handler(HTTPException)