From 6e87beab551b9f18fabc014291a904061f7270ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=9A=A9=EC=88=98?= Date: Wed, 4 Sep 2024 17:36:40 +0900 Subject: [PATCH] =?UTF-8?q?Refactor:=20API=20endPoint=20=EB=B3=80=EA=B2=BD?= =?UTF-8?q?=20=EB=B0=8F=20=EC=9D=B4=EB=AF=B8=EC=A7=80=20=EB=A9=94=EB=AA=A8?= =?UTF-8?q?=EB=A6=AC=20=EC=B2=98=EB=A6=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/app/api/yolo/detection.py | 19 +++++++++++++++---- ai/app/main.py | 2 +- ai/app/schemas/predict_response.py | 1 + 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index 8d3e7fd..2d01a9f 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -5,8 +5,7 @@ from services.ai_service import load_detection_model from typing import List router = APIRouter() - -@router.post("/predict", response_model=List[PredictResponse]) +@router.post("/detection", response_model=List[PredictResponse]) def predict(request: PredictRequest): version = "0.1.0" @@ -20,12 +19,23 @@ def predict(request: PredictRequest): results = [] try: for image in request.image_list: + # URL에서 이미지를 메모리로 로드 TODO: 추후 메모리에 할지 어떻게 해야할지 or 병렬 처리 고민 + # response = requests.get(image.image_url) + + # 이미지 데이터를 메모리로 로드 + # img = Image.open(io.BytesIO(response.content)) + predict_results = model.predict( - source=image.image_url, + source=image.image_url, iou=request.iou_threshold, conf=request.conf_threshold, - classes=request.classes) + classes=request.classes + ) results.append(predict_results[0]) + + # 메모리에서 이미지 객체 해제 + # img.close() + # del img except Exception as e: raise HTTPException(status_code=500, detail="model predict exception: "+str(e)) @@ -57,6 +67,7 @@ def predict(request: PredictRequest): } response.append({ "image_id":image.image_id, + "image_url":image.image_url, "data":label_data }) except Exception as e: diff --git a/ai/app/main.py b/ai/app/main.py index cc2fd10..04321cc 100644 --- a/ai/app/main.py +++ b/ai/app/main.py @@ -4,7 +4,7 @@ from api.yolo.detection import router as yolo_detection_router app = FastAPI() # 각 기능별 라우터를 애플리케이션에 등록 -app.include_router(yolo_detection_router, prefix="/api/yolo/detection") +app.include_router(yolo_detection_router, prefix="/api") # 애플리케이션 실행 if __name__ == "__main__": diff --git a/ai/app/schemas/predict_response.py b/ai/app/schemas/predict_response.py index 706d896..e34825c 100644 --- a/ai/app/schemas/predict_response.py +++ b/ai/app/schemas/predict_response.py @@ -20,4 +20,5 @@ class LabelData(BaseModel): class PredictResponse(BaseModel): image_id: int + image_url: str data: LabelData \ No newline at end of file