Merge branch 'be/refactor/train' into 'be/develop'

Feat: Model 조회 및 더미 데이터 API 생성

See merge request s11-s-project/S11P21S002!162
This commit is contained in:
정현조 2024-09-25 03:03:33 +09:00
commit a138b04ee7
12 changed files with 300 additions and 98 deletions

View File

@ -3,7 +3,9 @@ package com.worlabel.domain.model.controller;
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
import com.worlabel.domain.model.entity.dto.AiModelRequest;
import com.worlabel.domain.model.entity.dto.AiModelResponse;
import com.worlabel.domain.model.entity.dto.ModelTrainRequest;
import com.worlabel.domain.model.service.AiModelService;
import com.worlabel.domain.progress.service.ProgressService;
import com.worlabel.domain.project.entity.dto.ProjectRequest;
import com.worlabel.global.annotation.CurrentUser;
import com.worlabel.global.config.swagger.SwaggerApiError;
@ -26,6 +28,7 @@ import java.util.List;
public class AiModelController {
private final AiModelService aiModelService;
private final ProgressService progressService;
@Operation(summary = "프로젝트 모델 조회", description = "프로젝트에 있는 모델을 조회합니다.")
@SwaggerApiSuccess(description = "프로젝트 멤버를 성공적으로 조회합니다.")
@ -72,7 +75,8 @@ public class AiModelController {
@PostMapping("/projects/{project_id}/train")
public void trainModel(
@PathVariable("project_id") final Integer projectId,
@RequestBody final Integer modelId) {
aiModelService.train(projectId, modelId);
@RequestBody final ModelTrainRequest trainRequest) {
log.debug("모델 학습 요청 {}", trainRequest);
aiModelService.train(projectId, trainRequest);
}
}

View File

@ -0,0 +1,35 @@
package com.worlabel.domain.model.entity.dto;
import com.worlabel.domain.result.entity.Optimizer;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty;
import lombok.*;
@Getter
@AllArgsConstructor
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@Schema(name = "모델 훈련 요청 dto", description = "모델 훈련 요청 DTO")
public class ModelTrainRequest {
@Schema(description = "모델 ID", example = "1")
@NotEmpty(message = "아이디를 입력하세요")
private Integer modelId;
@Schema(description = "ratio", example = "Default = 0.8")
private double ratio;
@Schema(description = "epochs", example = "Default = 50")
private int epochs;
@Schema(description = "batch", example = "Default = -1")
private int batch;
@Schema(description = "lr0", example = "Default = 0.01")
private double lr0;
@Schema(description = "lrf", example = "Default = 0.01")
private double lrf;
@Schema(description = "optimizer", example = "Default = auto")
private Optimizer optimizer;
}

View File

@ -6,19 +6,21 @@ import com.worlabel.domain.image.entity.Image;
import com.worlabel.domain.image.entity.LabelStatus;
import com.worlabel.domain.image.repository.ImageRepository;
import com.worlabel.domain.labelcategory.entity.LabelCategory;
import com.worlabel.domain.labelcategory.entity.ProjectCategory;
import com.worlabel.domain.labelcategory.entity.dto.DefaultLabelCategoryResponse;
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository;
import com.worlabel.domain.model.entity.AiModel;
import com.worlabel.domain.model.entity.dto.AiModelRequest;
import com.worlabel.domain.model.entity.dto.AiModelResponse;
import com.worlabel.domain.model.entity.dto.DefaultAiModelResponse;
import com.worlabel.domain.model.entity.dto.DefaultResponse;
import com.worlabel.domain.model.entity.dto.*;
import com.worlabel.domain.model.repository.AiModelRepository;
import com.worlabel.domain.participant.entity.PrivilegeType;
import com.worlabel.domain.progress.service.ProgressService;
import com.worlabel.domain.project.dto.AiDto;
import com.worlabel.domain.project.dto.AiDto.TrainDataInfo;
import com.worlabel.domain.project.dto.AiDto.TrainRequest;
import com.worlabel.domain.project.entity.Project;
import com.worlabel.domain.project.repository.ProjectRepository;
import com.worlabel.domain.project.service.ProjectService;
import com.worlabel.global.annotation.CheckPrivilege;
import com.worlabel.global.cache.CacheKey;
import com.worlabel.global.exception.CustomException;
@ -34,7 +36,10 @@ import java.lang.reflect.Type;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Slf4j
@Service
@ -42,13 +47,15 @@ import java.util.List;
@RequiredArgsConstructor
public class AiModelService {
private final LabelCategoryRepository labelCategoryRepository;
private final RedisTemplate<String, Object> redisTemplate;
private final AiModelRepository aiModelRepository;
private final ProjectRepository projectRepository;
private final LabelCategoryRepository labelCategoryRepository;
private final ImageRepository imageRepository;
private final AiRequestService aiRequestService;
private final RedisTemplate<String, Object> redisTemplate;
private final ImageRepository imageRepository;
private final ProjectService projectService;
private final Gson gson;
private final ProgressService progressService;
// @PostConstruct
public void loadDefaultModel() {
@ -127,56 +134,43 @@ public class AiModelService {
}
@CheckPrivilege(PrivilegeType.EDITOR)
public void train(final Integer projectId, final Integer modelId) {
trainProgressCheck(projectId);
public void train(final Integer projectId, final ModelTrainRequest trainRequest) {
// progressService.trainProgressCheck(projectId);
// FastAPI 서버로 학습 요청을 전송
Project project = getProject(projectId);
AiModel model = getModel(modelId);
List<LabelCategory> labelCategories = labelCategoryRepository.findAllByModelId(modelId);
List<Integer> categories = labelCategories.stream()
.map(LabelCategory::getAiCategoryId).toList();
AiModel model = getModel(trainRequest.getModelId());
Map<Integer, Integer> labelMap = project.getCategoryList().stream()
.collect(Collectors.toMap(
category -> category.getLabelCategory().getId(),
ProjectCategory::getId
));
List<Image> images = imageRepository.findImagesByProjectId(projectId);
List<AiDto.TrainDataInfo> data = images.stream().filter(image -> image.getStatus() == LabelStatus.COMPLETED)
.map(image -> new AiDto.TrainDataInfo(image.getImagePath(), image.getDataPath()))
List<TrainDataInfo> data = images.stream()
.filter(image -> image.getStatus() == LabelStatus.COMPLETED)
.map(TrainDataInfo::of)
.toList();
// progressService.registerTrainProgress(projectId);
TrainRequest aiRequest = TrainRequest.of(project.getId(), model.getModelKey(), labelMap, data, trainRequest);
// progressService.removeTrainProgress(projectId);
String endPoint = project.getProjectType().getValue() + "/train";
AiDto.TrainRequest trainRequest = new AiDto.TrainRequest();
trainRequest.setProjectId(projectId);
trainRequest.setCategoryId(categories);
trainRequest.setData(data);
trainRequest.setModelKey(model.getModelKey());
// FastAPI 서버로 POST 요청 전송
String modelKey = aiRequestService.postRequest(endPoint, trainRequest, String.class, response -> response);
String modelKey = aiRequestService.postRequest(endPoint, aiRequest, String.class, response -> response);
// 가져온 modelKey -> version 업된 모델 다시 새롭게 저장
String currentDateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
String currentDateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMdd_HHmm"));
int newVersion = model.getVersion() + 1;
String newName = currentDateTime + String.format("%03d", newVersion);
AiModel newModel = AiModel.of(currentDateTime, modelKey, model.getVersion() + 1, project);
AiModel newModel = AiModel.of(newName, modelKey, newVersion, project);
aiModelRepository.save(newModel);
}
/**
* Redis 중복 요청 체크
*/
private void trainProgressCheck(Integer projectId) {
String trainProgressKey = CacheKey.trainProgressKey();
// 존재 확인
Boolean isProjectExist = redisTemplate.opsForSet().isMember(trainProgressKey, projectId);
if (Boolean.TRUE.equals(isProjectExist)) {
throw new CustomException(ErrorCode.AI_IN_PROGRESS);
}
// 학습 진행 중으로 상태 등록
redisTemplate.opsForSet().add(trainProgressKey, projectId);
}
/**
* Json -> List<DefaultResponse>
*/

View File

@ -1,37 +1,91 @@
package com.worlabel.domain.progress.repository;
import com.google.gson.Gson;
import com.worlabel.domain.report.entity.dto.ReportResponse;
import com.worlabel.global.cache.CacheKey;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Repository;
import java.util.List;
import java.util.stream.Collectors;
@Slf4j
@Repository
@RequiredArgsConstructor
public class ProgressCacheRepository {
private final RedisTemplate<String, Object> redisTemplate;
private final RedisTemplate<String, String> redisTemplate;
private final Gson gson;
/**
* 현재 오토레이블링중인지 확인하는 메서드
*/
public boolean predictCheck(final int projectId) {
String key = CacheKey.autoLabelingProgressKey();
Boolean isProgress = redisTemplate.opsForSet().isMember(key, projectId);
public boolean predictProgressCheck(final int projectId) {
Boolean isProgress = redisTemplate.opsForSet().isMember(CacheKey.autoLabelingProgressKey(), String.valueOf(projectId));
return Boolean.TRUE.equals(isProgress);
}
/**
* 학습 진행 등록 메서드
* 오토레이블링 진행 등록 메서드
*/
public void registerPredictProgress(final int projectId) {
String key = CacheKey.autoLabelingProgressKey();
redisTemplate.opsForSet().add(key, projectId);
redisTemplate.opsForSet().add(CacheKey.autoLabelingProgressKey(), String.valueOf(projectId));
}
/**
* 오토레이블링 진행 제거 메서드
*/
public void removePredictProgress(final int projectId) {
String key = CacheKey.autoLabelingProgressKey();
redisTemplate.opsForSet().remove(key, projectId);
redisTemplate.opsForSet().remove(CacheKey.autoLabelingProgressKey(), String.valueOf(projectId));
}
/**
* 학습 진행 확인 메서드
*/
public boolean trainProgressCheck(final int projectId) {
Boolean isProgress = redisTemplate.opsForSet().isMember(CacheKey.trainProgressKey(), String.valueOf(projectId));
return Boolean.TRUE.equals(isProgress);
}
/**
* 학습 진행 등록 메서드
*/
public void registerTrainProgress(final int projectId) {
redisTemplate.opsForSet().add(CacheKey.trainProgressKey(), String.valueOf(projectId));
}
/**
* 학습 진행 제거 메서드
*/
public void removeTrainProgress(final int projectId) {
redisTemplate.opsForSet().remove(CacheKey.trainProgressKey(), String.valueOf(projectId));
}
/**
* 진행 상황을 Redis에 추가
*/
public void addProgressModel(final int modelId,final String data){
ReportResponse reportResponse = convert(data);
redisTemplate.opsForList().rightPush(CacheKey.progressStatusKey(modelId), gson.toJson(reportResponse));
}
public List<ReportResponse> getProgressModel(final int modelId) {
// 저장된 주어진 응답에 맞추어 리턴
String key = CacheKey.progressStatusKey(modelId);
List<String> progressList = redisTemplate.opsForList().range(key, 0, -1);
return progressList.stream()
.map(this::convert)
.toList();
}
public void clearProgressModel(final int modelId) {
redisTemplate.delete(CacheKey.progressStatusKey(modelId));
}
private ReportResponse convert(String data){
return gson.fromJson(data, ReportResponse.class);
}
}

View File

@ -1,12 +1,15 @@
package com.worlabel.domain.progress.service;
import com.worlabel.domain.progress.repository.ProgressCacheRepository;
import com.worlabel.domain.report.entity.dto.ReportResponse;
import com.worlabel.global.exception.CustomException;
import com.worlabel.global.exception.ErrorCode;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.List;
@Slf4j
@Service
@RequiredArgsConstructor
@ -14,14 +17,39 @@ public class ProgressService {
private final ProgressCacheRepository progressCacheRepository;
public void predictCheck(final int projectId){
if(progressCacheRepository.predictCheck(projectId)){
// throw new CustomException(ErrorCode.AI_IN_PROGRESS);
progressCacheRepository.removePredictProgress(projectId);
public void predictProgressCheck(final int projectId){
if(progressCacheRepository.predictProgressCheck(projectId)){
throw new CustomException(ErrorCode.AI_IN_PROGRESS);
}
}
public void registerPredictProgress(final int projectId){
progressCacheRepository.registerPredictProgress(projectId);
}
public void removePredictProgress(final int projectId){
progressCacheRepository.removePredictProgress(projectId);
}
public void trainProgressCheck(final int projectId){
if(progressCacheRepository.trainProgressCheck(projectId)){
throw new CustomException(ErrorCode.AI_IN_PROGRESS);
}
}
public boolean isProgressTrain(final int projectId){
return progressCacheRepository.trainProgressCheck(projectId);
}
public void registerTrainProgress(final int projectId){
progressCacheRepository.registerTrainProgress(projectId);
}
public void removeTrainProgress(final int projectId){
progressCacheRepository.removeTrainProgress(projectId);
}
public List<ReportResponse> getProgressResponse(final int modelId) {
return progressCacheRepository.getProgressModel(modelId);
}
}

View File

@ -3,41 +3,76 @@ package com.worlabel.domain.project.dto;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.gson.annotations.SerializedName;
import com.worlabel.domain.image.entity.Image;
import com.worlabel.domain.model.entity.dto.ModelTrainRequest;
import com.worlabel.domain.result.entity.Optimizer;
import lombok.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class AiDto {
@Data
@Getter
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class TrainDataInfo {
@JsonProperty("image_url")
private String imagePath;
@JsonProperty("data_url")
private String dataPath;
public TrainDataInfo(String imagePath, String dataPath) {
this.imagePath = imagePath;
this.dataPath = dataPath;
public static TrainDataInfo of(Image image) {
return new TrainDataInfo(image.getImagePath(), image.getDataPath());
}
}
@Data
@Getter
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static class TrainRequest {
@JsonProperty("project_id")
private int projectId;
@JsonProperty("category_id")
private List<Integer> categoryId;
@JsonProperty("model_key")
private String modelKey;
@JsonProperty("label_map")
private Map<Integer,Integer> labelMap;
@JsonProperty("data")
private List<TrainDataInfo> data;
@JsonProperty("model_key")
private String modelKey;
// private int seed; // Optional
// private float ratio; // Default = 0.8
// private int epochs; // Default = 50
// private float batch; // Default = -1
private double ratio; // Default = 0.8
private int epochs; // Default = 50
private double batch; // Default = -1
private double lr0;
private double lrf;
private Optimizer optimizer;
public static TrainRequest of(final Integer projectId, final String modelKey, final Map<Integer, Integer> labelMap, final List<TrainDataInfo> data, final ModelTrainRequest trainRequest) {
TrainRequest request = new TrainRequest();
request.projectId = projectId;
request.modelKey = modelKey;
request.labelMap = labelMap;
request.data = data;
request.ratio = request.getRatio();
request.epochs = trainRequest.getEpochs();
request.batch = trainRequest.getBatch();
request.lr0 = trainRequest.getLr0();
request.lrf = trainRequest.getLrf();
request.optimizer = trainRequest.getOptimizer();
return request;
}
}
@Getter

View File

@ -162,7 +162,7 @@ public class ProjectService {
*/
@CheckPrivilege(PrivilegeType.EDITOR)
public void autoLabeling(final Integer projectId, final AutoModelRequest request) {
progressService.predictCheck(projectId);
// progressService.predictCheck(projectId);
Project project = getProject(projectId);
String endPoint = project.getProjectType().getValue() + "/predict";

View File

@ -11,14 +11,14 @@ import org.springframework.web.bind.annotation.RestController;
import java.util.List;
@RestController
@RequestMapping("/api/reports")
@RequestMapping("/api/projects/{project_id}/reports")
@RequiredArgsConstructor
public class ReportController {
private final ReportService reportService;
@GetMapping("/model/{model_id}")
public List<ReportResponse> getReportsByModelId(@PathVariable("model_id") final Integer modelId) {
return reportService.getReportsByModelId(modelId);
public List<ReportResponse> getReportsByModelId(@PathVariable("model_id") final Integer modelId, @PathVariable("project_id") final Integer projectId) {
return reportService.getReportsByModelId(projectId,modelId);
}
}

View File

@ -25,17 +25,18 @@ public class Report extends BaseEntity {
@JoinColumn(name = "model_id", nullable = false)
private AiModel aiModel;
/**
* 현재 에포크
*/
@Column(name = "epoch", nullable = false)
private Integer epoch;
/**
* 전체 에포크
*/
@Column(name = "total_epochs", nullable = false)
private Integer totalEpochs;
/**
* 현재 에포크
*/
@Column(name = "epoch", nullable = false)
private Integer epoch;
@Column(name = "box_loss", nullable = false)
private double boxLoss;
@ -48,4 +49,10 @@ public class Report extends BaseEntity {
@Column(name = "fitness", nullable = false)
private double fitness;
@Column(name = "epoch_time", nullable = false)
private double epochTime;
@Column(name = "left_second", nullable = false)
private double leftSecond;
}

View File

@ -4,26 +4,33 @@ import com.worlabel.domain.report.entity.Report;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
@Getter
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@NoArgsConstructor
@AllArgsConstructor
public class ReportResponse {
private Integer id;
private Integer totalEpochs;
private Integer epoch;
private int modelId;
private int totalEpochs;
private int epoch;
private double boxLoss;
private double clsLoss;
private double dflLoss;
private double fitness;
private double epochTime;
private double leftSecond;
public static ReportResponse from(final Report report) {
return new ReportResponse(
report.getId(),
report.getAiModel().getId(),
report.getTotalEpochs(),
report.getEpoch(),
report.getBoxLoss(),
report.getClsLoss(),
report.getDflLoss(),
report.getFitness());
report.getFitness(),
report.getEpochTime(),
report.getLeftSecond()
);
}
}

View File

@ -1,5 +1,6 @@
package com.worlabel.domain.report.service;
import com.worlabel.domain.progress.service.ProgressService;
import com.worlabel.domain.report.entity.Report;
import com.worlabel.domain.report.entity.dto.ReportResponse;
import com.worlabel.domain.report.repository.ReportRepository;
@ -7,6 +8,7 @@ import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.ArrayList;
import java.util.List;
@Service
@ -15,11 +17,43 @@ import java.util.List;
public class ReportService {
private final ReportRepository reportRepository;
private final ProgressService progressService;
public List<ReportResponse> getReportsByModelId(final Integer modelId) {
List<Report> reports = reportRepository.findByAiModelId(modelId);
return reports.stream()
.map(ReportResponse::from)
.toList();
public List<ReportResponse> getReportsByModelId(final Integer projectId, final Integer modelId) {
// 진행중이면 진행중에서 받아오기
return getDummyList();
// if(progressService.isProgressTrain(projectId)){
// return progressService.getProgressResponse(modelId);
// }
// // 작업 완료시에는 RDB
// else{
// List<Report> reports = reportRepository.findByAiModelId(modelId);
// return reports.stream()
// .map(ReportResponse::from)
// .toList();
// }
}
private List<ReportResponse> getDummyList() {
List<ReportResponse> dummyList = new ArrayList<>();
// 더미 데이터 15개 생성
for (int i = 1; i <= 15; i++) {
ReportResponse dummy = new ReportResponse(
i, // modelId
100, // totalEpochs
i, // epoch
Math.random(), // boxLoss
Math.random(), // clsLoss
Math.random(), // dflLoss
Math.random(), // fitness
Math.random() * 10,// epochTime
Math.random() * 100 // leftSecond
);
dummyList.add(dummy);
}
return dummyList;
}
}

View File

@ -14,4 +14,8 @@ public class CacheKey {
public static String fcmTokenKey(){
return "fcmToken";
}
public static String progressStatusKey(int modelId) {
return "progress:" + modelId;
}
}