Merge branch 'be/refactor/image' into 'be/develop'
Refactor: 이미지 상태 변경 및 모델 상태 변경 See merge request s11-s-project/S11P21S002!203
This commit is contained in:
commit
cc37d3048f
@ -134,7 +134,7 @@ public class ImageService {
|
||||
Image image = imageRepository.findById(imageId)
|
||||
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
|
||||
String dataPath = image.getDataPath();
|
||||
image.updateStatus(LabelStatus.COMPLETED);
|
||||
image.updateStatus(LabelStatus.SAVE);
|
||||
imageRepository.save(image);
|
||||
s3UploadService.uploadJson(labelRequest.getData(), dataPath);
|
||||
}
|
||||
|
@ -20,7 +20,10 @@ public class AiModelResponse {
|
||||
@Schema(description = "Default 모델 여부", example = "true")
|
||||
private Boolean isDefault;
|
||||
|
||||
public static AiModelResponse of(final AiModel aiModel) {
|
||||
return new AiModelResponse(aiModel.getId(), aiModel.getName(), aiModel.getProject() == null);
|
||||
@Schema(description = "모델 학습 여부", example = "true")
|
||||
private Boolean isTrain;
|
||||
|
||||
public static AiModelResponse of(final AiModel aiModel, final int progressModelId) {
|
||||
return new AiModelResponse(aiModel.getId(), aiModel.getName(), aiModel.getProject() == null, aiModel.getId() == progressModelId);
|
||||
}
|
||||
}
|
||||
|
@ -104,9 +104,10 @@ public class AiModelService {
|
||||
|
||||
@Transactional(readOnly = true)
|
||||
public List<AiModelResponse> getModelList(final Integer projectId) {
|
||||
int progressModelId = progressService.getProgressModelByProjectId(projectId);
|
||||
return aiModelRepository.findAllByProjectId(projectId)
|
||||
.stream()
|
||||
.map(AiModelResponse::of)
|
||||
.map(o -> AiModelResponse.of(o, progressModelId))
|
||||
.toList();
|
||||
}
|
||||
|
||||
@ -140,8 +141,6 @@ public class AiModelService {
|
||||
|
||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||
public void train(final Integer projectId, final ModelTrainRequest trainRequest) {
|
||||
// progressService.trainProgressCheck(projectId);
|
||||
|
||||
// FastAPI 서버로 학습 요청을 전송
|
||||
Project project = getProject(projectId);
|
||||
AiModel model = getModel(trainRequest.getModelId());
|
||||
@ -158,9 +157,7 @@ public class AiModelService {
|
||||
.map(TrainDataInfo::of)
|
||||
.toList();
|
||||
|
||||
// progressService.registerTrainProgress(projectId);
|
||||
TrainRequest aiRequest = TrainRequest.of(project.getId(), model.getId(), model.getModelKey(), labelMap, data, trainRequest);
|
||||
// progressService.removeTrainProgress(projectId);
|
||||
|
||||
String endPoint = project.getProjectType().getValue() + "/train";
|
||||
|
||||
|
@ -9,6 +9,7 @@ import org.springframework.data.redis.core.RedisTemplate;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@ -69,7 +70,10 @@ public class ProgressCacheRepository {
|
||||
* 진행 상황을 Redis에 추가 (리스트 형식 유지)
|
||||
*/
|
||||
public void addProgressModel(final int projectId, final int modelId, final ReportResponse data) {
|
||||
redisTemplate.opsForList().rightPush(CacheKey.trainKey(projectId, modelId), gson.toJson(data));
|
||||
String jsonData = gson.toJson(data);
|
||||
String key = CacheKey.trainKey(projectId, modelId);
|
||||
log.debug("key{} : data {}",key, jsonData);
|
||||
redisTemplate.opsForList().rightPush(key, jsonData);
|
||||
}
|
||||
|
||||
public List<ReportResponse> getProgressModel(final int projectId, final int modelId) {
|
||||
@ -82,6 +86,22 @@ public class ProgressCacheRepository {
|
||||
.toList();
|
||||
}
|
||||
|
||||
public int getProgressModelByProjectId(final int projectId) {
|
||||
String key = CacheKey.trainModelKey(projectId);
|
||||
log.debug("key : {}", key);
|
||||
// train:<projectId>:* 형태의 첫 번째 키를 가져옴
|
||||
Set<String> keys = redisTemplate.keys(key);
|
||||
|
||||
// 모델 ID를 추출하여 반환
|
||||
if (keys != null && !keys.isEmpty()) {
|
||||
String firstKey = keys.iterator().next(); // 첫 번째 키 가져오기
|
||||
String[] parts = firstKey.split(":");
|
||||
return Integer.parseInt(parts[2]); // modelId가 세 번째 위치
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
public void clearProgressModel(final int modelId) {
|
||||
redisTemplate.delete(CacheKey.progressStatusKey(modelId));
|
||||
}
|
||||
@ -89,4 +109,6 @@ public class ProgressCacheRepository {
|
||||
private ReportResponse convert(String data) {
|
||||
return gson.fromJson(data, ReportResponse.class);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
@ -45,10 +45,14 @@ public class ProgressService {
|
||||
return progressCacheRepository.trainProgressCheck(projectId, modelId);
|
||||
}
|
||||
|
||||
public void registerTrainProgress(final int projectId, final int modelId, final ReportResponse data) {
|
||||
public void addProgressModel(final int projectId, final int modelId, final ReportResponse data) {
|
||||
progressCacheRepository.addProgressModel(projectId, modelId, data);
|
||||
}
|
||||
|
||||
public int getProgressModelByProjectId(final int projectId) {
|
||||
return progressCacheRepository.getProgressModelByProjectId(projectId);
|
||||
}
|
||||
|
||||
public void removeTrainProgress(final int projectId, final int modelId) {
|
||||
progressCacheRepository.removeTrainProgress(projectId, modelId);
|
||||
}
|
||||
|
@ -163,8 +163,6 @@ public class ProjectService {
|
||||
*/
|
||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||
public void autoLabeling(final Integer projectId, final AutoModelRequest request) {
|
||||
// progressService.predictCheck(projectId);
|
||||
|
||||
Project project = getProject(projectId);
|
||||
String endPoint = project.getProjectType().getValue() + "/predict";
|
||||
|
||||
@ -186,15 +184,15 @@ public class ProjectService {
|
||||
}
|
||||
|
||||
// TODO: 트랜잭션 설정
|
||||
// TODO: 어떤 상황까지 덮어쓸껀지 물어보기
|
||||
@Transactional
|
||||
public void saveAutoLabelList(final List<AutoLabelingResult> resultList) {
|
||||
for(AutoLabelingResult result: resultList) {
|
||||
Image image = getImage(result.getImageId());
|
||||
if(image.getStatus() == LabelStatus.SAVE) continue;
|
||||
if(image.getStatus() == LabelStatus.SAVE || image.getStatus() == LabelStatus.IN_PROGRESS) continue;
|
||||
String dataPath = image.getDataPath();
|
||||
s3UploadService.uploadJson(result.getData(), dataPath);
|
||||
image.updateStatus(LabelStatus.IN_PROGRESS);
|
||||
imageRepository.save(image);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -9,12 +9,14 @@ import com.worlabel.domain.report.entity.dto.ReportRequest;
|
||||
import com.worlabel.domain.report.entity.dto.ReportResponse;
|
||||
import com.worlabel.domain.report.repository.ReportRepository;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@Transactional
|
||||
@RequiredArgsConstructor
|
||||
@ -29,36 +31,16 @@ public class ReportService {
|
||||
.toList();
|
||||
}
|
||||
|
||||
private List<ReportResponse> getDummyList() {
|
||||
List<ReportResponse> dummyList = new ArrayList<>();
|
||||
|
||||
// 더미 데이터 15개 생성
|
||||
for (int i = 1; i <= 15; i++) {
|
||||
ReportResponse dummy = new ReportResponse(
|
||||
i, // modelId
|
||||
100, // totalEpochs
|
||||
i, // epoch
|
||||
Math.random(), // boxLoss
|
||||
Math.random(), // clsLoss
|
||||
Math.random(), // dflLoss
|
||||
Math.random(), // fitness
|
||||
Math.random() * 10,// epochTime
|
||||
Math.random() * 100 // leftSecond
|
||||
);
|
||||
dummyList.add(dummy);
|
||||
}
|
||||
|
||||
return dummyList;
|
||||
}
|
||||
|
||||
public void addReportByModelId(final Integer projectId, final Integer modelId, final ReportRequest reportRequest) {
|
||||
ReportResponse reportResponse = ReportResponse.of(reportRequest, modelId);
|
||||
|
||||
if (progressService.isProgressTrain(projectId, modelId)) { // 이미 존재하면 뒤에 추가
|
||||
progressService.registerTrainProgress(projectId, modelId, reportResponse);
|
||||
} else { // 새로추가
|
||||
progressService.registerTrainProgress(projectId, modelId, reportResponse);
|
||||
}
|
||||
boolean result = progressService.isProgressTrain(projectId, modelId);
|
||||
// log.debug("result {}" ,result);
|
||||
// if (result) { // 이미 존재하면 뒤에 추가
|
||||
// }
|
||||
progressService.addProgressModel(projectId, modelId, reportResponse);
|
||||
// progressService.registerTrainProgress(projectId, modelId);
|
||||
|
||||
}
|
||||
|
||||
public List<ReportResponse> getReportsProgressByModelId(final Integer projectId, final Integer modelId) {
|
||||
|
@ -25,6 +25,10 @@ public class CacheKey {
|
||||
return "train:" + projectId + ":" + modelId;
|
||||
}
|
||||
|
||||
public static String trainModelKey(int projectId) {
|
||||
return "train:" + projectId + ":*";
|
||||
}
|
||||
|
||||
public static String alarmIdKey(){
|
||||
return "alarm:id";
|
||||
}
|
||||
@ -36,4 +40,6 @@ public class CacheKey {
|
||||
public static String alarmMemberAllKey(int memberId) {
|
||||
return "member:" + memberId + ":alarm:*";
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user