Refactor: 레이블 카테고리를 포함한 오토레이블링 구현 및 리팩토링
This commit is contained in:
parent
c7d07e5856
commit
05e9a2c03b
@ -15,8 +15,6 @@ router = APIRouter()
|
||||
|
||||
@router.post("/predict")
|
||||
async def detection_predict(request: PredictRequest):
|
||||
version = "0.1.0"
|
||||
|
||||
# Spring 서버의 WebSocket URL
|
||||
# TODO: 배포 시 변경
|
||||
spring_server_ws_url = f"ws://localhost:8080/ws"
|
||||
@ -30,6 +28,12 @@ async def detection_predict(request: PredictRequest):
|
||||
model = load_detection_model(model_path=model_path)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="load model exception: " + str(e))
|
||||
|
||||
# 모델 레이블 카테고리 연결
|
||||
classes = None
|
||||
if request.label_map:
|
||||
classes = list(request.label_map)
|
||||
|
||||
|
||||
# 웹소켓 연결
|
||||
try:
|
||||
@ -46,12 +50,12 @@ async def detection_predict(request: PredictRequest):
|
||||
source=image.image_url,
|
||||
iou=request.iou_threshold,
|
||||
conf=request.conf_threshold,
|
||||
classes=request.classes
|
||||
classes=classes
|
||||
)
|
||||
# 예측 결과 처리
|
||||
result = predict_results[0]
|
||||
label_data = LabelData(
|
||||
version=version,
|
||||
version="0.0.0",
|
||||
task_type="det",
|
||||
shapes=[
|
||||
{
|
||||
@ -61,7 +65,7 @@ async def detection_predict(request: PredictRequest):
|
||||
[summary['box']['x1'], summary['box']['y1']],
|
||||
[summary['box']['x2'], summary['box']['y2']]
|
||||
],
|
||||
"group_id": summary['class'],
|
||||
"group_id": request.label_map[summary['class']] if request.label_map else summary['class'],
|
||||
"shape_type": "rectangle",
|
||||
"flags": {}
|
||||
}
|
||||
@ -105,13 +109,12 @@ async def detection_predict(request: PredictRequest):
|
||||
source=image.image_url,
|
||||
iou=request.iou_threshold,
|
||||
conf=request.conf_threshold,
|
||||
classes=request.classes
|
||||
classes=classes
|
||||
)
|
||||
|
||||
# 예측 결과 처리
|
||||
result = predict_results[0]
|
||||
label_data = LabelData(
|
||||
version=version,
|
||||
version="0.0.0",
|
||||
task_type="det",
|
||||
shapes=[
|
||||
{
|
||||
@ -121,7 +124,7 @@ async def detection_predict(request: PredictRequest):
|
||||
[summary['box']['x1'], summary['box']['y1']],
|
||||
[summary['box']['x2'], summary['box']['y2']]
|
||||
],
|
||||
"group_id": summary['class'],
|
||||
"group_id": request.label_map[summary['class']] if request.label_map else summary['class'],
|
||||
"shape_type": "rectangle",
|
||||
"flags": {}
|
||||
}
|
||||
|
@ -9,7 +9,6 @@ router = APIRouter()
|
||||
|
||||
@router.post("/predict", response_model=List[PredictResponse])
|
||||
def predict(request: PredictRequest):
|
||||
version = "0.1.0"
|
||||
|
||||
# 모델 로드
|
||||
try:
|
||||
@ -37,7 +36,7 @@ def predict(request: PredictRequest):
|
||||
try:
|
||||
for (image, result) in zip(request.image_list, results):
|
||||
label_data:LabelData = {
|
||||
"version": version,
|
||||
"version": "0.0.0",
|
||||
"task_type": "seg",
|
||||
"shapes": [
|
||||
{
|
||||
|
@ -1,20 +1,15 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
class ImageInfo(BaseModel):
|
||||
image_id: int
|
||||
image_url: str
|
||||
|
||||
class LabelCategory(BaseModel):
|
||||
label_id: int
|
||||
label_name: str
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
project_id: int
|
||||
m_key: Optional[str] = Field(None, alias="model_key")
|
||||
image_list: List[ImageInfo]
|
||||
version: str = "latest"
|
||||
label_map: dict[int, int] = Field(None, description="모델 레이블 카테고리 idx: 프로젝트 레이블 카테고리 idx , None 일경우 모델 레이블 카테고리 idx로 레이블링")
|
||||
image_list: list[ImageInfo]
|
||||
conf_threshold: float = 0.25
|
||||
iou_threshold: float = 0.45
|
||||
classes: Optional[List[int]] = None
|
||||
label_categories: Optional[List[LabelCategory]] = None
|
||||
|
@ -1,7 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Union
|
||||
from schemas.predict_response import LabelData
|
||||
from schemas.predict_request import LabelCategory
|
||||
|
||||
class TrainDataInfo(BaseModel):
|
||||
image_url: str
|
||||
@ -15,4 +14,3 @@ class TrainRequest(BaseModel):
|
||||
epochs: int = 50 # 훈련 반복 횟수
|
||||
batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지
|
||||
path: Optional[str] = Field(None, alias="model_path")
|
||||
label_categories: Optional[List[LabelCategory]] = None # 새로운 레이블 카테고리 확인용
|
||||
|
Loading…
Reference in New Issue
Block a user