Merge branch 'ai/develop' of https://lab.ssafy.com/s11-s-project/S11P21S002 into ai/refactor/predict
This commit is contained in:
commit
20f8d9730a
@ -9,6 +9,7 @@ from services.create_model import save_model
|
||||
from utils.dataset_utils import split_data
|
||||
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path, get_model_path
|
||||
from utils.websocket_utils import WebSocketClient, WebSocketConnectionException
|
||||
from utils.slackMessage import send_slack_message
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
@ -17,12 +18,15 @@ router = APIRouter()
|
||||
|
||||
@router.post("/predict")
|
||||
async def detection_predict(request: PredictRequest):
|
||||
|
||||
send_slack_message(f"predict 요청{request}", status="success")
|
||||
|
||||
# Spring 서버의 WebSocket URL
|
||||
# TODO: 배포 시 변경
|
||||
spring_server_ws_url = f"ws://localhost:8080/ws"
|
||||
# spring_server_ws_url = f"ws://localhost:8080/ws"
|
||||
|
||||
# WebSocketClient 인스턴스 생성
|
||||
ws_client = WebSocketClient(spring_server_ws_url)
|
||||
# ws_client = WebSocketClient(spring_server_ws_url)
|
||||
|
||||
# 모델 로드
|
||||
model = load_model(request)
|
||||
@ -34,7 +38,7 @@ async def detection_predict(request: PredictRequest):
|
||||
response = []
|
||||
|
||||
# 웹소켓 연결
|
||||
await connect_to_websocket(ws_client)
|
||||
# await connect_to_websocket(ws_client)
|
||||
|
||||
# 추론
|
||||
try:
|
||||
@ -45,9 +49,15 @@ async def detection_predict(request: PredictRequest):
|
||||
|
||||
return response
|
||||
|
||||
finally:
|
||||
if ws_client.is_connected():
|
||||
await ws_client.close()
|
||||
except Exception as e:
|
||||
# 실패했을 때 Slack 알림
|
||||
send_slack_message(f"프로젝트 ID: {request.project_id} - 실패! 에러: {str(e)}",status="error")
|
||||
raise HTTPException(status_code=500, detail="Prediction process failed")
|
||||
|
||||
# finally:
|
||||
# send_slack_message("종료")
|
||||
# if ws_client.is_connected():
|
||||
# await ws_client.close()
|
||||
|
||||
# 모델 로드
|
||||
def load_model(request: PredictRequest):
|
||||
|
27
ai/app/utils/slackMessage.py
Normal file
27
ai/app/utils/slackMessage.py
Normal file
@ -0,0 +1,27 @@
|
||||
import httpx
|
||||
import os
|
||||
|
||||
SLACK_WEBHOOK_URL = "https://hooks.slack.com/services/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"
|
||||
|
||||
def send_slack_message(message: str, status: str = "info"):
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# 상태에 따라 다른 메시지 형식 적용 (성공, 에러)
|
||||
if status == "error":
|
||||
formatted_message = f":x: 에러 발생: {message}"
|
||||
elif status == "success":
|
||||
formatted_message = f":white_check_mark: {message}"
|
||||
else:
|
||||
formatted_message = message
|
||||
|
||||
# Slack에 전송할 페이로드
|
||||
payload = {
|
||||
"text": formatted_message
|
||||
}
|
||||
|
||||
response = httpx.post(SLACK_WEBHOOK_URL, json=payload, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
return "Message sent successfully"
|
||||
else:
|
||||
return f"Failed to send message. Status code: {response.status_code}"
|
@ -18,3 +18,4 @@ dependencies:
|
||||
- python-dotenv
|
||||
- locust
|
||||
- websockets
|
||||
- httpx
|
||||
|
Loading…
Reference in New Issue
Block a user