From 6d6cd8b134806fdbacc90fda83f241a601813686 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=9A=A9=EC=88=98?= Date: Mon, 30 Sep 2024 12:18:48 +0900 Subject: [PATCH] =?UTF-8?q?Fix:=20=EB=AA=A8=EB=8D=B8=20=ED=95=99=EC=8A=B5?= =?UTF-8?q?=20=EC=97=90=EB=9F=AC=20=ED=95=B4=EA=B2=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../domain/model/service/AiModelService.java | 8 ++-- .../repository/ProgressCacheRepository.java | 44 +++++++++++-------- .../progress/service/ProgressService.java | 24 +++++----- .../domain/report/service/ReportService.java | 9 ---- 4 files changed, 42 insertions(+), 43 deletions(-) diff --git a/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java b/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java index e6724a1..7b13f5b 100644 --- a/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java +++ b/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java @@ -88,13 +88,15 @@ public class AiModelService { @CheckPrivilege(PrivilegeType.EDITOR) public void train(final Integer memberId, final Integer projectId, final ModelTrainRequest trainRequest) { - progressService.trainProgressCheck(projectId, trainRequest.getModelId()); + progressService.trainProgressCheck(projectId); try { - progressService.registerTrainProgress(projectId, trainRequest.getModelId()); + // 학습 상황 등록 + progressService.registerTrainProcess(projectId, trainRequest.getModelId()); Project project = getProject(projectId); AiModel model = getModel(trainRequest.getModelId()); + TrainRequest aiRequest = getTrainRequest(trainRequest, project, model); // FastAPI 서버로 POST 요청 전송 @@ -121,7 +123,7 @@ public class AiModelService { // 알람 전송 alarmService.save(memberId, Alarm.AlarmType.TRAIN); } finally { - progressService.removeTrainProgress(projectId, trainRequest.getModelId()); + progressService.removeTrainProgress(projectId); } } diff --git a/backend/src/main/java/com/worlabel/domain/progress/repository/ProgressCacheRepository.java b/backend/src/main/java/com/worlabel/domain/progress/repository/ProgressCacheRepository.java index de65bd6..3f914c5 100644 --- a/backend/src/main/java/com/worlabel/domain/progress/repository/ProgressCacheRepository.java +++ b/backend/src/main/java/com/worlabel/domain/progress/repository/ProgressCacheRepository.java @@ -10,7 +10,6 @@ import org.springframework.stereotype.Repository; import java.util.List; import java.util.Set; -import java.util.stream.Collectors; @Slf4j @Repository @@ -42,28 +41,35 @@ public class ProgressCacheRepository { redisTemplate.opsForSet().remove(CacheKey.autoLabelingProgressKey(), String.valueOf(projectId)); } + /** + * 현재 프로젝트 등록 + */ + public void registerTrainProject(final int projectId, final int modelId) { + String key = CacheKey.trainProgressKey(); + redisTemplate.opsForHash().put(key, String.valueOf(projectId), String.valueOf(modelId)); + } + + /** + * 현재 오토레이블링중인지 확인하는 메서드 + */ + public boolean trainProgressCheck(final int projectId) { + String key = CacheKey.trainProgressKey(); + return redisTemplate.opsForHash().hasKey(key, String.valueOf(projectId)); + } + + public void removeTrainProgress(final int projectId){ + String key = CacheKey.trainProgressKey(); + redisTemplate.opsForHash().delete(key, String.valueOf(projectId)); + } + /** * 현재 학습 진행 여부 확인 메서드 (단일 키 사용) */ - public boolean trainProgressCheck(final int projectId, final int modelId) { + public boolean trainModelProgressCheck(final int projectId, final int modelId) { String key = CacheKey.trainKey(projectId, modelId); - return Boolean.TRUE.equals(redisTemplate.hasKey(key)); - } + Long listSize = redisTemplate.opsForList().size(key); - /** - * 학습 진행 등록 메서드 (단일 키 사용) - */ - public void registerTrainProgress(final int projectId, final int modelId) { - String key = CacheKey.trainKey(projectId, modelId); - redisTemplate.opsForValue().set(key, String.valueOf(modelId)); - } - - /** - * 학습 진행 제거 메서드 (단일 키 사용) - */ - public void removeTrainProgress(final int projectId, final int modelId) { - String key = CacheKey.trainKey(projectId, modelId); - redisTemplate.delete(key); + return listSize != null && listSize > 0; } /** @@ -72,7 +78,7 @@ public class ProgressCacheRepository { public void addProgressModel(final int projectId, final int modelId, final ReportResponse data) { String jsonData = gson.toJson(data); String key = CacheKey.trainKey(projectId, modelId); - log.debug("key{} : data {}",key, jsonData); + log.debug("key{} : data {}", key, jsonData); redisTemplate.opsForList().rightPush(key, jsonData); } diff --git a/backend/src/main/java/com/worlabel/domain/progress/service/ProgressService.java b/backend/src/main/java/com/worlabel/domain/progress/service/ProgressService.java index bb758e0..e0c1373 100644 --- a/backend/src/main/java/com/worlabel/domain/progress/service/ProgressService.java +++ b/backend/src/main/java/com/worlabel/domain/progress/service/ProgressService.java @@ -22,7 +22,6 @@ public class ProgressService { throw new CustomException(ErrorCode.AI_IN_PROGRESS, "해당 프로젝트 오토레이블링 진행 중"); } } - public void registerPredictProgress(final int projectId) { progressCacheRepository.registerPredictProgress(projectId); } @@ -31,18 +30,22 @@ public class ProgressService { progressCacheRepository.removePredictProgress(projectId); } - public void trainProgressCheck(final int projectId, final int modelId) { - if (progressCacheRepository.trainProgressCheck(projectId, modelId)) { + public void registerTrainProcess(final int projectId, final int modelId) { + progressCacheRepository.registerTrainProject(projectId, modelId); + } + + public void removeTrainProgress(final int projectId){ + progressCacheRepository.removeTrainProgress(projectId); + } + + public void trainProgressCheck(final int projectId) { + if (progressCacheRepository.trainProgressCheck(projectId)) { throw new CustomException(ErrorCode.AI_IN_PROGRESS); } } - public void registerTrainProgress(final int projectId, final int modelId) { - progressCacheRepository.registerTrainProgress(projectId, modelId); - } - public boolean isProgressTrain(final int projectId, final int modelId) { - return progressCacheRepository.trainProgressCheck(projectId, modelId); + return progressCacheRepository.trainModelProgressCheck(projectId, modelId); } public void addProgressModel(final int projectId, final int modelId, final ReportResponse data) { @@ -53,11 +56,8 @@ public class ProgressService { return progressCacheRepository.getProgressModelByProjectId(projectId); } - public void removeTrainProgress(final int projectId, final int modelId) { - progressCacheRepository.removeTrainProgress(projectId, modelId); - } - public List getProgressResponse(final int projectId, final int modelId) { return progressCacheRepository.getProgressModel(projectId, modelId); } + } diff --git a/backend/src/main/java/com/worlabel/domain/report/service/ReportService.java b/backend/src/main/java/com/worlabel/domain/report/service/ReportService.java index 7bf4f02..eee69d1 100644 --- a/backend/src/main/java/com/worlabel/domain/report/service/ReportService.java +++ b/backend/src/main/java/com/worlabel/domain/report/service/ReportService.java @@ -33,21 +33,13 @@ public class ReportService { public void addReportByModelId(final Integer projectId, final Integer modelId, final ReportRequest reportRequest) { ReportResponse reportResponse = ReportResponse.of(reportRequest, modelId); - - boolean result = progressService.isProgressTrain(projectId, modelId); -// log.debug("result {}" ,result); -// if (result) { // 이미 존재하면 뒤에 추가 -// } progressService.addProgressModel(projectId, modelId, reportResponse); -// progressService.registerTrainProgress(projectId, modelId); - } public List getReportsProgressByModelId(final Integer projectId, final Integer modelId) { if (progressService.isProgressTrain(projectId, modelId)) { return progressService.getProgressResponse(projectId, modelId); } - return List.of(); } @@ -70,6 +62,5 @@ public class ReportService { reports.add(report); } reportRepository.saveAll(reports); - progressService.removeTrainProgress(projectId, modelId); } } \ No newline at end of file