From 7b10eeccaae093d78b0f610b8c0ca2a1eed4400c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A7=84=ED=98=84?= Date: Wed, 25 Sep 2024 16:06:45 +0900 Subject: [PATCH] =?UTF-8?q?Refactor:=20=ED=95=99=EC=8A=B5=20=EB=A0=88?= =?UTF-8?q?=EC=9D=B4=EB=B8=94=20data=5Furl=EB=A1=9C=20=EC=88=98=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/app/api/yolo/detection.py | 4 ++-- ai/app/schemas/train_request.py | 2 +- ai/app/utils/file_utils.py | 14 ++------------ 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index ff542d3..5cacd20 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -10,7 +10,7 @@ from services.create_model import save_model from utils.dataset_utils import split_data from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path from utils.slackMessage import send_slack_message -import asyncio +import asyncio, httpx router = APIRouter() @@ -18,7 +18,7 @@ router = APIRouter() @router.post("/predict") async def detection_predict(request: PredictRequest): - # send_slack_message(f"predict 요청{request}", status="success") + send_slack_message(f"predict 요청{request}", status="success") # 모델 로드 model = get_model(request) diff --git a/ai/app/schemas/train_request.py b/ai/app/schemas/train_request.py index f60a538..0c84ac5 100644 --- a/ai/app/schemas/train_request.py +++ b/ai/app/schemas/train_request.py @@ -4,7 +4,7 @@ from schemas.predict_response import LabelData class TrainDataInfo(BaseModel): image_url: str - label: str + data_url: str class TrainRequest(BaseModel): project_id: int diff --git a/ai/app/utils/file_utils.py b/ai/app/utils/file_utils.py index a68a2d5..9047d85 100644 --- a/ai/app/utils/file_utils.py +++ b/ai/app/utils/file_utils.py @@ -55,8 +55,8 @@ def process_image_and_label(data:TrainDataInfo, dataset_root_path:str, child_pat # 레이블 파일 경로 label_path = os.path.join(dataset_root_path, child_path, f"{img_title}.txt") - # 레이블 역직렬화 - label = json_to_object(data.label) + # 레이블 객체 불러오기 + label = json.loads(urllib.request.urlopen(data.data_url).read()) # 레이블 -> 학습용 레이블 데이터 파싱 후 생성 create_detection_train_label(label, label_path, label_map) @@ -104,13 +104,3 @@ def get_file_name(path): if not os.path.exists(path): raise FileNotFoundError() return os.path.basename(path) - -def json_to_object(json_string): - try: - # JSON 문자열을 Python 객체로 변환 - python_object = json.loads(json_string) - return python_object - except json.JSONDecodeError as e: - raise json.JSONDecodeError("json_decode_error:"+str(e)) - except Exception as e: - raise Exception("exception at json_to_object:"+str(e)) \ No newline at end of file