Merge branch 'be/refactor/image' into 'be/develop'

Refactor: 이미지 상태 변경 및 모델 상태 변경

See merge request s11-s-project/S11P21S002!203
This commit is contained in:
김진현 2024-09-26 21:28:23 +09:00
commit cc37d3048f
8 changed files with 53 additions and 41 deletions

View File

@ -134,7 +134,7 @@ public class ImageService {
Image image = imageRepository.findById(imageId)
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
String dataPath = image.getDataPath();
image.updateStatus(LabelStatus.COMPLETED);
image.updateStatus(LabelStatus.SAVE);
imageRepository.save(image);
s3UploadService.uploadJson(labelRequest.getData(), dataPath);
}

View File

@ -20,7 +20,10 @@ public class AiModelResponse {
@Schema(description = "Default 모델 여부", example = "true")
private Boolean isDefault;
public static AiModelResponse of(final AiModel aiModel) {
return new AiModelResponse(aiModel.getId(), aiModel.getName(), aiModel.getProject() == null);
@Schema(description = "모델 학습 여부", example = "true")
private Boolean isTrain;
public static AiModelResponse of(final AiModel aiModel, final int progressModelId) {
return new AiModelResponse(aiModel.getId(), aiModel.getName(), aiModel.getProject() == null, aiModel.getId() == progressModelId);
}
}

View File

@ -104,9 +104,10 @@ public class AiModelService {
@Transactional(readOnly = true)
public List<AiModelResponse> getModelList(final Integer projectId) {
int progressModelId = progressService.getProgressModelByProjectId(projectId);
return aiModelRepository.findAllByProjectId(projectId)
.stream()
.map(AiModelResponse::of)
.map(o -> AiModelResponse.of(o, progressModelId))
.toList();
}
@ -140,8 +141,6 @@ public class AiModelService {
@CheckPrivilege(PrivilegeType.EDITOR)
public void train(final Integer projectId, final ModelTrainRequest trainRequest) {
// progressService.trainProgressCheck(projectId);
// FastAPI 서버로 학습 요청을 전송
Project project = getProject(projectId);
AiModel model = getModel(trainRequest.getModelId());
@ -158,9 +157,7 @@ public class AiModelService {
.map(TrainDataInfo::of)
.toList();
// progressService.registerTrainProgress(projectId);
TrainRequest aiRequest = TrainRequest.of(project.getId(), model.getId(), model.getModelKey(), labelMap, data, trainRequest);
// progressService.removeTrainProgress(projectId);
String endPoint = project.getProjectType().getValue() + "/train";

View File

@ -9,6 +9,7 @@ import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Repository;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
@ -69,7 +70,10 @@ public class ProgressCacheRepository {
* 진행 상황을 Redis에 추가 (리스트 형식 유지)
*/
public void addProgressModel(final int projectId, final int modelId, final ReportResponse data) {
redisTemplate.opsForList().rightPush(CacheKey.trainKey(projectId, modelId), gson.toJson(data));
String jsonData = gson.toJson(data);
String key = CacheKey.trainKey(projectId, modelId);
log.debug("key{} : data {}",key, jsonData);
redisTemplate.opsForList().rightPush(key, jsonData);
}
public List<ReportResponse> getProgressModel(final int projectId, final int modelId) {
@ -82,6 +86,22 @@ public class ProgressCacheRepository {
.toList();
}
public int getProgressModelByProjectId(final int projectId) {
String key = CacheKey.trainModelKey(projectId);
log.debug("key : {}", key);
// train:<projectId>:* 형태의 번째 키를 가져옴
Set<String> keys = redisTemplate.keys(key);
// 모델 ID를 추출하여 반환
if (keys != null && !keys.isEmpty()) {
String firstKey = keys.iterator().next(); // 번째 가져오기
String[] parts = firstKey.split(":");
return Integer.parseInt(parts[2]); // modelId가 번째 위치
}
return 0;
}
public void clearProgressModel(final int modelId) {
redisTemplate.delete(CacheKey.progressStatusKey(modelId));
}
@ -89,4 +109,6 @@ public class ProgressCacheRepository {
private ReportResponse convert(String data) {
return gson.fromJson(data, ReportResponse.class);
}
}

View File

@ -45,10 +45,14 @@ public class ProgressService {
return progressCacheRepository.trainProgressCheck(projectId, modelId);
}
public void registerTrainProgress(final int projectId, final int modelId, final ReportResponse data) {
public void addProgressModel(final int projectId, final int modelId, final ReportResponse data) {
progressCacheRepository.addProgressModel(projectId, modelId, data);
}
public int getProgressModelByProjectId(final int projectId) {
return progressCacheRepository.getProgressModelByProjectId(projectId);
}
public void removeTrainProgress(final int projectId, final int modelId) {
progressCacheRepository.removeTrainProgress(projectId, modelId);
}

View File

@ -163,8 +163,6 @@ public class ProjectService {
*/
@CheckPrivilege(PrivilegeType.EDITOR)
public void autoLabeling(final Integer projectId, final AutoModelRequest request) {
// progressService.predictCheck(projectId);
Project project = getProject(projectId);
String endPoint = project.getProjectType().getValue() + "/predict";
@ -186,15 +184,15 @@ public class ProjectService {
}
// TODO: 트랜잭션 설정
// TODO: 어떤 상황까지 덮어쓸껀지 물어보기
@Transactional
public void saveAutoLabelList(final List<AutoLabelingResult> resultList) {
for(AutoLabelingResult result: resultList) {
Image image = getImage(result.getImageId());
if(image.getStatus() == LabelStatus.SAVE) continue;
if(image.getStatus() == LabelStatus.SAVE || image.getStatus() == LabelStatus.IN_PROGRESS) continue;
String dataPath = image.getDataPath();
s3UploadService.uploadJson(result.getData(), dataPath);
image.updateStatus(LabelStatus.IN_PROGRESS);
imageRepository.save(image);
}
}

View File

@ -9,12 +9,14 @@ import com.worlabel.domain.report.entity.dto.ReportRequest;
import com.worlabel.domain.report.entity.dto.ReportResponse;
import com.worlabel.domain.report.repository.ReportRepository;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.ArrayList;
import java.util.List;
@Slf4j
@Service
@Transactional
@RequiredArgsConstructor
@ -29,36 +31,16 @@ public class ReportService {
.toList();
}
private List<ReportResponse> getDummyList() {
List<ReportResponse> dummyList = new ArrayList<>();
// 더미 데이터 15개 생성
for (int i = 1; i <= 15; i++) {
ReportResponse dummy = new ReportResponse(
i, // modelId
100, // totalEpochs
i, // epoch
Math.random(), // boxLoss
Math.random(), // clsLoss
Math.random(), // dflLoss
Math.random(), // fitness
Math.random() * 10,// epochTime
Math.random() * 100 // leftSecond
);
dummyList.add(dummy);
}
return dummyList;
}
public void addReportByModelId(final Integer projectId, final Integer modelId, final ReportRequest reportRequest) {
ReportResponse reportResponse = ReportResponse.of(reportRequest, modelId);
if (progressService.isProgressTrain(projectId, modelId)) { // 이미 존재하면 뒤에 추가
progressService.registerTrainProgress(projectId, modelId, reportResponse);
} else { // 새로추가
progressService.registerTrainProgress(projectId, modelId, reportResponse);
}
boolean result = progressService.isProgressTrain(projectId, modelId);
// log.debug("result {}" ,result);
// if (result) { // 이미 존재하면 뒤에 추가
// }
progressService.addProgressModel(projectId, modelId, reportResponse);
// progressService.registerTrainProgress(projectId, modelId);
}
public List<ReportResponse> getReportsProgressByModelId(final Integer projectId, final Integer modelId) {

View File

@ -25,6 +25,10 @@ public class CacheKey {
return "train:" + projectId + ":" + modelId;
}
public static String trainModelKey(int projectId) {
return "train:" + projectId + ":*";
}
public static String alarmIdKey(){
return "alarm:id";
}
@ -36,4 +40,6 @@ public class CacheKey {
public static String alarmMemberAllKey(int memberId) {
return "member:" + memberId + ":alarm:*";
}
}