From 82381a52929c3bfb3d5c39335e663830b9eb1808 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A7=84=ED=98=84?= Date: Wed, 25 Sep 2024 17:59:02 +0900 Subject: [PATCH] =?UTF-8?q?Feat:=20slack=20=EC=95=8C=EB=9E=8C=EC=84=A4?= =?UTF-8?q?=EC=A0=95,=20=ED=95=99=EC=8A=B5=20=EC=A4=91=20api=20=EB=B3=B4?= =?UTF-8?q?=EB=82=B4=EA=B8=B0=20=EA=B5=AC=ED=98=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/app/api/yolo/detection.py | 27 +++++++++++++++++++-------- ai/app/schemas/predict_request.py | 2 +- ai/app/utils/api_utils.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 9 deletions(-) create mode 100644 ai/app/utils/api_utils.py diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index 5cacd20..afd0d39 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -1,6 +1,6 @@ import json -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Request from schemas.predict_request import PredictRequest from schemas.train_request import TrainRequest from schemas.predict_response import PredictResponse, LabelData @@ -10,6 +10,8 @@ from services.create_model import save_model from utils.dataset_utils import split_data from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path from utils.slackMessage import send_slack_message +from utils.api_utils import report_data +import random import asyncio, httpx @@ -33,9 +35,8 @@ async def detection_predict(request: PredictRequest): # 추론 try: results = run_predictions(model, url_list, request, classes) - print(len(results)) - print(len(request.image_list)) response = [process_prediction_result(result, image, request.label_map) for result, image in zip(results,request.image_list)] + send_slack_message(f"predict 성공{response}", status="success") return response except Exception as e: @@ -48,6 +49,7 @@ def get_model(request: PredictRequest): try: return load_detection_model(request.project_id, request.m_key) except Exception as e: + send_slack_message(f"프로젝트 ID: {request.project_id} - 실패! 에러: {str(e)}",status="error") raise HTTPException(status_code=500, detail="load model exception: " + str(e)) # 추론 실행 함수 @@ -60,18 +62,22 @@ def run_predictions(model, image, request, classes): classes=classes ) except Exception as e: + send_slack_message(f"프로젝트 ID: {request.project_id} - 실패! 에러: {str(e)}",status="error") raise HTTPException(status_code=500, detail="model predict exception: " + str(e)) # 추론 결과 처리 함수 def process_prediction_result(result, image, label_map): + random_number = random.randint(0, 0xFFFFFF) + color = f"{random_number:06X}" + label_data = LabelData( version="0.0.0", task_type="det", shapes=[ { "label": summary['name'], - "color": "#ff0000", + "color": color, "points": [ [summary['box']['x1'], summary['box']['y1']], [summary['box']['x2'], summary['box']['y2']] @@ -96,7 +102,11 @@ def process_prediction_result(result, image, label_map): @router.post("/train") -async def detection_train(request: TrainRequest): +async def detection_train(request: TrainRequest, http_request: Request): + # Authorization 헤더에서 Bearer 토큰 추출 + auth_header = http_request.headers.get("Authorization") + + token = auth_header.split(" ")[1] if auth_header and auth_header.startswith("Bearer ") else None send_slack_message(f"train 요청{request}", status="success") @@ -148,6 +158,7 @@ async def detection_train(request: TrainRequest): left_seconds= left_seconds # 남은 시간(초) ) # 데이터 전송 + report_data(request.project_id, request.m_id, data, token) model.add_callback("on_train_epoch_start", send_data) @@ -162,10 +173,10 @@ async def detection_train(request: TrainRequest): ) model_key = save_model(project_id=request.project_id, path=join_path(dataset_root_path, "result", "weights", "best.pt")) - - return {"model_key": model_key, "results": results.results_dict} + response = {"model_key": model_key, "results": results.results_dict} + send_slack_message(f"train 성공{response}", status="success") + return response except Exception as e: raise HTTPException(status_code=500, detail="model train exception: " + str(e)) - diff --git a/ai/app/schemas/predict_request.py b/ai/app/schemas/predict_request.py index c40ef34..15cbc5c 100644 --- a/ai/app/schemas/predict_request.py +++ b/ai/app/schemas/predict_request.py @@ -8,7 +8,7 @@ class ImageInfo(BaseModel): class PredictRequest(BaseModel): project_id: int - m_key: Optional[str] = Field(None, alias="model_key") + m_key: str = Field("yolo8", alias="model_key") label_map: dict[int, int] = Field(None, description="모델 레이블 카테고리 idx: 프로젝트 레이블 카테고리 idx , None 일경우 모델 레이블 카테고리 idx로 레이블링") image_list: list[ImageInfo] conf_threshold: float = 0.25 diff --git a/ai/app/utils/api_utils.py b/ai/app/utils/api_utils.py new file mode 100644 index 0000000..5f11476 --- /dev/null +++ b/ai/app/utils/api_utils.py @@ -0,0 +1,28 @@ +from schemas.train_report_data import ReportData +from dotenv import load_dotenv +import os, httpx + + +def report_data(project_id:int, model_id:int, data:ReportData, token): + try: + load_dotenv() + # main.py와 같은 디렉토리에 .env 파일 생성해서 따옴표 없이 입력 + # API_BASE_URL = {url} + # API_KEY = {key} + base_url = os.getenv("API_BASE_URL") + headers = { + "Content-Type": "application/json" + } + if token: + headers["Authorization"] = f"Bearer {token}" + + response = httpx.request( + method="POST", + url=base_url+f"/api/projects/{project_id}/reports/models/{model_id}", + json=data.model_dump(), + headers=headers + ) + # status에 따라 예외 발생 + response.raise_for_status() + except Exception as e: + print("report data failed: "+str(e))