From 05e9a2c03b4e9940d5718297f546fa3801bd1fd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A7=84=ED=98=84?= Date: Mon, 23 Sep 2024 09:46:42 +0900 Subject: [PATCH] =?UTF-8?q?Refactor:=20=EB=A0=88=EC=9D=B4=EB=B8=94=20?= =?UTF-8?q?=EC=B9=B4=ED=85=8C=EA=B3=A0=EB=A6=AC=EB=A5=BC=20=ED=8F=AC?= =?UTF-8?q?=ED=95=A8=ED=95=9C=20=EC=98=A4=ED=86=A0=EB=A0=88=EC=9D=B4?= =?UTF-8?q?=EB=B8=94=EB=A7=81=20=EA=B5=AC=ED=98=84=20=EB=B0=8F=20=EB=A6=AC?= =?UTF-8?q?=ED=8C=A9=ED=86=A0=EB=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/app/api/yolo/detection.py | 21 ++++++++++++--------- ai/app/api/yolo/segmentation.py | 3 +-- ai/app/schemas/predict_request.py | 11 +++-------- ai/app/schemas/train_request.py | 2 -- 4 files changed, 16 insertions(+), 21 deletions(-) diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index e556d3e..bb9234c 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -15,8 +15,6 @@ router = APIRouter() @router.post("/predict") async def detection_predict(request: PredictRequest): - version = "0.1.0" - # Spring 서버의 WebSocket URL # TODO: 배포 시 변경 spring_server_ws_url = f"ws://localhost:8080/ws" @@ -30,6 +28,12 @@ async def detection_predict(request: PredictRequest): model = load_detection_model(model_path=model_path) except Exception as e: raise HTTPException(status_code=500, detail="load model exception: " + str(e)) + + # 모델 레이블 카테고리 연결 + classes = None + if request.label_map: + classes = list(request.label_map) + # 웹소켓 연결 try: @@ -46,12 +50,12 @@ async def detection_predict(request: PredictRequest): source=image.image_url, iou=request.iou_threshold, conf=request.conf_threshold, - classes=request.classes + classes=classes ) # 예측 결과 처리 result = predict_results[0] label_data = LabelData( - version=version, + version="0.0.0", task_type="det", shapes=[ { @@ -61,7 +65,7 @@ async def detection_predict(request: PredictRequest): [summary['box']['x1'], summary['box']['y1']], [summary['box']['x2'], summary['box']['y2']] ], - "group_id": summary['class'], + "group_id": request.label_map[summary['class']] if request.label_map else summary['class'], "shape_type": "rectangle", "flags": {} } @@ -105,13 +109,12 @@ async def detection_predict(request: PredictRequest): source=image.image_url, iou=request.iou_threshold, conf=request.conf_threshold, - classes=request.classes + classes=classes ) - # 예측 결과 처리 result = predict_results[0] label_data = LabelData( - version=version, + version="0.0.0", task_type="det", shapes=[ { @@ -121,7 +124,7 @@ async def detection_predict(request: PredictRequest): [summary['box']['x1'], summary['box']['y1']], [summary['box']['x2'], summary['box']['y2']] ], - "group_id": summary['class'], + "group_id": request.label_map[summary['class']] if request.label_map else summary['class'], "shape_type": "rectangle", "flags": {} } diff --git a/ai/app/api/yolo/segmentation.py b/ai/app/api/yolo/segmentation.py index 5bc8f8f..4ab9a7b 100644 --- a/ai/app/api/yolo/segmentation.py +++ b/ai/app/api/yolo/segmentation.py @@ -9,7 +9,6 @@ router = APIRouter() @router.post("/predict", response_model=List[PredictResponse]) def predict(request: PredictRequest): - version = "0.1.0" # 모델 로드 try: @@ -37,7 +36,7 @@ def predict(request: PredictRequest): try: for (image, result) in zip(request.image_list, results): label_data:LabelData = { - "version": version, + "version": "0.0.0", "task_type": "seg", "shapes": [ { diff --git a/ai/app/schemas/predict_request.py b/ai/app/schemas/predict_request.py index e65166a..c40ef34 100644 --- a/ai/app/schemas/predict_request.py +++ b/ai/app/schemas/predict_request.py @@ -1,20 +1,15 @@ from pydantic import BaseModel, Field -from typing import List, Optional +from typing import Optional, Union class ImageInfo(BaseModel): image_id: int image_url: str -class LabelCategory(BaseModel): - label_id: int - label_name: str class PredictRequest(BaseModel): project_id: int m_key: Optional[str] = Field(None, alias="model_key") - image_list: List[ImageInfo] - version: str = "latest" + label_map: dict[int, int] = Field(None, description="모델 레이블 카테고리 idx: 프로젝트 레이블 카테고리 idx , None 일경우 모델 레이블 카테고리 idx로 레이블링") + image_list: list[ImageInfo] conf_threshold: float = 0.25 iou_threshold: float = 0.45 - classes: Optional[List[int]] = None - label_categories: Optional[List[LabelCategory]] = None diff --git a/ai/app/schemas/train_request.py b/ai/app/schemas/train_request.py index 3e9a849..97ad1f7 100644 --- a/ai/app/schemas/train_request.py +++ b/ai/app/schemas/train_request.py @@ -1,7 +1,6 @@ from pydantic import BaseModel, Field from typing import List, Optional, Union from schemas.predict_response import LabelData -from schemas.predict_request import LabelCategory class TrainDataInfo(BaseModel): image_url: str @@ -15,4 +14,3 @@ class TrainRequest(BaseModel): epochs: int = 50 # 훈련 반복 횟수 batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지 path: Optional[str] = Field(None, alias="model_path") - label_categories: Optional[List[LabelCategory]] = None # 새로운 레이블 카테고리 확인용