Merge branch 'ai/refactor/exception-handle' into 'ai/develop'
Refactor: request 검증 classification 레이블 없을때 에러 처리 See merge request s11-s-project/S11P21S002!248
This commit is contained in:
commit
354676d867
@ -32,9 +32,6 @@ def get_model_list(project_id:int):
|
|||||||
|
|
||||||
@router.post("/projects/{project_id}", status_code=201)
|
@router.post("/projects/{project_id}", status_code=201)
|
||||||
def create_model(project_id: int, request: ModelCreateRequest):
|
def create_model(project_id: int, request: ModelCreateRequest):
|
||||||
if request.project_type not in ["segmentation", "detection", "classification"]:
|
|
||||||
raise HTTPException(status_code=400,
|
|
||||||
detail= f"Invalid type '{request.type}'. Must be one of \"segmentation\", \"detection\", \"classification\".")
|
|
||||||
model_key = create_new_model(project_id, request.project_type, request.pretrained)
|
model_key = create_new_model(project_id, request.project_type, request.pretrained)
|
||||||
return {"model_key": model_key}
|
return {"model_key": model_key}
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
class ModelCreateRequest(BaseModel):
|
class ModelCreateRequest(BaseModel):
|
||||||
project_type: str
|
project_type: Literal["segmentation", "detection", "classification"]
|
||||||
pretrained:bool = True
|
pretrained:bool = True
|
@ -1,5 +1,4 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
class ImageInfo(BaseModel):
|
class ImageInfo(BaseModel):
|
||||||
image_id: int
|
image_id: int
|
||||||
@ -9,7 +8,7 @@ class ImageInfo(BaseModel):
|
|||||||
class PredictRequest(BaseModel):
|
class PredictRequest(BaseModel):
|
||||||
project_id: int
|
project_id: int
|
||||||
m_key: str = Field("yolo8", alias="model_key") # model_ 로 시작하는 변수를 BaseModel의 변수로 만들경우 Warning 떠서 m_key로 대체
|
m_key: str = Field("yolo8", alias="model_key") # model_ 로 시작하는 변수를 BaseModel의 변수로 만들경우 Warning 떠서 m_key로 대체
|
||||||
label_map: dict[str, int] = Field(..., description="프로젝트 레이블 이름: 프로젝트 레이블 pk , None일 경우 모델 레이블 카테고리 idx로 레이블링")
|
label_map: dict[str, int] = Field(..., description="프로젝트 레이블 이름: 프로젝트 레이블 pk")
|
||||||
image_list: list[ImageInfo] # 이미지 리스트
|
image_list: list[ImageInfo] # 이미지 리스트
|
||||||
conf_threshold: float = 0.25 #
|
conf_threshold: float = Field(0.25, gt=0, lt= 1)
|
||||||
iou_threshold: float = 0.45
|
iou_threshold: float = Field(0.45, gt=0, lt= 1)
|
||||||
|
28
ai/app/schemas/train_label_data.py
Normal file
28
ai/app/schemas/train_label_data.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Segment(BaseModel):
|
||||||
|
x: float = Field(..., ge=0, le=1)
|
||||||
|
y: float = Field(..., ge=0, le=1)
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
return f"{self.x} {self.y}"
|
||||||
|
|
||||||
|
class DetectionLabelData(BaseModel):
|
||||||
|
label_id: int = Field(..., ge=0)
|
||||||
|
center_x: float = Field(..., ge=0, le=1)
|
||||||
|
center_y: float = Field(..., ge=0, le=1)
|
||||||
|
width: float = Field(..., ge=0, le=1)
|
||||||
|
height: float = Field(..., ge=0, le=1)
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
return f"{self.label_id} {self.center_x} {self.center_y} {self.width} {self.height}"
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentationLabelData(BaseModel):
|
||||||
|
label_id: int
|
||||||
|
segments: list[Segment]
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
points_str = " ".join([segment.to_string() for segment in self.segments])
|
||||||
|
return f"{self.label_id} {points_str}"
|
@ -1,7 +1,6 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
class ReportData(BaseModel):
|
class ReportData(BaseModel):
|
||||||
|
|
||||||
epoch: int # 현재 에포크
|
epoch: int # 현재 에포크
|
||||||
total_epochs: int # 전체 에포크
|
total_epochs: int # 전체 에포크
|
||||||
seg_loss: float # seg_loss
|
seg_loss: float # seg_loss
|
||||||
|
@ -1,22 +1,22 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Optional, Union, Literal
|
from typing import Literal
|
||||||
from schemas.predict_response import LabelData
|
|
||||||
|
|
||||||
class TrainDataInfo(BaseModel):
|
class TrainDataInfo(BaseModel):
|
||||||
image_url: str
|
image_url: str
|
||||||
data_url: str
|
data_url: str
|
||||||
|
|
||||||
class TrainRequest(BaseModel):
|
class TrainRequest(BaseModel):
|
||||||
project_id: int
|
project_id: int = Field(..., gt= 0)
|
||||||
m_key: str = Field("yolo8", alias="model_key")
|
m_key: str = Field("yolo8", alias="model_key")
|
||||||
m_id: int = Field(..., alias="model_id") # 학습 중 에포크 결과를 보낼때 model_id를 보냄
|
m_id: int = Field(..., alias="model_id", gt= 0) # 학습 중 에포크 결과를 보낼때 model_id를 보냄
|
||||||
label_map: dict[str, int] = Field(..., description="프로젝트 레이블 이름: 프로젝트 레이블 pk , None일 경우 모델 레이블 카테고리 idx로 레이블링")
|
label_map: dict[str, int] = Field(..., description="프로젝트 레이블 이름: 프로젝트 레이블 pk")
|
||||||
data: List[TrainDataInfo]
|
data: list[TrainDataInfo]
|
||||||
ratio: float = 0.8 # 훈련/검증 분할 비율
|
ratio: float = Field(0.8, gt=0, lt=1) # 훈련/검증 분할 비율
|
||||||
|
|
||||||
# 학습 파라미터
|
# 학습 파라미터
|
||||||
epochs: int = 50 # 훈련 반복 횟수
|
epochs: int = Field(50, gt= 0, lt = 1000) # 훈련 반복 횟수
|
||||||
batch: Union[float, int] = -1 # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지
|
batch: int = Field(16, gt=0, le = 10000) # 훈련 batch 수[int] or GPU의 사용률 자동[float] default(-1): gpu의 60% 사용 유지
|
||||||
lr0: float = 0.01 # 초기 학습 가중치
|
lr0: float = Field(0.01, gt= 0, lt= 1) # 초기 학습 가중치
|
||||||
lrf: float = 0.01 # lr0 기준으로 학습 가중치의 최종 수렴치 (ex lr0의 0.01배)
|
lrf: float = Field(0.01, gt= 0, lt= 1) # lr0 기준으로 학습 가중치의 최종 수렴치 (ex lr0의 0.01배)
|
||||||
optimizer: Literal['auto', 'SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp'] = 'auto'
|
optimizer: Literal['auto', 'SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp'] = 'auto'
|
||||||
|
|
||||||
|
@ -5,13 +5,14 @@ import os, httpx
|
|||||||
|
|
||||||
def send_data_call_api(project_id:int, model_id:int, data:ReportData):
|
def send_data_call_api(project_id:int, model_id:int, data:ReportData):
|
||||||
try:
|
try:
|
||||||
# load_dotenv()
|
load_dotenv()
|
||||||
# base_url = os.getenv("API_BASE_URL")
|
base_url = os.getenv("API_BASE_URL")
|
||||||
# main.py와 같은 디렉토리에 .env 파일 생성해서 따옴표 없이 아래 데이터를 입력
|
# main.py와 같은 디렉토리에 .env 파일 생성해서 따옴표 없이 아래 데이터를 입력
|
||||||
# API_BASE_URL = {url}
|
# API_BASE_URL = {url}
|
||||||
# API_KEY = {key}
|
# API_KEY = {key}
|
||||||
|
|
||||||
# 하드코딩으로 대체
|
# 하드코딩으로 대체
|
||||||
|
if not base_url:
|
||||||
base_url = "http://127.0.0.1:8080"
|
base_url = "http://127.0.0.1:8080"
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
@ -22,7 +23,8 @@ def send_data_call_api(project_id:int, model_id:int, data:ReportData):
|
|||||||
method="POST",
|
method="POST",
|
||||||
url=base_url+f"/api/projects/{project_id}/reports/models/{model_id}",
|
url=base_url+f"/api/projects/{project_id}/reports/models/{model_id}",
|
||||||
json=data.model_dump(),
|
json=data.model_dump(),
|
||||||
headers=headers
|
headers=headers,
|
||||||
|
timeout=10
|
||||||
)
|
)
|
||||||
# status에 따라 예외 발생
|
# status에 따라 예외 발생
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -3,7 +3,7 @@ import shutil
|
|||||||
import yaml
|
import yaml
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from schemas.train_request import TrainDataInfo
|
from schemas.train_request import TrainDataInfo
|
||||||
from schemas.predict_response import LabelData
|
from schemas.train_label_data import DetectionLabelData, SegmentationLabelData, Segment
|
||||||
import urllib
|
import urllib
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@ -67,27 +67,33 @@ def process_image_and_label(data:TrainDataInfo, dataset_root_path:str, child_pat
|
|||||||
def create_detection_train_label(label:dict, label_path:str, label_converter:dict[int, int]):
|
def create_detection_train_label(label:dict, label_path:str, label_converter:dict[int, int]):
|
||||||
with open(label_path, "w") as train_label_txt:
|
with open(label_path, "w") as train_label_txt:
|
||||||
for shape in label["shapes"]:
|
for shape in label["shapes"]:
|
||||||
train_label = []
|
|
||||||
x1 = shape["points"][0][0]
|
x1 = shape["points"][0][0]
|
||||||
y1 = shape["points"][0][1]
|
y1 = shape["points"][0][1]
|
||||||
x2 = shape["points"][1][0]
|
x2 = shape["points"][1][0]
|
||||||
y2 = shape["points"][1][1]
|
y2 = shape["points"][1][1]
|
||||||
train_label.append(str(label_converter[shape["group_id"]])) # label Id
|
detection_label = DetectionLabelData(
|
||||||
train_label.append(str((x1 + x2) / 2 / label["imageWidth"])) # 중심 x 좌표
|
label_id= label_converter[shape["group_id"]], # 모델의 id (converter : pjt category pk -> model category id)
|
||||||
train_label.append(str((y1 + y2) / 2 / label["imageHeight"])) # 중심 y 좌표
|
center_x= (x1 + x2) / 2 / label["imageWidth"], # 중심 x 좌표
|
||||||
train_label.append(str((x2 - x1) / label["imageWidth"])) # 너비
|
center_y= (y1 + y2) / 2 / label["imageHeight"], # 중심 y 좌표
|
||||||
train_label.append(str((y2 - y1) / label["imageHeight"] )) # 높이
|
width= (x2 - x1) / label["imageWidth"], # 너비
|
||||||
train_label_txt.write(" ".join(train_label)+"\n")
|
height= (y2 - y1) / label["imageHeight"] # 높이
|
||||||
|
)
|
||||||
|
|
||||||
|
train_label_txt.write(detection_label.to_string()+"\n") # str변환 후 txt에 쓰기
|
||||||
|
|
||||||
def create_segmentation_train_label(label:dict, label_path:str, label_converter:dict[int, int]):
|
def create_segmentation_train_label(label:dict, label_path:str, label_converter:dict[int, int]):
|
||||||
with open(label_path, "w") as train_label_txt:
|
with open(label_path, "w") as train_label_txt:
|
||||||
for shape in label["shapes"]:
|
for shape in label["shapes"]:
|
||||||
train_label = []
|
segmentation_label = SegmentationLabelData(
|
||||||
train_label.append(str(label_converter[shape["group_id"]])) # label Id
|
label_id = label_converter[shape["group_id"]], # label Id
|
||||||
for x, y in shape["points"]:
|
segments = [
|
||||||
train_label.append(str(x / label["imageWidth"]))
|
Segment(
|
||||||
train_label.append(str(y / label["imageHeight"]))
|
x=x / label["imageWidth"], # shapes의 points 갯수만큼 x, y 반복
|
||||||
train_label_txt.write(" ".join(train_label)+"\n")
|
y=y / label["imageHeight"]
|
||||||
|
) for x, y in shape["points"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
train_label_txt.write(segmentation_label.to_string()+"\n")
|
||||||
|
|
||||||
def join_path(path, *paths):
|
def join_path(path, *paths):
|
||||||
"""os.path.join()과 같은 기능, os import 하기 싫어서 만듦"""
|
"""os.path.join()과 같은 기능, os import 하기 싫어서 만듦"""
|
||||||
@ -135,6 +141,10 @@ def process_image_and_label_in_cls(data:TrainDataInfo, dataset_root_path:str, ch
|
|||||||
# 레이블 객체 불러오기
|
# 레이블 객체 불러오기
|
||||||
label = json.loads(urllib.request.urlopen(data.data_url).read())
|
label = json.loads(urllib.request.urlopen(data.data_url).read())
|
||||||
|
|
||||||
|
if not label["shapes"]:
|
||||||
|
# assert label["shapes"], No Label. Failed Download" # AssertionError 발생
|
||||||
|
print("No Label. Failed Download")
|
||||||
|
return
|
||||||
label_name = label["shapes"][0]["label"]
|
label_name = label["shapes"][0]["label"]
|
||||||
|
|
||||||
label_path = os.path.join(dataset_root_path,child_path,label_name)
|
label_path = os.path.join(dataset_root_path,child_path,label_name)
|
||||||
@ -143,8 +153,8 @@ def process_image_and_label_in_cls(data:TrainDataInfo, dataset_root_path:str, ch
|
|||||||
if os.path.exists(label_path):
|
if os.path.exists(label_path):
|
||||||
urllib.request.urlretrieve(data.image_url, os.path.join(label_path, img_name))
|
urllib.request.urlretrieve(data.image_url, os.path.join(label_path, img_name))
|
||||||
else:
|
else:
|
||||||
# raise FileNotFoundError("failed download")
|
# raise FileNotFoundError("No Label Category. Failed Download")
|
||||||
print("Not Found Label Category. Failed Download")
|
print("No Label Category. Failed Download")
|
||||||
# 레이블 데이터 중에서 프로젝트 카테고리에 해당되지않는 데이터가 있는 경우 처리 1. 에러 raise 2. 무시(+ warning)
|
# 레이블 데이터 중에서 프로젝트 카테고리에 해당되지않는 데이터가 있는 경우 처리 1. 에러 raise 2. 무시(+ warning)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user