diff --git a/ai/.gitignore b/ai/.gitignore index 7c7931f..b7baa62 100644 --- a/ai/.gitignore +++ b/ai/.gitignore @@ -37,4 +37,6 @@ test-data/ # 리소스 resources/ datasets/ -*.pt \ No newline at end of file +*.pt + +*.jpg \ No newline at end of file diff --git a/ai/app/api/yolo/detection.py b/ai/app/api/yolo/detection.py index d0038b1..041fec0 100644 --- a/ai/app/api/yolo/detection.py +++ b/ai/app/api/yolo/detection.py @@ -25,139 +25,89 @@ async def detection_predict(request: PredictRequest): ws_client = WebSocketClient(spring_server_ws_url) # 모델 로드 - try: - model_path = request.m_key and get_model_path(request.project_id, request.m_key) - model = load_detection_model(model_path=model_path) - except Exception as e: - raise HTTPException(status_code=500, detail="load model exception: " + str(e)) + model = load_model(request) # 모델 레이블 카테고리 연결 - classes = None - if request.label_map: - classes = list(request.label_map) + classes = list(request.label_map) if request.label_map else None + # 결과를 저장할 리스트 + response = [] # 웹소켓 연결 + await connect_to_websocket(ws_client) + + # 추론 try: - await ws_client.connect() - if not ws_client.is_connected(): - raise WebSocketConnectionException() - - # 추론 - total_images = len(request.image_list) for idx, image in enumerate(request.image_list): - try: - # URL에서 이미지를 메모리로 로드 TODO: 추후 메모리에 할지 어떻게 해야할지 or 병렬 처리 고민 - predict_results = model.predict( - source=image.image_url, - iou=request.iou_threshold, - conf=request.conf_threshold, - classes=classes - ) - # 예측 결과 처리 - result = predict_results[0] - label_data = LabelData( - version="0.0.0", - task_type="det", - shapes=[ - { - "label": summary['name'], - "color": "#ff0000", - "points": [ - [summary['box']['x1'], summary['box']['y1']], - [summary['box']['x2'], summary['box']['y2']] - ], - "group_id": request.label_map[summary['class']] if request.label_map else summary['class'], - "shape_type": "rectangle", - "flags": {} - } - for summary in result.summary() - ], - split="none", - imageHeight=result.orig_img.shape[0], - imageWidth=result.orig_img.shape[1], - imageDepth=result.orig_img.shape[2] - ) + result = run_predictions(model, image, request, classes) + response_item = process_prediction_result(result, request, image) + response.append(response_item) - response_item = PredictResponse( - image_id=image.image_id, - image_url=image.image_url, - data=label_data - ) - - # 진행률 계산 - progress = (idx + 1) / total_images * 100 - - # 웹소켓으로 예측 결과와 진행률 전송 - message = { - "project_id": request.project_id, - "progress": progress, - "result": response_item.model_dump() - } - - await ws_client.send_message("/app/ai/predict/progress", json.dumps(message)) - - except Exception as e: - raise HTTPException(status_code=500, detail="model predict exception: " + str(e)) - return Response(status_code=204) - - # 웹소켓 연결 안된 경우 - except WebSocketConnectionException as e: - # 추론 - response = [] - for image in request.image_list: - try: - predict_results = model.predict( - source=image.image_url, - iou=request.iou_threshold, - conf=request.conf_threshold, - classes=classes - ) - # 예측 결과 처리 - result = predict_results[0] - label_data = LabelData( - version="0.0.0", - task_type="det", - shapes=[ - { - "label": summary['name'], - "color": "#ff0000", - "points": [ - [summary['box']['x1'], summary['box']['y1']], - [summary['box']['x2'], summary['box']['y2']] - ], - "group_id": request.label_map[summary['class']] if request.label_map else summary['class'], - "shape_type": "rectangle", - "flags": {} - } - for summary in result.summary() - ], - split="none", - imageHeight=result.orig_img.shape[0], - imageWidth=result.orig_img.shape[1], - imageDepth=result.orig_img.shape[2] - ) - - response_item = PredictResponse( - image_id=image.image_id, - image_url=image.image_url, - data=label_data - ) - - response.append(response_item) - - except Exception as e: - raise HTTPException(status_code=500, detail="model predict exception: " + str(e)) - return response - except Exception as e: - print(f"Prediction process failed: {str(e)}") - raise HTTPException(status_code=500, detail="Prediction process failed") finally: if ws_client.is_connected(): await ws_client.close() +# 모델 로드 +def load_model(request: PredictRequest): + try: + model_path = request.m_key and get_model_path(request.project_id, request.m_key) + return load_detection_model(model_path=model_path) + except Exception as e: + raise HTTPException(status_code=500, detail="load model exception: " + str(e)) + +# 웹소켓 연결 +async def connect_to_websocket(ws_client): + try: + await ws_client.connect() + if not ws_client.is_connected(): + raise WebSocketConnectionException("웹 소켓 연결 실패") + except Exception as e: + raise HTTPException(status_code=500, detail="websocket connect failed: " + str(e)) + +# 추론 실행 함수 +def run_predictions(model, image, request, classes): + try: + predict_results = model.predict( + source=image.image_url, + iou=request.iou_threshold, + conf=request.conf_threshold, + classes=classes + ) + return predict_results[0] + except Exception as e: + raise HTTPException(status_code=500, detail="model predict exception: " + str(e)) + +# 추론 결과 처리 함수 +def process_prediction_result(result, request, image): + label_data = LabelData( + version="0.0.0", + task_type="det", + shapes=[ + { + "label": summary['name'], + "color": "#ff0000", + "points": [ + [summary['box']['x1'], summary['box']['y1']], + [summary['box']['x2'], summary['box']['y2']] + ], + "group_id": request.label_map[summary['class']] if request.label_map else summary['class'], + "shape_type": "rectangle", + "flags": {} + } + for summary in result.summary() + ], + split="none", + imageHeight=result.orig_img.shape[0], + imageWidth=result.orig_img.shape[1], + imageDepth=result.orig_img.shape[2] + ) + + return PredictResponse( + image_id=image.image_id, + data=json.dumps(label_data.dict()) + ) @router.post("/train") async def detection_train(request: TrainRequest): diff --git a/ai/app/schemas/predict_response.py b/ai/app/schemas/predict_response.py index e34825c..3570b32 100644 --- a/ai/app/schemas/predict_response.py +++ b/ai/app/schemas/predict_response.py @@ -20,5 +20,4 @@ class LabelData(BaseModel): class PredictResponse(BaseModel): image_id: int - image_url: str - data: LabelData \ No newline at end of file + data: str \ No newline at end of file