From 6b40135215e200427b714882ec589bde89b09f7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=9A=A9=EC=88=98?= Date: Wed, 25 Sep 2024 02:01:22 +0900 Subject: [PATCH 1/2] =?UTF-8?q?Refactor:=20=EB=AA=A8=EB=8D=B8=20=ED=95=99?= =?UTF-8?q?=EC=8A=B5=20=EB=A6=AC=ED=8C=A9=ED=86=A0=EB=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../model/controller/AiModelController.java | 22 +++--- .../model/entity/dto/ModelTrainRequest.java | 35 +++++++++ .../domain/model/service/AiModelService.java | 74 +++++++++---------- .../repository/ProgressCacheRepository.java | 38 ++++++++-- .../progress/service/ProgressService.java | 25 ++++++- .../worlabel/domain/project/dto/AiDto.java | 63 ++++++++++++---- .../project/service/ProjectService.java | 2 +- 7 files changed, 181 insertions(+), 78 deletions(-) create mode 100644 backend/src/main/java/com/worlabel/domain/model/entity/dto/ModelTrainRequest.java diff --git a/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java b/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java index 1352994..e11dacd 100644 --- a/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java +++ b/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java @@ -3,6 +3,7 @@ package com.worlabel.domain.model.controller; import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse; import com.worlabel.domain.model.entity.dto.AiModelRequest; import com.worlabel.domain.model.entity.dto.AiModelResponse; +import com.worlabel.domain.model.entity.dto.ModelTrainRequest; import com.worlabel.domain.model.service.AiModelService; import com.worlabel.domain.project.entity.dto.ProjectRequest; import com.worlabel.global.annotation.CurrentUser; @@ -32,7 +33,7 @@ public class AiModelController { @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) @GetMapping("/projects/{project_id}/models") public List getModelList( - @PathVariable("project_id") final Integer projectId) { + @PathVariable("project_id") final Integer projectId) { return aiModelService.getModelList(projectId); } @@ -41,7 +42,7 @@ public class AiModelController { @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) @GetMapping("/models/{model_id}/categories") public List getCategories( - @PathVariable("model_id") final Integer modelId) { + @PathVariable("model_id") final Integer modelId) { return aiModelService.getCategories(modelId); } @@ -50,8 +51,8 @@ public class AiModelController { @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) @PostMapping("/projects/{project_id}/models") public void addModel( - @PathVariable("project_id") final Integer projectId, - @Valid @RequestBody final AiModelRequest aiModelRequest) { + @PathVariable("project_id") final Integer projectId, + @Valid @RequestBody final AiModelRequest aiModelRequest) { aiModelService.addModel(projectId, aiModelRequest); } @@ -60,9 +61,9 @@ public class AiModelController { @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) @PutMapping("/projects/{project_id}/models/{model_id}") public void renameModel( - @PathVariable("project_id") final Integer projectId, - @PathVariable("model_id") final Integer modelId, - @Valid @RequestBody final AiModelRequest aiModelRequest) { + @PathVariable("project_id") final Integer projectId, + @PathVariable("model_id") final Integer modelId, + @Valid @RequestBody final AiModelRequest aiModelRequest) { aiModelService.renameModel(projectId, modelId, aiModelRequest); } @@ -71,8 +72,9 @@ public class AiModelController { @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) @PostMapping("/projects/{project_id}/train") public void trainModel( - @PathVariable("project_id") final Integer projectId, - @RequestBody final Integer modelId) { - aiModelService.train(projectId, modelId); + @PathVariable("project_id") final Integer projectId, + @RequestBody final ModelTrainRequest trainRequest) { + log.debug("모델 학습 요청 {}", trainRequest); + aiModelService.train(projectId, trainRequest); } } diff --git a/backend/src/main/java/com/worlabel/domain/model/entity/dto/ModelTrainRequest.java b/backend/src/main/java/com/worlabel/domain/model/entity/dto/ModelTrainRequest.java new file mode 100644 index 0000000..fea99f5 --- /dev/null +++ b/backend/src/main/java/com/worlabel/domain/model/entity/dto/ModelTrainRequest.java @@ -0,0 +1,35 @@ +package com.worlabel.domain.model.entity.dto; + +import com.worlabel.domain.result.entity.Optimizer; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotEmpty; +import lombok.*; + +@Getter +@AllArgsConstructor +@NoArgsConstructor(access = AccessLevel.PRIVATE) +@Schema(name = "모델 훈련 요청 dto", description = "모델 훈련 요청 DTO") +public class ModelTrainRequest { + + @Schema(description = "모델 ID", example = "1") + @NotEmpty(message = "아이디를 입력하세요") + private Integer modelId; + + @Schema(description = "ratio", example = "Default = 0.8") + private double ratio; + + @Schema(description = "epochs", example = "Default = 50") + private int epochs; + + @Schema(description = "batch", example = "Default = -1") + private int batch; + + @Schema(description = "lr0", example = "Default = 0.01") + private double lr0; + + @Schema(description = "lrf", example = "Default = 0.01") + private double lrf; + + @Schema(description = "optimizer", example = "Default = auto") + private Optimizer optimizer; +} diff --git a/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java b/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java index 47b500c..34e4205 100644 --- a/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java +++ b/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java @@ -6,19 +6,21 @@ import com.worlabel.domain.image.entity.Image; import com.worlabel.domain.image.entity.LabelStatus; import com.worlabel.domain.image.repository.ImageRepository; import com.worlabel.domain.labelcategory.entity.LabelCategory; +import com.worlabel.domain.labelcategory.entity.ProjectCategory; import com.worlabel.domain.labelcategory.entity.dto.DefaultLabelCategoryResponse; import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse; import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository; import com.worlabel.domain.model.entity.AiModel; -import com.worlabel.domain.model.entity.dto.AiModelRequest; -import com.worlabel.domain.model.entity.dto.AiModelResponse; -import com.worlabel.domain.model.entity.dto.DefaultAiModelResponse; -import com.worlabel.domain.model.entity.dto.DefaultResponse; +import com.worlabel.domain.model.entity.dto.*; import com.worlabel.domain.model.repository.AiModelRepository; import com.worlabel.domain.participant.entity.PrivilegeType; +import com.worlabel.domain.progress.service.ProgressService; import com.worlabel.domain.project.dto.AiDto; +import com.worlabel.domain.project.dto.AiDto.TrainDataInfo; +import com.worlabel.domain.project.dto.AiDto.TrainRequest; import com.worlabel.domain.project.entity.Project; import com.worlabel.domain.project.repository.ProjectRepository; +import com.worlabel.domain.project.service.ProjectService; import com.worlabel.global.annotation.CheckPrivilege; import com.worlabel.global.cache.CacheKey; import com.worlabel.global.exception.CustomException; @@ -34,7 +36,10 @@ import java.lang.reflect.Type; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; @Slf4j @Service @@ -42,13 +47,15 @@ import java.util.List; @RequiredArgsConstructor public class AiModelService { + private final LabelCategoryRepository labelCategoryRepository; + private final RedisTemplate redisTemplate; private final AiModelRepository aiModelRepository; private final ProjectRepository projectRepository; - private final LabelCategoryRepository labelCategoryRepository; - private final ImageRepository imageRepository; private final AiRequestService aiRequestService; - private final RedisTemplate redisTemplate; + private final ImageRepository imageRepository; + private final ProjectService projectService; private final Gson gson; + private final ProgressService progressService; // @PostConstruct public void loadDefaultModel() { @@ -127,56 +134,41 @@ public class AiModelService { } @CheckPrivilege(PrivilegeType.EDITOR) - public void train(final Integer projectId, final Integer modelId) { - trainProgressCheck(projectId); + public void train(final Integer projectId, final ModelTrainRequest trainRequest) { +// progressService.trainProgressCheck(projectId); // FastAPI 서버로 학습 요청을 전송 Project project = getProject(projectId); - AiModel model = getModel(modelId); - List labelCategories = labelCategoryRepository.findAllByModelId(modelId); - List categories = labelCategories.stream() - .map(LabelCategory::getAiCategoryId).toList(); + AiModel model = getModel(trainRequest.getModelId()); + + Map labelMap = project.getCategoryList().stream() + .collect(Collectors.toMap( + category -> category.getLabelCategory().getId(), + ProjectCategory::getId + )); List images = imageRepository.findImagesByProjectId(projectId); - - List data = images.stream().filter(image -> image.getStatus() == LabelStatus.COMPLETED) - .map(image -> new AiDto.TrainDataInfo(image.getImagePath(), image.getDataPath())) + List data = images.stream() + .filter(image -> image.getStatus() == LabelStatus.COMPLETED) + .map(TrainDataInfo::of) .toList(); + TrainRequest aiRequest = TrainRequest.of(project.getId(), model.getModelKey(), labelMap, data, trainRequest); + String endPoint = project.getProjectType().getValue() + "/train"; - AiDto.TrainRequest trainRequest = new AiDto.TrainRequest(); - trainRequest.setProjectId(projectId); - trainRequest.setCategoryId(categories); - trainRequest.setData(data); - trainRequest.setModelKey(model.getModelKey()); - // FastAPI 서버로 POST 요청 전송 - String modelKey = aiRequestService.postRequest(endPoint, trainRequest, String.class, response -> response); + String modelKey = aiRequestService.postRequest(endPoint, aiRequest, String.class, response -> response); // 가져온 modelKey -> version 업된 모델 다시 새롭게 저장 - String currentDateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")); + String currentDateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMdd_HHmm")); + int newVersion = model.getVersion() + 1; + String newName = currentDateTime + String.format("%03d", newVersion); - AiModel newModel = AiModel.of(currentDateTime, modelKey, model.getVersion() + 1, project); + AiModel newModel = AiModel.of(newName, modelKey, newVersion, project); aiModelRepository.save(newModel); } - /** - * Redis 중복 요청 체크 - */ - private void trainProgressCheck(Integer projectId) { - String trainProgressKey = CacheKey.trainProgressKey(); - - // 존재 확인 - Boolean isProjectExist = redisTemplate.opsForSet().isMember(trainProgressKey, projectId); - if (Boolean.TRUE.equals(isProjectExist)) { - throw new CustomException(ErrorCode.AI_IN_PROGRESS); - } - - // 학습 진행 중으로 상태 등록 - redisTemplate.opsForSet().add(trainProgressKey, projectId); - } - /** * Json -> List */ diff --git a/backend/src/main/java/com/worlabel/domain/progress/repository/ProgressCacheRepository.java b/backend/src/main/java/com/worlabel/domain/progress/repository/ProgressCacheRepository.java index 1d0fb37..f739389 100644 --- a/backend/src/main/java/com/worlabel/domain/progress/repository/ProgressCacheRepository.java +++ b/backend/src/main/java/com/worlabel/domain/progress/repository/ProgressCacheRepository.java @@ -16,22 +16,44 @@ public class ProgressCacheRepository { /** * 현재 오토레이블링중인지 확인하는 메서드 */ - public boolean predictCheck(final int projectId) { - String key = CacheKey.autoLabelingProgressKey(); - Boolean isProgress = redisTemplate.opsForSet().isMember(key, projectId); + public boolean predictProgressCheck(final int projectId) { + Boolean isProgress = redisTemplate.opsForSet().isMember(CacheKey.autoLabelingProgressKey(), projectId); return Boolean.TRUE.equals(isProgress); } /** - * 학습 진행 중 등록 메서드 + * 오토레이블링 진행 중 등록 메서드 */ public void registerPredictProgress(final int projectId) { - String key = CacheKey.autoLabelingProgressKey(); - redisTemplate.opsForSet().add(key, projectId); + redisTemplate.opsForSet().add(CacheKey.autoLabelingProgressKey(), projectId); } + /** + * 오토레이블링 진행 제거 메서드 + */ public void removePredictProgress(final int projectId) { - String key = CacheKey.autoLabelingProgressKey(); - redisTemplate.opsForSet().remove(key, projectId); + redisTemplate.opsForSet().remove(CacheKey.autoLabelingProgressKey(), projectId); + } + + /** + * 학습 진행 확인 메서드 + */ + public boolean trainProgressCheck(final int projectId) { + Boolean isProgress = redisTemplate.opsForSet().isMember(CacheKey.trainProgressKey(), projectId); + return Boolean.TRUE.equals(isProgress); + } + + /** + * 학습 진행 등록 메서드 + */ + public void registerTrainProgress(final int projectId) { + redisTemplate.opsForSet().add(CacheKey.trainProgressKey(), projectId); + } + + /** + * 학습 진행 제거 메서드 + */ + public void removeTrainProgress(final int projectId) { + redisTemplate.opsForSet().remove(CacheKey.trainProgressKey(), projectId); } } diff --git a/backend/src/main/java/com/worlabel/domain/progress/service/ProgressService.java b/backend/src/main/java/com/worlabel/domain/progress/service/ProgressService.java index a5b13a8..49fc227 100644 --- a/backend/src/main/java/com/worlabel/domain/progress/service/ProgressService.java +++ b/backend/src/main/java/com/worlabel/domain/progress/service/ProgressService.java @@ -14,14 +14,31 @@ public class ProgressService { private final ProgressCacheRepository progressCacheRepository; - public void predictCheck(final int projectId){ - if(progressCacheRepository.predictCheck(projectId)){ -// throw new CustomException(ErrorCode.AI_IN_PROGRESS); - progressCacheRepository.removePredictProgress(projectId); + public void predictProgressCheck(final int projectId){ + if(progressCacheRepository.predictProgressCheck(projectId)){ + throw new CustomException(ErrorCode.AI_IN_PROGRESS); } } public void registerPredictProgress(final int projectId){ progressCacheRepository.registerPredictProgress(projectId); } + + public void removePredictProgress(final int projectId){ + progressCacheRepository.removePredictProgress(projectId); + } + + public void trainProgressCheck(final int projectId){ + if(progressCacheRepository.trainProgressCheck(projectId)){ + throw new CustomException(ErrorCode.AI_IN_PROGRESS); + } + } + + public void registerTrainProgress(final int projectId){ + progressCacheRepository.registerTrainProgress(projectId); + } + + public void removeTrainProgress(final int projectId){ + progressCacheRepository.removeTrainProgress(projectId); + } } diff --git a/backend/src/main/java/com/worlabel/domain/project/dto/AiDto.java b/backend/src/main/java/com/worlabel/domain/project/dto/AiDto.java index aae8ace..278a5fb 100644 --- a/backend/src/main/java/com/worlabel/domain/project/dto/AiDto.java +++ b/backend/src/main/java/com/worlabel/domain/project/dto/AiDto.java @@ -3,41 +3,76 @@ package com.worlabel.domain.project.dto; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.gson.annotations.SerializedName; import com.worlabel.domain.image.entity.Image; +import com.worlabel.domain.model.entity.dto.ModelTrainRequest; +import com.worlabel.domain.result.entity.Optimizer; import lombok.*; import java.util.HashMap; import java.util.List; +import java.util.Map; public class AiDto { - @Data + @Getter + @AllArgsConstructor(access = AccessLevel.PRIVATE) + @NoArgsConstructor(access = AccessLevel.PRIVATE) public static class TrainDataInfo { + + @JsonProperty("image_url") private String imagePath; + + @JsonProperty("data_url") private String dataPath; - public TrainDataInfo(String imagePath, String dataPath) { - this.imagePath = imagePath; - this.dataPath = dataPath; + public static TrainDataInfo of(Image image) { + return new TrainDataInfo(image.getImagePath(), image.getDataPath()); } } - @Data + @Getter + @AllArgsConstructor(access = AccessLevel.PRIVATE) + @NoArgsConstructor(access = AccessLevel.PRIVATE) public static class TrainRequest { + @JsonProperty("project_id") private int projectId; - @JsonProperty("category_id") - private List categoryId; + @JsonProperty("model_key") + private String modelKey; + + @JsonProperty("label_map") + private Map labelMap; @JsonProperty("data") private List data; - @JsonProperty("model_key") - private String modelKey; -// private int seed; // Optional -// private float ratio; // Default = 0.8 -// private int epochs; // Default = 50 -// private float batch; // Default = -1 + private double ratio; // Default = 0.8 + + private int epochs; // Default = 50 + + private double batch; // Default = -1 + + private double lr0; + + private double lrf; + + private Optimizer optimizer; + + public static TrainRequest of(final Integer projectId, final String modelKey, final Map labelMap, final List data, final ModelTrainRequest trainRequest) { + TrainRequest request = new TrainRequest(); + request.projectId = projectId; + request.modelKey = modelKey; + request.labelMap = labelMap; + request.data = data; + request.ratio = request.getRatio(); + request.epochs = trainRequest.getEpochs(); + request.batch = trainRequest.getBatch(); + request.lr0 = trainRequest.getLr0(); + request.lrf = trainRequest.getLrf(); + request.optimizer = trainRequest.getOptimizer(); + + return request; + } } @Getter @@ -89,7 +124,7 @@ public class AiDto { @AllArgsConstructor(access = AccessLevel.PRIVATE) @Getter @ToString - public static class AutoLabelingResult{ + public static class AutoLabelingResult { @SerializedName("image_id") private Long imageId; diff --git a/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java b/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java index 8d6b56d..3b67b8e 100644 --- a/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java +++ b/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java @@ -162,7 +162,7 @@ public class ProjectService { */ @CheckPrivilege(PrivilegeType.EDITOR) public void autoLabeling(final Integer projectId, final AutoModelRequest request) { - progressService.predictCheck(projectId); +// progressService.predictCheck(projectId); Project project = getProject(projectId); String endPoint = project.getProjectType().getValue() + "/predict"; From d90b2be0c9ae9469f0538c634680a12508154e19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=9A=A9=EC=88=98?= Date: Wed, 25 Sep 2024 02:36:26 +0900 Subject: [PATCH 2/2] =?UTF-8?q?Feat:=20Model=20=EC=A1=B0=ED=9A=8C=20?= =?UTF-8?q?=EB=B0=8F=20=EB=8D=94=EB=AF=B8=20=EB=8D=B0=EC=9D=B4=ED=84=B0=20?= =?UTF-8?q?API=20=EC=83=9D=EC=84=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../model/controller/AiModelController.java | 2 + .../domain/model/service/AiModelService.java | 2 + .../repository/ProgressCacheRepository.java | 46 ++++++++++++++++--- .../progress/service/ProgressService.java | 11 +++++ .../report/controller/ReportController.java | 6 +-- .../worlabel/domain/report/entity/Report.java | 17 +++++-- .../report/entity/dto/ReportResponse.java | 19 +++++--- .../domain/report/service/ReportService.java | 44 ++++++++++++++++-- .../com/worlabel/global/cache/CacheKey.java | 4 ++ 9 files changed, 125 insertions(+), 26 deletions(-) diff --git a/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java b/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java index e11dacd..3fcbfc0 100644 --- a/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java +++ b/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java @@ -5,6 +5,7 @@ import com.worlabel.domain.model.entity.dto.AiModelRequest; import com.worlabel.domain.model.entity.dto.AiModelResponse; import com.worlabel.domain.model.entity.dto.ModelTrainRequest; import com.worlabel.domain.model.service.AiModelService; +import com.worlabel.domain.progress.service.ProgressService; import com.worlabel.domain.project.entity.dto.ProjectRequest; import com.worlabel.global.annotation.CurrentUser; import com.worlabel.global.config.swagger.SwaggerApiError; @@ -27,6 +28,7 @@ import java.util.List; public class AiModelController { private final AiModelService aiModelService; + private final ProgressService progressService; @Operation(summary = "프로젝트 모델 조회", description = "프로젝트에 있는 모델을 조회합니다.") @SwaggerApiSuccess(description = "프로젝트 멤버를 성공적으로 조회합니다.") diff --git a/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java b/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java index 34e4205..7d5426b 100644 --- a/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java +++ b/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java @@ -153,7 +153,9 @@ public class AiModelService { .map(TrainDataInfo::of) .toList(); +// progressService.registerTrainProgress(projectId); TrainRequest aiRequest = TrainRequest.of(project.getId(), model.getModelKey(), labelMap, data, trainRequest); +// progressService.removeTrainProgress(projectId); String endPoint = project.getProjectType().getValue() + "/train"; diff --git a/backend/src/main/java/com/worlabel/domain/progress/repository/ProgressCacheRepository.java b/backend/src/main/java/com/worlabel/domain/progress/repository/ProgressCacheRepository.java index f739389..94c28ab 100644 --- a/backend/src/main/java/com/worlabel/domain/progress/repository/ProgressCacheRepository.java +++ b/backend/src/main/java/com/worlabel/domain/progress/repository/ProgressCacheRepository.java @@ -1,23 +1,29 @@ package com.worlabel.domain.progress.repository; +import com.google.gson.Gson; +import com.worlabel.domain.report.entity.dto.ReportResponse; import com.worlabel.global.cache.CacheKey; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.stereotype.Repository; +import java.util.List; +import java.util.stream.Collectors; + @Slf4j @Repository @RequiredArgsConstructor public class ProgressCacheRepository { - private final RedisTemplate redisTemplate; + private final RedisTemplate redisTemplate; + private final Gson gson; /** * 현재 오토레이블링중인지 확인하는 메서드 */ public boolean predictProgressCheck(final int projectId) { - Boolean isProgress = redisTemplate.opsForSet().isMember(CacheKey.autoLabelingProgressKey(), projectId); + Boolean isProgress = redisTemplate.opsForSet().isMember(CacheKey.autoLabelingProgressKey(), String.valueOf(projectId)); return Boolean.TRUE.equals(isProgress); } @@ -25,21 +31,21 @@ public class ProgressCacheRepository { * 오토레이블링 진행 중 등록 메서드 */ public void registerPredictProgress(final int projectId) { - redisTemplate.opsForSet().add(CacheKey.autoLabelingProgressKey(), projectId); + redisTemplate.opsForSet().add(CacheKey.autoLabelingProgressKey(), String.valueOf(projectId)); } /** * 오토레이블링 진행 제거 메서드 */ public void removePredictProgress(final int projectId) { - redisTemplate.opsForSet().remove(CacheKey.autoLabelingProgressKey(), projectId); + redisTemplate.opsForSet().remove(CacheKey.autoLabelingProgressKey(), String.valueOf(projectId)); } /** * 학습 진행 확인 메서드 */ public boolean trainProgressCheck(final int projectId) { - Boolean isProgress = redisTemplate.opsForSet().isMember(CacheKey.trainProgressKey(), projectId); + Boolean isProgress = redisTemplate.opsForSet().isMember(CacheKey.trainProgressKey(), String.valueOf(projectId)); return Boolean.TRUE.equals(isProgress); } @@ -47,13 +53,39 @@ public class ProgressCacheRepository { * 학습 진행 등록 메서드 */ public void registerTrainProgress(final int projectId) { - redisTemplate.opsForSet().add(CacheKey.trainProgressKey(), projectId); + redisTemplate.opsForSet().add(CacheKey.trainProgressKey(), String.valueOf(projectId)); } /** * 학습 진행 제거 메서드 */ public void removeTrainProgress(final int projectId) { - redisTemplate.opsForSet().remove(CacheKey.trainProgressKey(), projectId); + redisTemplate.opsForSet().remove(CacheKey.trainProgressKey(), String.valueOf(projectId)); + } + + /** + * 진행 상황을 Redis에 추가 + */ + public void addProgressModel(final int modelId,final String data){ + ReportResponse reportResponse = convert(data); + redisTemplate.opsForList().rightPush(CacheKey.progressStatusKey(modelId), gson.toJson(reportResponse)); + } + + public List getProgressModel(final int modelId) { + // 저장된 걸 주어진 응답에 맞추어 리턴 + String key = CacheKey.progressStatusKey(modelId); + List progressList = redisTemplate.opsForList().range(key, 0, -1); + + return progressList.stream() + .map(this::convert) + .toList(); + } + + public void clearProgressModel(final int modelId) { + redisTemplate.delete(CacheKey.progressStatusKey(modelId)); + } + + private ReportResponse convert(String data){ + return gson.fromJson(data, ReportResponse.class); } } diff --git a/backend/src/main/java/com/worlabel/domain/progress/service/ProgressService.java b/backend/src/main/java/com/worlabel/domain/progress/service/ProgressService.java index 49fc227..6ef7e1b 100644 --- a/backend/src/main/java/com/worlabel/domain/progress/service/ProgressService.java +++ b/backend/src/main/java/com/worlabel/domain/progress/service/ProgressService.java @@ -1,12 +1,15 @@ package com.worlabel.domain.progress.service; import com.worlabel.domain.progress.repository.ProgressCacheRepository; +import com.worlabel.domain.report.entity.dto.ReportResponse; import com.worlabel.global.exception.CustomException; import com.worlabel.global.exception.ErrorCode; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; +import java.util.List; + @Slf4j @Service @RequiredArgsConstructor @@ -34,6 +37,10 @@ public class ProgressService { } } + public boolean isProgressTrain(final int projectId){ + return progressCacheRepository.trainProgressCheck(projectId); + } + public void registerTrainProgress(final int projectId){ progressCacheRepository.registerTrainProgress(projectId); } @@ -41,4 +48,8 @@ public class ProgressService { public void removeTrainProgress(final int projectId){ progressCacheRepository.removeTrainProgress(projectId); } + + public List getProgressResponse(final int modelId) { + return progressCacheRepository.getProgressModel(modelId); + } } diff --git a/backend/src/main/java/com/worlabel/domain/report/controller/ReportController.java b/backend/src/main/java/com/worlabel/domain/report/controller/ReportController.java index 8652b37..8ae0da5 100644 --- a/backend/src/main/java/com/worlabel/domain/report/controller/ReportController.java +++ b/backend/src/main/java/com/worlabel/domain/report/controller/ReportController.java @@ -11,14 +11,14 @@ import org.springframework.web.bind.annotation.RestController; import java.util.List; @RestController -@RequestMapping("/api/reports") +@RequestMapping("/api/projects/{project_id}/reports") @RequiredArgsConstructor public class ReportController { private final ReportService reportService; @GetMapping("/model/{model_id}") - public List getReportsByModelId(@PathVariable("model_id") final Integer modelId) { - return reportService.getReportsByModelId(modelId); + public List getReportsByModelId(@PathVariable("model_id") final Integer modelId, @PathVariable("project_id") final Integer projectId) { + return reportService.getReportsByModelId(projectId,modelId); } } \ No newline at end of file diff --git a/backend/src/main/java/com/worlabel/domain/report/entity/Report.java b/backend/src/main/java/com/worlabel/domain/report/entity/Report.java index 55172c7..b66e920 100644 --- a/backend/src/main/java/com/worlabel/domain/report/entity/Report.java +++ b/backend/src/main/java/com/worlabel/domain/report/entity/Report.java @@ -25,17 +25,18 @@ public class Report extends BaseEntity { @JoinColumn(name = "model_id", nullable = false) private AiModel aiModel; + /** + * 현재 에포크 + */ + @Column(name = "epoch", nullable = false) + private Integer epoch; + /** * 전체 에포크 */ @Column(name = "total_epochs", nullable = false) private Integer totalEpochs; - /** - * 현재 에포크 - */ - @Column(name = "epoch", nullable = false) - private Integer epoch; @Column(name = "box_loss", nullable = false) private double boxLoss; @@ -48,4 +49,10 @@ public class Report extends BaseEntity { @Column(name = "fitness", nullable = false) private double fitness; + + @Column(name = "epoch_time", nullable = false) + private double epochTime; + + @Column(name = "left_second", nullable = false) + private double leftSecond; } diff --git a/backend/src/main/java/com/worlabel/domain/report/entity/dto/ReportResponse.java b/backend/src/main/java/com/worlabel/domain/report/entity/dto/ReportResponse.java index 6ad1d3d..30825c2 100644 --- a/backend/src/main/java/com/worlabel/domain/report/entity/dto/ReportResponse.java +++ b/backend/src/main/java/com/worlabel/domain/report/entity/dto/ReportResponse.java @@ -4,26 +4,33 @@ import com.worlabel.domain.report.entity.Report; import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Getter; +import lombok.NoArgsConstructor; @Getter -@AllArgsConstructor(access = AccessLevel.PRIVATE) +@NoArgsConstructor +@AllArgsConstructor public class ReportResponse { - private Integer id; - private Integer totalEpochs; - private Integer epoch; + private int modelId; + private int totalEpochs; + private int epoch; private double boxLoss; private double clsLoss; private double dflLoss; private double fitness; + private double epochTime; + private double leftSecond; public static ReportResponse from(final Report report) { return new ReportResponse( - report.getId(), + report.getAiModel().getId(), report.getTotalEpochs(), report.getEpoch(), report.getBoxLoss(), report.getClsLoss(), report.getDflLoss(), - report.getFitness()); + report.getFitness(), + report.getEpochTime(), + report.getLeftSecond() + ); } } \ No newline at end of file diff --git a/backend/src/main/java/com/worlabel/domain/report/service/ReportService.java b/backend/src/main/java/com/worlabel/domain/report/service/ReportService.java index d0315ee..f3a8f9e 100644 --- a/backend/src/main/java/com/worlabel/domain/report/service/ReportService.java +++ b/backend/src/main/java/com/worlabel/domain/report/service/ReportService.java @@ -1,5 +1,6 @@ package com.worlabel.domain.report.service; +import com.worlabel.domain.progress.service.ProgressService; import com.worlabel.domain.report.entity.Report; import com.worlabel.domain.report.entity.dto.ReportResponse; import com.worlabel.domain.report.repository.ReportRepository; @@ -7,6 +8,7 @@ import lombok.RequiredArgsConstructor; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; +import java.util.ArrayList; import java.util.List; @Service @@ -15,11 +17,43 @@ import java.util.List; public class ReportService { private final ReportRepository reportRepository; + private final ProgressService progressService; - public List getReportsByModelId(final Integer modelId) { - List reports = reportRepository.findByAiModelId(modelId); - return reports.stream() - .map(ReportResponse::from) - .toList(); + public List getReportsByModelId(final Integer projectId, final Integer modelId) { + // 진행중이면 진행중에서 받아오기 + return getDummyList(); +// if(progressService.isProgressTrain(projectId)){ +// return progressService.getProgressResponse(modelId); +// } +// // 작업 완료시에는 RDB +// else{ +// List reports = reportRepository.findByAiModelId(modelId); +// return reports.stream() +// .map(ReportResponse::from) +// .toList(); +// } } + + private List getDummyList() { + List dummyList = new ArrayList<>(); + + // 더미 데이터 15개 생성 + for (int i = 1; i <= 15; i++) { + ReportResponse dummy = new ReportResponse( + i, // modelId + 100, // totalEpochs + i, // epoch + Math.random(), // boxLoss + Math.random(), // clsLoss + Math.random(), // dflLoss + Math.random(), // fitness + Math.random() * 10,// epochTime + Math.random() * 100 // leftSecond + ); + dummyList.add(dummy); + } + + return dummyList; + } + } \ No newline at end of file diff --git a/backend/src/main/java/com/worlabel/global/cache/CacheKey.java b/backend/src/main/java/com/worlabel/global/cache/CacheKey.java index ab1f71c..7b3100b 100644 --- a/backend/src/main/java/com/worlabel/global/cache/CacheKey.java +++ b/backend/src/main/java/com/worlabel/global/cache/CacheKey.java @@ -14,4 +14,8 @@ public class CacheKey { public static String fcmTokenKey(){ return "fcmToken"; } + + public static String progressStatusKey(int modelId) { + return "progress:" + modelId; + } }