Refactor: Train response에 accuracy 추가

This commit is contained in:
김진현 2024-09-27 14:18:59 +09:00
parent 83737e566d
commit 4b6751a00b
4 changed files with 10 additions and 7 deletions

View File

@ -4,7 +4,7 @@ from schemas.predict_request import PredictRequest
from schemas.train_request import TrainRequest, TrainDataInfo
from schemas.predict_response import PredictResponse, LabelData
from schemas.train_report_data import ReportData
from schemas.train_response import ClassificationTrainResponse
from schemas.train_response import TrainResponse
from services.load_model import load_classification_model
from services.create_model import save_model
from utils.file_utils import get_dataset_root_path, process_directories_in_cls, process_image_and_label_in_cls, join_path
@ -105,11 +105,16 @@ async def classification_train(request: TrainRequest):
result = results.results_dict
response = ClassificationTrainResponse(
response = TrainResponse(
modelKey=model_key,
precision= result["accuracy_top1"],
precision= 0,
recall= 0,
mAP50= 0,
mAP5095= 0,
accuracy=result["accuracy_top1"],
fitness= result["fitness"]
)
send_slack_message(f"train 성공{response}", status="success")
return response

View File

@ -144,6 +144,7 @@ async def detection_train(request: TrainRequest):
recall= result["metrics/recall(B)"],
mAP50= result["metrics/mAP50(B)"],
mAP5095= result["metrics/mAP50-95(B)"],
accuracy=0,
fitness= result["fitness"]
)
send_slack_message(f"train 성공{response}", status="success")

View File

@ -114,6 +114,7 @@ async def segmentation_train(request: TrainRequest):
recall= result["metrics/recall(M)"],
mAP50= result["metrics/mAP50(M)"],
mAP5095= result["metrics/mAP50-95(M)"],
accuracy = 0,
fitness= result["fitness"]
)
send_slack_message(f"train 성공{response}", status="success")

View File

@ -6,9 +6,5 @@ class TrainResponse(BaseModel):
recall: float
mAP50: float
mAP5095: float
fitness: float
class ClassificationTrainResponse(BaseModel):
modelKey: str
accuracy: float
fitness: float