diff --git a/ai/app/api/yolo/classfication.py b/ai/app/api/yolo/classfication.py index 0accc43..e870a82 100644 --- a/ai/app/api/yolo/classfication.py +++ b/ai/app/api/yolo/classfication.py @@ -42,39 +42,38 @@ def get_model(project_id:int, model_key:str): # 추론 결과 처리 함수 def process_prediction_result(result, image, label_map): try: - label_name = None + shapes = [] # top 5에 해당하는 class id 순회 for class_id in result.probs.top5: - name = result.names[class_id] # class id에 해당하는 label_name - if name in label_map: # name이 사용자 레이블 카테고리에 있을 경우 - label_name = name # label_name 설정 + label_name = result.names[class_id] # class id에 해당하는 label_name + if label_name in label_map: # name이 사용자 레이블 카테고리에 있을 경우 + shapes = [ + Shape( + label=label_name, + color=get_random_color(), + points=[[0.0, 0.0]], + group_id=label_map[label_name], + shape_type='point', + flags={} + ) + ] # label_name 설정 break label_data = LabelData( version="0.0.0", task_type="cls", - shapes=[], + shapes=shapes, split="none", imageHeight=result.orig_img.shape[0], imageWidth=result.orig_img.shape[1], imageDepth=result.orig_img.shape[2] ) - if label_name: # label_name을 설정한게 있다면 추가 - shape = Shape( - label= label_name, - color= get_random_color(), - points= [[0.0, 0.0]], - group_id= label_map[label_name], - shape_type= 'point', - flags= {} - ) - LabelData.shapes.append(shape) - return PredictResponse( image_id=image.image_id, data=label_data.model_dump_json() ) + except KeyError as e: raise HTTPException(status_code=500, detail="KeyError: " + str(e)) except Exception as e: