Feat: detection/predict 프로젝트 명세 변경에 따른 수정
This commit is contained in:
parent
cae5fb5ae4
commit
2d8609077c
@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from schemas.predict_request import PredictRequest
|
||||
from schemas.train_request import TrainRequest
|
||||
from schemas.predict_response import PredictResponse, LabelData
|
||||
from schemas.predict_response import PredictResponse, LabelData, Shape
|
||||
from schemas.train_report_data import ReportData
|
||||
from schemas.train_response import TrainResponse
|
||||
from services.load_model import load_detection_model
|
||||
@ -21,14 +21,14 @@ async def detection_predict(request: PredictRequest):
|
||||
send_slack_message(f"predict 요청: {request}", status="success")
|
||||
|
||||
# 모델 로드
|
||||
model = get_model(request)
|
||||
|
||||
# 모델 레이블 카테고리 연결
|
||||
classes = list(request.label_map) if request.label_map else None
|
||||
model = get_model(request.project_id, request.m_key)
|
||||
|
||||
# 이미지 데이터 정리
|
||||
url_list = list(map(lambda x:x.image_url, request.image_list))
|
||||
|
||||
# 이 값을 모델에 입력하면 해당하는 클래스 id만 출력됨
|
||||
classes = get_classes(request.label_map, model.names)
|
||||
|
||||
# 추론
|
||||
results = run_predictions(model, url_list, request, classes)
|
||||
|
||||
@ -38,11 +38,18 @@ async def detection_predict(request: PredictRequest):
|
||||
return response
|
||||
|
||||
# 모델 로드
|
||||
def get_model(request: PredictRequest):
|
||||
def get_model(project_id, model_key):
|
||||
try:
|
||||
return load_detection_model(request.project_id, request.m_key)
|
||||
return load_detection_model(project_id, model_key)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="load model exception: " + str(e))
|
||||
raise HTTPException(status_code=500, detail="exception in get_model(): " + str(e))
|
||||
|
||||
# 모델의 레이블로부터 label_map의 key에 존재하는 값의 id만 가져오기
|
||||
def get_classes(label_map:dict[str: int], model_names: dict[int, str]):
|
||||
try:
|
||||
return [id for id, name in model_names.items() if name in label_map]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="exception in get_classes(): " + str(e))
|
||||
|
||||
# 추론 실행 함수
|
||||
def run_predictions(model, image, request, classes):
|
||||
@ -54,7 +61,7 @@ def run_predictions(model, image, request, classes):
|
||||
classes=classes
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="model predict exception: " + str(e))
|
||||
raise HTTPException(status_code=500, detail="exception in run_predictions: " + str(e))
|
||||
|
||||
|
||||
# 추론 결과 처리 함수
|
||||
@ -64,17 +71,17 @@ def process_prediction_result(result, image, label_map):
|
||||
version="0.0.0",
|
||||
task_type="det",
|
||||
shapes=[
|
||||
{
|
||||
"label": summary['name'],
|
||||
"color": get_random_color(),
|
||||
"points": [
|
||||
Shape(
|
||||
label= summary['name'],
|
||||
color= get_random_color(),
|
||||
points= [
|
||||
[summary['box']['x1'], summary['box']['y1']],
|
||||
[summary['box']['x2'], summary['box']['y2']]
|
||||
],
|
||||
"group_id": label_map[summary['class']] if label_map else summary['class'],
|
||||
"shape_type": "rectangle",
|
||||
"flags": {}
|
||||
}
|
||||
group_id= label_map[summary['name']],
|
||||
shape_type= "rectangle",
|
||||
flags= {}
|
||||
)
|
||||
for summary in result.summary()
|
||||
],
|
||||
split="none",
|
||||
@ -82,8 +89,10 @@ def process_prediction_result(result, image, label_map):
|
||||
imageWidth=result.orig_img.shape[1],
|
||||
imageDepth=result.orig_img.shape[2]
|
||||
)
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=500, detail="KeyError: " + str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="model predict exception: " + str(e))
|
||||
raise HTTPException(status_code=500, detail="exception in process_prediction_result(): " + str(e))
|
||||
|
||||
return PredictResponse(
|
||||
image_id=image.image_id,
|
||||
@ -94,8 +103,6 @@ def get_random_color():
|
||||
random_number = random.randint(0, 0xFFFFFF)
|
||||
return f"#{random_number:06X}"
|
||||
|
||||
|
||||
|
||||
@router.post("/train")
|
||||
async def detection_train(request: TrainRequest):
|
||||
|
||||
|
@ -8,8 +8,8 @@ class ImageInfo(BaseModel):
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
project_id: int
|
||||
m_key: str = Field("yolo8", alias="model_key")
|
||||
label_map: dict[int, int] = Field(None, description="모델 레이블 카테고리 idx: 프로젝트 레이블 카테고리 idx , None 일경우 모델 레이블 카테고리 idx로 레이블링")
|
||||
image_list: list[ImageInfo]
|
||||
conf_threshold: float = 0.25
|
||||
m_key: str = Field("yolo8", alias="model_key") # model_ 로 시작하는 변수를 BaseModel의 변수로 만들경우 Warning 떠서 m_key로 대체
|
||||
label_map: dict[str, int] = Field(..., description="프로젝트 레이블 이름: 프로젝트 레이블 pk , None일 경우 모델 레이블 카테고리 idx로 레이블링")
|
||||
image_list: list[ImageInfo] # 이미지 리스트
|
||||
conf_threshold: float = 0.25 #
|
||||
iou_threshold: float = 0.45
|
||||
|
Loading…
Reference in New Issue
Block a user