Merge branch 'ai/refactor/predict' into 'ai/develop'

Refactor: 학습 레이블 data_url로 수정

See merge request s11-s-project/S11P21S002!175
This commit is contained in:
김용수 2024-09-25 17:21:30 +09:00
commit 03d4a87907
3 changed files with 5 additions and 15 deletions

View File

@ -10,7 +10,7 @@ from services.create_model import save_model
from utils.dataset_utils import split_data from utils.dataset_utils import split_data
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path
from utils.slackMessage import send_slack_message from utils.slackMessage import send_slack_message
import asyncio import asyncio, httpx
router = APIRouter() router = APIRouter()
@ -18,7 +18,7 @@ router = APIRouter()
@router.post("/predict") @router.post("/predict")
async def detection_predict(request: PredictRequest): async def detection_predict(request: PredictRequest):
# send_slack_message(f"predict 요청{request}", status="success") send_slack_message(f"predict 요청{request}", status="success")
# 모델 로드 # 모델 로드
model = get_model(request) model = get_model(request)

View File

@ -4,7 +4,7 @@ from schemas.predict_response import LabelData
class TrainDataInfo(BaseModel): class TrainDataInfo(BaseModel):
image_url: str image_url: str
label: str data_url: str
class TrainRequest(BaseModel): class TrainRequest(BaseModel):
project_id: int project_id: int

View File

@ -55,8 +55,8 @@ def process_image_and_label(data:TrainDataInfo, dataset_root_path:str, child_pat
# 레이블 파일 경로 # 레이블 파일 경로
label_path = os.path.join(dataset_root_path, child_path, f"{img_title}.txt") label_path = os.path.join(dataset_root_path, child_path, f"{img_title}.txt")
# 레이블 역직렬화 # 레이블 객체 불러오기
label = json_to_object(data.label) label = json.loads(urllib.request.urlopen(data.data_url).read())
# 레이블 -> 학습용 레이블 데이터 파싱 후 생성 # 레이블 -> 학습용 레이블 데이터 파싱 후 생성
create_detection_train_label(label, label_path, label_map) create_detection_train_label(label, label_path, label_map)
@ -104,13 +104,3 @@ def get_file_name(path):
if not os.path.exists(path): if not os.path.exists(path):
raise FileNotFoundError() raise FileNotFoundError()
return os.path.basename(path) return os.path.basename(path)
def json_to_object(json_string):
try:
# JSON 문자열을 Python 객체로 변환
python_object = json.loads(json_string)
return python_object
except json.JSONDecodeError as e:
raise json.JSONDecodeError("json_decode_error:"+str(e))
except Exception as e:
raise Exception("exception at json_to_object:"+str(e))