Refactor: 레이블 카테고리를 포함한 오토레이블링 구현 및 리팩토링

This commit is contained in:
김진현 2024-09-23 09:46:42 +09:00
parent c7d07e5856
commit 05e9a2c03b
4 changed files with 16 additions and 21 deletions

View File

@ -15,8 +15,6 @@ router = APIRouter()
@router.post("/predict")
async def detection_predict(request: PredictRequest):
version = "0.1.0"
# Spring 서버의 WebSocket URL
# TODO: 배포 시 변경
spring_server_ws_url = f"ws://localhost:8080/ws"
@ -30,6 +28,12 @@ async def detection_predict(request: PredictRequest):
model = load_detection_model(model_path=model_path)
except Exception as e:
raise HTTPException(status_code=500, detail="load model exception: " + str(e))
# 모델 레이블 카테고리 연결
classes = None
if request.label_map:
classes = list(request.label_map)
# 웹소켓 연결
try:
@ -46,12 +50,12 @@ async def detection_predict(request: PredictRequest):
source=image.image_url,
iou=request.iou_threshold,
conf=request.conf_threshold,
classes=request.classes
classes=classes
)
# 예측 결과 처리
result = predict_results[0]
label_data = LabelData(
version=version,
version="0.0.0",
task_type="det",
shapes=[
{
@ -61,7 +65,7 @@ async def detection_predict(request: PredictRequest):
[summary['box']['x1'], summary['box']['y1']],
[summary['box']['x2'], summary['box']['y2']]
],
"group_id": summary['class'],
"group_id": request.label_map[summary['class']] if request.label_map else summary['class'],
"shape_type": "rectangle",
"flags": {}
}
@ -105,13 +109,12 @@ async def detection_predict(request: PredictRequest):
source=image.image_url,
iou=request.iou_threshold,
conf=request.conf_threshold,
classes=request.classes
classes=classes
)
# 예측 결과 처리
result = predict_results[0]
label_data = LabelData(
version=version,
version="0.0.0",
task_type="det",
shapes=[
{
@ -121,7 +124,7 @@ async def detection_predict(request: PredictRequest):
[summary['box']['x1'], summary['box']['y1']],
[summary['box']['x2'], summary['box']['y2']]
],
"group_id": summary['class'],
"group_id": request.label_map[summary['class']] if request.label_map else summary['class'],
"shape_type": "rectangle",
"flags": {}
}

View File

@ -9,7 +9,6 @@ router = APIRouter()
@router.post("/predict", response_model=List[PredictResponse])
def predict(request: PredictRequest):
version = "0.1.0"
# 모델 로드
try:
@ -37,7 +36,7 @@ def predict(request: PredictRequest):
try:
for (image, result) in zip(request.image_list, results):
label_data:LabelData = {
"version": version,
"version": "0.0.0",
"task_type": "seg",
"shapes": [
{

View File

@ -1,20 +1,15 @@
from pydantic import BaseModel, Field
from typing import List, Optional
from typing import Optional, Union
class ImageInfo(BaseModel):
image_id: int
image_url: str
class LabelCategory(BaseModel):
label_id: int
label_name: str
class PredictRequest(BaseModel):
project_id: int
m_key: Optional[str] = Field(None, alias="model_key")
image_list: List[ImageInfo]
version: str = "latest"
label_map: dict[int, int] = Field(None, description="모델 레이블 카테고리 idx: 프로젝트 레이블 카테고리 idx , None 일경우 모델 레이블 카테고리 idx로 레이블링")
image_list: list[ImageInfo]
conf_threshold: float = 0.25
iou_threshold: float = 0.45
classes: Optional[List[int]] = None
label_categories: Optional[List[LabelCategory]] = None

View File

@ -1,7 +1,6 @@
from pydantic import BaseModel, Field
from typing import List, Optional, Union
from schemas.predict_response import LabelData
from schemas.predict_request import LabelCategory
class TrainDataInfo(BaseModel):
image_url: str
@ -15,4 +14,3 @@ class TrainRequest(BaseModel):
epochs: int = 50 # 훈련 반복 횟수
batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지
path: Optional[str] = Field(None, alias="model_path")
label_categories: Optional[List[LabelCategory]] = None # 새로운 레이블 카테고리 확인용