Feat: detection/predict 프로젝트 명세 변경에 따른 수정

This commit is contained in:
김진현 2024-09-27 10:38:01 +09:00
parent cae5fb5ae4
commit 2d8609077c
2 changed files with 31 additions and 24 deletions

View File

@ -1,7 +1,7 @@
from fastapi import APIRouter, HTTPException, Request
from schemas.predict_request import PredictRequest
from schemas.train_request import TrainRequest
from schemas.predict_response import PredictResponse, LabelData
from schemas.predict_response import PredictResponse, LabelData, Shape
from schemas.train_report_data import ReportData
from schemas.train_response import TrainResponse
from services.load_model import load_detection_model
@ -21,14 +21,14 @@ async def detection_predict(request: PredictRequest):
send_slack_message(f"predict 요청: {request}", status="success")
# 모델 로드
model = get_model(request)
# 모델 레이블 카테고리 연결
classes = list(request.label_map) if request.label_map else None
model = get_model(request.project_id, request.m_key)
# 이미지 데이터 정리
url_list = list(map(lambda x:x.image_url, request.image_list))
# 이 값을 모델에 입력하면 해당하는 클래스 id만 출력됨
classes = get_classes(request.label_map, model.names)
# 추론
results = run_predictions(model, url_list, request, classes)
@ -38,11 +38,18 @@ async def detection_predict(request: PredictRequest):
return response
# 모델 로드
def get_model(request: PredictRequest):
def get_model(project_id, model_key):
try:
return load_detection_model(request.project_id, request.m_key)
return load_detection_model(project_id, model_key)
except Exception as e:
raise HTTPException(status_code=500, detail="load model exception: " + str(e))
raise HTTPException(status_code=500, detail="exception in get_model(): " + str(e))
# 모델의 레이블로부터 label_map의 key에 존재하는 값의 id만 가져오기
def get_classes(label_map:dict[str: int], model_names: dict[int, str]):
try:
return [id for id, name in model_names.items() if name in label_map]
except Exception as e:
raise HTTPException(status_code=500, detail="exception in get_classes(): " + str(e))
# 추론 실행 함수
def run_predictions(model, image, request, classes):
@ -54,7 +61,7 @@ def run_predictions(model, image, request, classes):
classes=classes
)
except Exception as e:
raise HTTPException(status_code=500, detail="model predict exception: " + str(e))
raise HTTPException(status_code=500, detail="exception in run_predictions: " + str(e))
# 추론 결과 처리 함수
@ -64,17 +71,17 @@ def process_prediction_result(result, image, label_map):
version="0.0.0",
task_type="det",
shapes=[
{
"label": summary['name'],
"color": get_random_color(),
"points": [
Shape(
label= summary['name'],
color= get_random_color(),
points= [
[summary['box']['x1'], summary['box']['y1']],
[summary['box']['x2'], summary['box']['y2']]
],
"group_id": label_map[summary['class']] if label_map else summary['class'],
"shape_type": "rectangle",
"flags": {}
}
group_id= label_map[summary['name']],
shape_type= "rectangle",
flags= {}
)
for summary in result.summary()
],
split="none",
@ -82,8 +89,10 @@ def process_prediction_result(result, image, label_map):
imageWidth=result.orig_img.shape[1],
imageDepth=result.orig_img.shape[2]
)
except KeyError as e:
raise HTTPException(status_code=500, detail="KeyError: " + str(e))
except Exception as e:
raise HTTPException(status_code=500, detail="model predict exception: " + str(e))
raise HTTPException(status_code=500, detail="exception in process_prediction_result(): " + str(e))
return PredictResponse(
image_id=image.image_id,
@ -94,8 +103,6 @@ def get_random_color():
random_number = random.randint(0, 0xFFFFFF)
return f"#{random_number:06X}"
@router.post("/train")
async def detection_train(request: TrainRequest):

View File

@ -8,8 +8,8 @@ class ImageInfo(BaseModel):
class PredictRequest(BaseModel):
project_id: int
m_key: str = Field("yolo8", alias="model_key")
label_map: dict[int, int] = Field(None, description="모델 레이블 카테고리 idx: 프로젝트 레이블 카테고리 idx , None 일경우 모델 레이블 카테고리 idx로 레이블링")
image_list: list[ImageInfo]
conf_threshold: float = 0.25
m_key: str = Field("yolo8", alias="model_key") # model_ 로 시작하는 변수를 BaseModel의 변수로 만들경우 Warning 떠서 m_key로 대체
label_map: dict[str, int] = Field(..., description="프로젝트 레이블 이름: 프로젝트 레이블 pk , None일 경우 모델 레이블 카테고리 idx로 레이블링")
image_list: list[ImageInfo] # 이미지 리스트
conf_threshold: float = 0.25 #
iou_threshold: float = 0.45