Refactor: train_label pydantic model 이용하여 데이터 유효성 검사 추가
This commit is contained in:
parent
1b7fea1415
commit
4dfa4b1c3e
28
ai/app/schemas/train_label_data.py
Normal file
28
ai/app/schemas/train_label_data.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Segment(BaseModel):
|
||||||
|
x: float = Field(..., ge=0, le=1)
|
||||||
|
y: float = Field(..., ge=0, le=1)
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
return f"{self.x} {self.y}"
|
||||||
|
|
||||||
|
class DetectionLabelData(BaseModel):
|
||||||
|
label_id: int = Field(..., ge=0)
|
||||||
|
center_x: float = Field(..., ge=0, le=1)
|
||||||
|
center_y: float = Field(..., ge=0, le=1)
|
||||||
|
width: float = Field(..., ge=0, le=1)
|
||||||
|
height: float = Field(..., ge=0, le=1)
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
return f"{self.label_id} {self.center_x} {self.center_y} {self.width} {self.height}"
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentationLabelData(BaseModel):
|
||||||
|
label_id: int
|
||||||
|
segments: list[Segment]
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
points_str = " ".join([segment.to_string() for segment in self.segments])
|
||||||
|
return f"{self.label_id} {points_str}"
|
@ -3,7 +3,7 @@ import shutil
|
|||||||
import yaml
|
import yaml
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from schemas.train_request import TrainDataInfo
|
from schemas.train_request import TrainDataInfo
|
||||||
from schemas.predict_response import LabelData
|
from schemas.train_label_data import DetectionLabelData, SegmentationLabelData, Segment
|
||||||
import urllib
|
import urllib
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@ -71,29 +71,29 @@ def create_detection_train_label(label:dict, label_path:str, label_converter:dic
|
|||||||
y1 = shape["points"][0][1]
|
y1 = shape["points"][0][1]
|
||||||
x2 = shape["points"][1][0]
|
x2 = shape["points"][1][0]
|
||||||
y2 = shape["points"][1][1]
|
y2 = shape["points"][1][1]
|
||||||
train_label = {
|
detection_label = DetectionLabelData(
|
||||||
'label_id': label_converter[shape["group_id"]], # 모델의 id (converter : pjt category pk -> model category id)
|
label_id= label_converter[shape["group_id"]], # 모델의 id (converter : pjt category pk -> model category id)
|
||||||
'center_x': (x1 + x2) / 2 / label["imageWidth"], # 중심 x 좌표
|
center_x= (x1 + x2) / 2 / label["imageWidth"], # 중심 x 좌표
|
||||||
'center_y': (y1 + y2) / 2 / label["imageHeight"], # 중심 y 좌표
|
center_y= (y1 + y2) / 2 / label["imageHeight"], # 중심 y 좌표
|
||||||
'width': (x2 - x1) / label["imageWidth"], # 너비
|
width= (x2 - x1) / label["imageWidth"], # 너비
|
||||||
'height': (y2 - y1) / label["imageHeight"] # 높이
|
height= (y2 - y1) / label["imageHeight"] # 높이
|
||||||
}
|
)
|
||||||
|
|
||||||
for key, value in train_label[1:].items(): # label_id를 제외한 다른 key에 대해
|
train_label_txt.write(detection_label.to_string()+"\n") # str변환 후 txt에 쓰기
|
||||||
if value<0 or value >1: # 0과 1사이가 아니라면 에러
|
|
||||||
raise ValueError(f"Improper value in {label_path}: {key} = {value}")
|
|
||||||
|
|
||||||
train_label_txt.write(" ".join(map(str, train_label.values()))+"\n") # str변환 후 txt에 쓰기
|
|
||||||
|
|
||||||
def create_segmentation_train_label(label:dict, label_path:str, label_converter:dict[int, int]):
|
def create_segmentation_train_label(label:dict, label_path:str, label_converter:dict[int, int]):
|
||||||
with open(label_path, "w") as train_label_txt:
|
with open(label_path, "w") as train_label_txt:
|
||||||
for shape in label["shapes"]:
|
for shape in label["shapes"]:
|
||||||
train_label = []
|
segmentation_label = SegmentationLabelData(
|
||||||
train_label.append(str(label_converter[shape["group_id"]])) # label Id
|
label_id = label_converter[shape["group_id"]], # label Id
|
||||||
for x, y in shape["points"]:
|
segments = [
|
||||||
train_label.append(str(x / label["imageWidth"]))
|
Segment(
|
||||||
train_label.append(str(y / label["imageHeight"]))
|
x=x / label["imageWidth"], # shapes의 points 갯수만큼 x, y 반복
|
||||||
train_label_txt.write(" ".join(train_label)+"\n")
|
y=y / label["imageHeight"]
|
||||||
|
) for x, y in shape["points"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
train_label_txt.write(segmentation_label.to_string()+"\n")
|
||||||
|
|
||||||
def join_path(path, *paths):
|
def join_path(path, *paths):
|
||||||
"""os.path.join()과 같은 기능, os import 하기 싫어서 만듦"""
|
"""os.path.join()과 같은 기능, os import 하기 싫어서 만듦"""
|
||||||
|
Loading…
Reference in New Issue
Block a user