Fix: 모델 학습 에러 해결

This commit is contained in:
김용수 2024-09-30 12:18:48 +09:00
parent eec903e233
commit 6d6cd8b134
4 changed files with 42 additions and 43 deletions

View File

@ -88,13 +88,15 @@ public class AiModelService {
@CheckPrivilege(PrivilegeType.EDITOR)
public void train(final Integer memberId, final Integer projectId, final ModelTrainRequest trainRequest) {
progressService.trainProgressCheck(projectId, trainRequest.getModelId());
progressService.trainProgressCheck(projectId);
try {
progressService.registerTrainProgress(projectId, trainRequest.getModelId());
// 학습 상황 등록
progressService.registerTrainProcess(projectId, trainRequest.getModelId());
Project project = getProject(projectId);
AiModel model = getModel(trainRequest.getModelId());
TrainRequest aiRequest = getTrainRequest(trainRequest, project, model);
// FastAPI 서버로 POST 요청 전송
@ -121,7 +123,7 @@ public class AiModelService {
// 알람 전송
alarmService.save(memberId, Alarm.AlarmType.TRAIN);
} finally {
progressService.removeTrainProgress(projectId, trainRequest.getModelId());
progressService.removeTrainProgress(projectId);
}
}

View File

@ -10,7 +10,6 @@ import org.springframework.stereotype.Repository;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
@Repository
@ -42,28 +41,35 @@ public class ProgressCacheRepository {
redisTemplate.opsForSet().remove(CacheKey.autoLabelingProgressKey(), String.valueOf(projectId));
}
/**
* 현재 프로젝트 등록
*/
public void registerTrainProject(final int projectId, final int modelId) {
String key = CacheKey.trainProgressKey();
redisTemplate.opsForHash().put(key, String.valueOf(projectId), String.valueOf(modelId));
}
/**
* 현재 오토레이블링중인지 확인하는 메서드
*/
public boolean trainProgressCheck(final int projectId) {
String key = CacheKey.trainProgressKey();
return redisTemplate.opsForHash().hasKey(key, String.valueOf(projectId));
}
public void removeTrainProgress(final int projectId){
String key = CacheKey.trainProgressKey();
redisTemplate.opsForHash().delete(key, String.valueOf(projectId));
}
/**
* 현재 학습 진행 여부 확인 메서드 (단일 사용)
*/
public boolean trainProgressCheck(final int projectId, final int modelId) {
public boolean trainModelProgressCheck(final int projectId, final int modelId) {
String key = CacheKey.trainKey(projectId, modelId);
return Boolean.TRUE.equals(redisTemplate.hasKey(key));
}
Long listSize = redisTemplate.opsForList().size(key);
/**
* 학습 진행 등록 메서드 (단일 사용)
*/
public void registerTrainProgress(final int projectId, final int modelId) {
String key = CacheKey.trainKey(projectId, modelId);
redisTemplate.opsForValue().set(key, String.valueOf(modelId));
}
/**
* 학습 진행 제거 메서드 (단일 사용)
*/
public void removeTrainProgress(final int projectId, final int modelId) {
String key = CacheKey.trainKey(projectId, modelId);
redisTemplate.delete(key);
return listSize != null && listSize > 0;
}
/**
@ -72,7 +78,7 @@ public class ProgressCacheRepository {
public void addProgressModel(final int projectId, final int modelId, final ReportResponse data) {
String jsonData = gson.toJson(data);
String key = CacheKey.trainKey(projectId, modelId);
log.debug("key{} : data {}",key, jsonData);
log.debug("key{} : data {}", key, jsonData);
redisTemplate.opsForList().rightPush(key, jsonData);
}

View File

@ -22,7 +22,6 @@ public class ProgressService {
throw new CustomException(ErrorCode.AI_IN_PROGRESS, "해당 프로젝트 오토레이블링 진행 중");
}
}
public void registerPredictProgress(final int projectId) {
progressCacheRepository.registerPredictProgress(projectId);
}
@ -31,18 +30,22 @@ public class ProgressService {
progressCacheRepository.removePredictProgress(projectId);
}
public void trainProgressCheck(final int projectId, final int modelId) {
if (progressCacheRepository.trainProgressCheck(projectId, modelId)) {
public void registerTrainProcess(final int projectId, final int modelId) {
progressCacheRepository.registerTrainProject(projectId, modelId);
}
public void removeTrainProgress(final int projectId){
progressCacheRepository.removeTrainProgress(projectId);
}
public void trainProgressCheck(final int projectId) {
if (progressCacheRepository.trainProgressCheck(projectId)) {
throw new CustomException(ErrorCode.AI_IN_PROGRESS);
}
}
public void registerTrainProgress(final int projectId, final int modelId) {
progressCacheRepository.registerTrainProgress(projectId, modelId);
}
public boolean isProgressTrain(final int projectId, final int modelId) {
return progressCacheRepository.trainProgressCheck(projectId, modelId);
return progressCacheRepository.trainModelProgressCheck(projectId, modelId);
}
public void addProgressModel(final int projectId, final int modelId, final ReportResponse data) {
@ -53,11 +56,8 @@ public class ProgressService {
return progressCacheRepository.getProgressModelByProjectId(projectId);
}
public void removeTrainProgress(final int projectId, final int modelId) {
progressCacheRepository.removeTrainProgress(projectId, modelId);
}
public List<ReportResponse> getProgressResponse(final int projectId, final int modelId) {
return progressCacheRepository.getProgressModel(projectId, modelId);
}
}

View File

@ -33,21 +33,13 @@ public class ReportService {
public void addReportByModelId(final Integer projectId, final Integer modelId, final ReportRequest reportRequest) {
ReportResponse reportResponse = ReportResponse.of(reportRequest, modelId);
boolean result = progressService.isProgressTrain(projectId, modelId);
// log.debug("result {}" ,result);
// if (result) { // 이미 존재하면 뒤에 추가
// }
progressService.addProgressModel(projectId, modelId, reportResponse);
// progressService.registerTrainProgress(projectId, modelId);
}
public List<ReportResponse> getReportsProgressByModelId(final Integer projectId, final Integer modelId) {
if (progressService.isProgressTrain(projectId, modelId)) {
return progressService.getProgressResponse(projectId, modelId);
}
return List.of();
}
@ -70,6 +62,5 @@ public class ReportService {
reports.add(report);
}
reportRepository.saveAll(reports);
progressService.removeTrainProgress(projectId, modelId);
}
}