Merge branch 'ai/refactor/predict' into 'ai/develop'
Refactor: 학습 레이블 data_url로 수정 See merge request s11-s-project/S11P21S002!175
This commit is contained in:
commit
03d4a87907
@ -10,7 +10,7 @@ from services.create_model import save_model
|
|||||||
from utils.dataset_utils import split_data
|
from utils.dataset_utils import split_data
|
||||||
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path
|
from utils.file_utils import get_dataset_root_path, process_directories, process_image_and_label, join_path
|
||||||
from utils.slackMessage import send_slack_message
|
from utils.slackMessage import send_slack_message
|
||||||
import asyncio
|
import asyncio, httpx
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@ -18,7 +18,7 @@ router = APIRouter()
|
|||||||
@router.post("/predict")
|
@router.post("/predict")
|
||||||
async def detection_predict(request: PredictRequest):
|
async def detection_predict(request: PredictRequest):
|
||||||
|
|
||||||
# send_slack_message(f"predict 요청{request}", status="success")
|
send_slack_message(f"predict 요청{request}", status="success")
|
||||||
|
|
||||||
# 모델 로드
|
# 모델 로드
|
||||||
model = get_model(request)
|
model = get_model(request)
|
||||||
|
@ -4,7 +4,7 @@ from schemas.predict_response import LabelData
|
|||||||
|
|
||||||
class TrainDataInfo(BaseModel):
|
class TrainDataInfo(BaseModel):
|
||||||
image_url: str
|
image_url: str
|
||||||
label: str
|
data_url: str
|
||||||
|
|
||||||
class TrainRequest(BaseModel):
|
class TrainRequest(BaseModel):
|
||||||
project_id: int
|
project_id: int
|
||||||
|
@ -55,8 +55,8 @@ def process_image_and_label(data:TrainDataInfo, dataset_root_path:str, child_pat
|
|||||||
# 레이블 파일 경로
|
# 레이블 파일 경로
|
||||||
label_path = os.path.join(dataset_root_path, child_path, f"{img_title}.txt")
|
label_path = os.path.join(dataset_root_path, child_path, f"{img_title}.txt")
|
||||||
|
|
||||||
# 레이블 역직렬화
|
# 레이블 객체 불러오기
|
||||||
label = json_to_object(data.label)
|
label = json.loads(urllib.request.urlopen(data.data_url).read())
|
||||||
|
|
||||||
# 레이블 -> 학습용 레이블 데이터 파싱 후 생성
|
# 레이블 -> 학습용 레이블 데이터 파싱 후 생성
|
||||||
create_detection_train_label(label, label_path, label_map)
|
create_detection_train_label(label, label_path, label_map)
|
||||||
@ -104,13 +104,3 @@ def get_file_name(path):
|
|||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
raise FileNotFoundError()
|
raise FileNotFoundError()
|
||||||
return os.path.basename(path)
|
return os.path.basename(path)
|
||||||
|
|
||||||
def json_to_object(json_string):
|
|
||||||
try:
|
|
||||||
# JSON 문자열을 Python 객체로 변환
|
|
||||||
python_object = json.loads(json_string)
|
|
||||||
return python_object
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise json.JSONDecodeError("json_decode_error:"+str(e))
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception("exception at json_to_object:"+str(e))
|
|
Loading…
Reference in New Issue
Block a user