From 7016d3a91ec8245b68822d7e1993cb1ed5af5f88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=A7=84=ED=98=84?= Date: Fri, 27 Sep 2024 16:02:02 +0900 Subject: [PATCH] =?UTF-8?q?Fix:=20classification=20process=5Fprediction=5F?= =?UTF-8?q?result()=20=EC=97=90=EB=9F=AC=20=EC=88=98=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ai/app/api/yolo/classfication.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/ai/app/api/yolo/classfication.py b/ai/app/api/yolo/classfication.py index 0accc43..e870a82 100644 --- a/ai/app/api/yolo/classfication.py +++ b/ai/app/api/yolo/classfication.py @@ -42,39 +42,38 @@ def get_model(project_id:int, model_key:str): # 추론 결과 처리 함수 def process_prediction_result(result, image, label_map): try: - label_name = None + shapes = [] # top 5에 해당하는 class id 순회 for class_id in result.probs.top5: - name = result.names[class_id] # class id에 해당하는 label_name - if name in label_map: # name이 사용자 레이블 카테고리에 있을 경우 - label_name = name # label_name 설정 + label_name = result.names[class_id] # class id에 해당하는 label_name + if label_name in label_map: # name이 사용자 레이블 카테고리에 있을 경우 + shapes = [ + Shape( + label=label_name, + color=get_random_color(), + points=[[0.0, 0.0]], + group_id=label_map[label_name], + shape_type='point', + flags={} + ) + ] # label_name 설정 break label_data = LabelData( version="0.0.0", task_type="cls", - shapes=[], + shapes=shapes, split="none", imageHeight=result.orig_img.shape[0], imageWidth=result.orig_img.shape[1], imageDepth=result.orig_img.shape[2] ) - if label_name: # label_name을 설정한게 있다면 추가 - shape = Shape( - label= label_name, - color= get_random_color(), - points= [[0.0, 0.0]], - group_id= label_map[label_name], - shape_type= 'point', - flags= {} - ) - LabelData.shapes.append(shape) - return PredictResponse( image_id=image.image_id, data=label_data.model_dump_json() ) + except KeyError as e: raise HTTPException(status_code=500, detail="KeyError: " + str(e)) except Exception as e: