Merge branch 'be/refactor/train' into 'be/develop'
Feat: Model 조회 및 더미 데이터 API 생성 See merge request s11-s-project/S11P21S002!162
This commit is contained in:
commit
a138b04ee7
@ -3,7 +3,9 @@ package com.worlabel.domain.model.controller;
|
||||
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
|
||||
import com.worlabel.domain.model.entity.dto.AiModelRequest;
|
||||
import com.worlabel.domain.model.entity.dto.AiModelResponse;
|
||||
import com.worlabel.domain.model.entity.dto.ModelTrainRequest;
|
||||
import com.worlabel.domain.model.service.AiModelService;
|
||||
import com.worlabel.domain.progress.service.ProgressService;
|
||||
import com.worlabel.domain.project.entity.dto.ProjectRequest;
|
||||
import com.worlabel.global.annotation.CurrentUser;
|
||||
import com.worlabel.global.config.swagger.SwaggerApiError;
|
||||
@ -26,13 +28,14 @@ import java.util.List;
|
||||
public class AiModelController {
|
||||
|
||||
private final AiModelService aiModelService;
|
||||
private final ProgressService progressService;
|
||||
|
||||
@Operation(summary = "프로젝트 모델 조회", description = "프로젝트에 있는 모델을 조회합니다.")
|
||||
@SwaggerApiSuccess(description = "프로젝트 멤버를 성공적으로 조회합니다.")
|
||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||
@GetMapping("/projects/{project_id}/models")
|
||||
public List<AiModelResponse> getModelList(
|
||||
@PathVariable("project_id") final Integer projectId) {
|
||||
@PathVariable("project_id") final Integer projectId) {
|
||||
return aiModelService.getModelList(projectId);
|
||||
}
|
||||
|
||||
@ -41,7 +44,7 @@ public class AiModelController {
|
||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||
@GetMapping("/models/{model_id}/categories")
|
||||
public List<LabelCategoryResponse> getCategories(
|
||||
@PathVariable("model_id") final Integer modelId) {
|
||||
@PathVariable("model_id") final Integer modelId) {
|
||||
return aiModelService.getCategories(modelId);
|
||||
}
|
||||
|
||||
@ -50,8 +53,8 @@ public class AiModelController {
|
||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||
@PostMapping("/projects/{project_id}/models")
|
||||
public void addModel(
|
||||
@PathVariable("project_id") final Integer projectId,
|
||||
@Valid @RequestBody final AiModelRequest aiModelRequest) {
|
||||
@PathVariable("project_id") final Integer projectId,
|
||||
@Valid @RequestBody final AiModelRequest aiModelRequest) {
|
||||
aiModelService.addModel(projectId, aiModelRequest);
|
||||
}
|
||||
|
||||
@ -60,9 +63,9 @@ public class AiModelController {
|
||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||
@PutMapping("/projects/{project_id}/models/{model_id}")
|
||||
public void renameModel(
|
||||
@PathVariable("project_id") final Integer projectId,
|
||||
@PathVariable("model_id") final Integer modelId,
|
||||
@Valid @RequestBody final AiModelRequest aiModelRequest) {
|
||||
@PathVariable("project_id") final Integer projectId,
|
||||
@PathVariable("model_id") final Integer modelId,
|
||||
@Valid @RequestBody final AiModelRequest aiModelRequest) {
|
||||
aiModelService.renameModel(projectId, modelId, aiModelRequest);
|
||||
}
|
||||
|
||||
@ -71,8 +74,9 @@ public class AiModelController {
|
||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||
@PostMapping("/projects/{project_id}/train")
|
||||
public void trainModel(
|
||||
@PathVariable("project_id") final Integer projectId,
|
||||
@RequestBody final Integer modelId) {
|
||||
aiModelService.train(projectId, modelId);
|
||||
@PathVariable("project_id") final Integer projectId,
|
||||
@RequestBody final ModelTrainRequest trainRequest) {
|
||||
log.debug("모델 학습 요청 {}", trainRequest);
|
||||
aiModelService.train(projectId, trainRequest);
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,35 @@
|
||||
package com.worlabel.domain.model.entity.dto;
|
||||
|
||||
import com.worlabel.domain.result.entity.Optimizer;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import lombok.*;
|
||||
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
@Schema(name = "모델 훈련 요청 dto", description = "모델 훈련 요청 DTO")
|
||||
public class ModelTrainRequest {
|
||||
|
||||
@Schema(description = "모델 ID", example = "1")
|
||||
@NotEmpty(message = "아이디를 입력하세요")
|
||||
private Integer modelId;
|
||||
|
||||
@Schema(description = "ratio", example = "Default = 0.8")
|
||||
private double ratio;
|
||||
|
||||
@Schema(description = "epochs", example = "Default = 50")
|
||||
private int epochs;
|
||||
|
||||
@Schema(description = "batch", example = "Default = -1")
|
||||
private int batch;
|
||||
|
||||
@Schema(description = "lr0", example = "Default = 0.01")
|
||||
private double lr0;
|
||||
|
||||
@Schema(description = "lrf", example = "Default = 0.01")
|
||||
private double lrf;
|
||||
|
||||
@Schema(description = "optimizer", example = "Default = auto")
|
||||
private Optimizer optimizer;
|
||||
}
|
@ -6,19 +6,21 @@ import com.worlabel.domain.image.entity.Image;
|
||||
import com.worlabel.domain.image.entity.LabelStatus;
|
||||
import com.worlabel.domain.image.repository.ImageRepository;
|
||||
import com.worlabel.domain.labelcategory.entity.LabelCategory;
|
||||
import com.worlabel.domain.labelcategory.entity.ProjectCategory;
|
||||
import com.worlabel.domain.labelcategory.entity.dto.DefaultLabelCategoryResponse;
|
||||
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
|
||||
import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository;
|
||||
import com.worlabel.domain.model.entity.AiModel;
|
||||
import com.worlabel.domain.model.entity.dto.AiModelRequest;
|
||||
import com.worlabel.domain.model.entity.dto.AiModelResponse;
|
||||
import com.worlabel.domain.model.entity.dto.DefaultAiModelResponse;
|
||||
import com.worlabel.domain.model.entity.dto.DefaultResponse;
|
||||
import com.worlabel.domain.model.entity.dto.*;
|
||||
import com.worlabel.domain.model.repository.AiModelRepository;
|
||||
import com.worlabel.domain.participant.entity.PrivilegeType;
|
||||
import com.worlabel.domain.progress.service.ProgressService;
|
||||
import com.worlabel.domain.project.dto.AiDto;
|
||||
import com.worlabel.domain.project.dto.AiDto.TrainDataInfo;
|
||||
import com.worlabel.domain.project.dto.AiDto.TrainRequest;
|
||||
import com.worlabel.domain.project.entity.Project;
|
||||
import com.worlabel.domain.project.repository.ProjectRepository;
|
||||
import com.worlabel.domain.project.service.ProjectService;
|
||||
import com.worlabel.global.annotation.CheckPrivilege;
|
||||
import com.worlabel.global.cache.CacheKey;
|
||||
import com.worlabel.global.exception.CustomException;
|
||||
@ -34,7 +36,10 @@ import java.lang.reflect.Type;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@ -42,13 +47,15 @@ import java.util.List;
|
||||
@RequiredArgsConstructor
|
||||
public class AiModelService {
|
||||
|
||||
private final LabelCategoryRepository labelCategoryRepository;
|
||||
private final RedisTemplate<String, Object> redisTemplate;
|
||||
private final AiModelRepository aiModelRepository;
|
||||
private final ProjectRepository projectRepository;
|
||||
private final LabelCategoryRepository labelCategoryRepository;
|
||||
private final ImageRepository imageRepository;
|
||||
private final AiRequestService aiRequestService;
|
||||
private final RedisTemplate<String, Object> redisTemplate;
|
||||
private final ImageRepository imageRepository;
|
||||
private final ProjectService projectService;
|
||||
private final Gson gson;
|
||||
private final ProgressService progressService;
|
||||
|
||||
// @PostConstruct
|
||||
public void loadDefaultModel() {
|
||||
@ -127,56 +134,43 @@ public class AiModelService {
|
||||
}
|
||||
|
||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||
public void train(final Integer projectId, final Integer modelId) {
|
||||
trainProgressCheck(projectId);
|
||||
public void train(final Integer projectId, final ModelTrainRequest trainRequest) {
|
||||
// progressService.trainProgressCheck(projectId);
|
||||
|
||||
// FastAPI 서버로 학습 요청을 전송
|
||||
Project project = getProject(projectId);
|
||||
AiModel model = getModel(modelId);
|
||||
List<LabelCategory> labelCategories = labelCategoryRepository.findAllByModelId(modelId);
|
||||
List<Integer> categories = labelCategories.stream()
|
||||
.map(LabelCategory::getAiCategoryId).toList();
|
||||
AiModel model = getModel(trainRequest.getModelId());
|
||||
|
||||
Map<Integer, Integer> labelMap = project.getCategoryList().stream()
|
||||
.collect(Collectors.toMap(
|
||||
category -> category.getLabelCategory().getId(),
|
||||
ProjectCategory::getId
|
||||
));
|
||||
|
||||
List<Image> images = imageRepository.findImagesByProjectId(projectId);
|
||||
|
||||
List<AiDto.TrainDataInfo> data = images.stream().filter(image -> image.getStatus() == LabelStatus.COMPLETED)
|
||||
.map(image -> new AiDto.TrainDataInfo(image.getImagePath(), image.getDataPath()))
|
||||
List<TrainDataInfo> data = images.stream()
|
||||
.filter(image -> image.getStatus() == LabelStatus.COMPLETED)
|
||||
.map(TrainDataInfo::of)
|
||||
.toList();
|
||||
|
||||
// progressService.registerTrainProgress(projectId);
|
||||
TrainRequest aiRequest = TrainRequest.of(project.getId(), model.getModelKey(), labelMap, data, trainRequest);
|
||||
// progressService.removeTrainProgress(projectId);
|
||||
|
||||
String endPoint = project.getProjectType().getValue() + "/train";
|
||||
|
||||
AiDto.TrainRequest trainRequest = new AiDto.TrainRequest();
|
||||
trainRequest.setProjectId(projectId);
|
||||
trainRequest.setCategoryId(categories);
|
||||
trainRequest.setData(data);
|
||||
trainRequest.setModelKey(model.getModelKey());
|
||||
|
||||
// FastAPI 서버로 POST 요청 전송
|
||||
String modelKey = aiRequestService.postRequest(endPoint, trainRequest, String.class, response -> response);
|
||||
String modelKey = aiRequestService.postRequest(endPoint, aiRequest, String.class, response -> response);
|
||||
|
||||
// 가져온 modelKey -> version 업된 모델 다시 새롭게 저장
|
||||
String currentDateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
|
||||
String currentDateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMdd_HHmm"));
|
||||
int newVersion = model.getVersion() + 1;
|
||||
String newName = currentDateTime + String.format("%03d", newVersion);
|
||||
|
||||
AiModel newModel = AiModel.of(currentDateTime, modelKey, model.getVersion() + 1, project);
|
||||
AiModel newModel = AiModel.of(newName, modelKey, newVersion, project);
|
||||
aiModelRepository.save(newModel);
|
||||
}
|
||||
|
||||
/**
|
||||
* Redis 중복 요청 체크
|
||||
*/
|
||||
private void trainProgressCheck(Integer projectId) {
|
||||
String trainProgressKey = CacheKey.trainProgressKey();
|
||||
|
||||
// 존재 확인
|
||||
Boolean isProjectExist = redisTemplate.opsForSet().isMember(trainProgressKey, projectId);
|
||||
if (Boolean.TRUE.equals(isProjectExist)) {
|
||||
throw new CustomException(ErrorCode.AI_IN_PROGRESS);
|
||||
}
|
||||
|
||||
// 학습 진행 중으로 상태 등록
|
||||
redisTemplate.opsForSet().add(trainProgressKey, projectId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Json -> List<DefaultResponse>
|
||||
*/
|
||||
|
@ -1,37 +1,91 @@
|
||||
package com.worlabel.domain.progress.repository;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.worlabel.domain.report.entity.dto.ReportResponse;
|
||||
import com.worlabel.global.cache.CacheKey;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.data.redis.core.RedisTemplate;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Repository
|
||||
@RequiredArgsConstructor
|
||||
public class ProgressCacheRepository {
|
||||
|
||||
private final RedisTemplate<String, Object> redisTemplate;
|
||||
private final RedisTemplate<String, String> redisTemplate;
|
||||
private final Gson gson;
|
||||
|
||||
/**
|
||||
* 현재 오토레이블링중인지 확인하는 메서드
|
||||
*/
|
||||
public boolean predictCheck(final int projectId) {
|
||||
String key = CacheKey.autoLabelingProgressKey();
|
||||
Boolean isProgress = redisTemplate.opsForSet().isMember(key, projectId);
|
||||
public boolean predictProgressCheck(final int projectId) {
|
||||
Boolean isProgress = redisTemplate.opsForSet().isMember(CacheKey.autoLabelingProgressKey(), String.valueOf(projectId));
|
||||
return Boolean.TRUE.equals(isProgress);
|
||||
}
|
||||
|
||||
/**
|
||||
* 학습 진행 중 등록 메서드
|
||||
* 오토레이블링 진행 중 등록 메서드
|
||||
*/
|
||||
public void registerPredictProgress(final int projectId) {
|
||||
String key = CacheKey.autoLabelingProgressKey();
|
||||
redisTemplate.opsForSet().add(key, projectId);
|
||||
redisTemplate.opsForSet().add(CacheKey.autoLabelingProgressKey(), String.valueOf(projectId));
|
||||
}
|
||||
|
||||
/**
|
||||
* 오토레이블링 진행 제거 메서드
|
||||
*/
|
||||
public void removePredictProgress(final int projectId) {
|
||||
String key = CacheKey.autoLabelingProgressKey();
|
||||
redisTemplate.opsForSet().remove(key, projectId);
|
||||
redisTemplate.opsForSet().remove(CacheKey.autoLabelingProgressKey(), String.valueOf(projectId));
|
||||
}
|
||||
|
||||
/**
|
||||
* 학습 진행 확인 메서드
|
||||
*/
|
||||
public boolean trainProgressCheck(final int projectId) {
|
||||
Boolean isProgress = redisTemplate.opsForSet().isMember(CacheKey.trainProgressKey(), String.valueOf(projectId));
|
||||
return Boolean.TRUE.equals(isProgress);
|
||||
}
|
||||
|
||||
/**
|
||||
* 학습 진행 등록 메서드
|
||||
*/
|
||||
public void registerTrainProgress(final int projectId) {
|
||||
redisTemplate.opsForSet().add(CacheKey.trainProgressKey(), String.valueOf(projectId));
|
||||
}
|
||||
|
||||
/**
|
||||
* 학습 진행 제거 메서드
|
||||
*/
|
||||
public void removeTrainProgress(final int projectId) {
|
||||
redisTemplate.opsForSet().remove(CacheKey.trainProgressKey(), String.valueOf(projectId));
|
||||
}
|
||||
|
||||
/**
|
||||
* 진행 상황을 Redis에 추가
|
||||
*/
|
||||
public void addProgressModel(final int modelId,final String data){
|
||||
ReportResponse reportResponse = convert(data);
|
||||
redisTemplate.opsForList().rightPush(CacheKey.progressStatusKey(modelId), gson.toJson(reportResponse));
|
||||
}
|
||||
|
||||
public List<ReportResponse> getProgressModel(final int modelId) {
|
||||
// 저장된 걸 주어진 응답에 맞추어 리턴
|
||||
String key = CacheKey.progressStatusKey(modelId);
|
||||
List<String> progressList = redisTemplate.opsForList().range(key, 0, -1);
|
||||
|
||||
return progressList.stream()
|
||||
.map(this::convert)
|
||||
.toList();
|
||||
}
|
||||
|
||||
public void clearProgressModel(final int modelId) {
|
||||
redisTemplate.delete(CacheKey.progressStatusKey(modelId));
|
||||
}
|
||||
|
||||
private ReportResponse convert(String data){
|
||||
return gson.fromJson(data, ReportResponse.class);
|
||||
}
|
||||
}
|
||||
|
@ -1,12 +1,15 @@
|
||||
package com.worlabel.domain.progress.service;
|
||||
|
||||
import com.worlabel.domain.progress.repository.ProgressCacheRepository;
|
||||
import com.worlabel.domain.report.entity.dto.ReportResponse;
|
||||
import com.worlabel.global.exception.CustomException;
|
||||
import com.worlabel.global.exception.ErrorCode;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
@ -14,14 +17,39 @@ public class ProgressService {
|
||||
|
||||
private final ProgressCacheRepository progressCacheRepository;
|
||||
|
||||
public void predictCheck(final int projectId){
|
||||
if(progressCacheRepository.predictCheck(projectId)){
|
||||
// throw new CustomException(ErrorCode.AI_IN_PROGRESS);
|
||||
progressCacheRepository.removePredictProgress(projectId);
|
||||
public void predictProgressCheck(final int projectId){
|
||||
if(progressCacheRepository.predictProgressCheck(projectId)){
|
||||
throw new CustomException(ErrorCode.AI_IN_PROGRESS);
|
||||
}
|
||||
}
|
||||
|
||||
public void registerPredictProgress(final int projectId){
|
||||
progressCacheRepository.registerPredictProgress(projectId);
|
||||
}
|
||||
|
||||
public void removePredictProgress(final int projectId){
|
||||
progressCacheRepository.removePredictProgress(projectId);
|
||||
}
|
||||
|
||||
public void trainProgressCheck(final int projectId){
|
||||
if(progressCacheRepository.trainProgressCheck(projectId)){
|
||||
throw new CustomException(ErrorCode.AI_IN_PROGRESS);
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isProgressTrain(final int projectId){
|
||||
return progressCacheRepository.trainProgressCheck(projectId);
|
||||
}
|
||||
|
||||
public void registerTrainProgress(final int projectId){
|
||||
progressCacheRepository.registerTrainProgress(projectId);
|
||||
}
|
||||
|
||||
public void removeTrainProgress(final int projectId){
|
||||
progressCacheRepository.removeTrainProgress(projectId);
|
||||
}
|
||||
|
||||
public List<ReportResponse> getProgressResponse(final int modelId) {
|
||||
return progressCacheRepository.getProgressModel(modelId);
|
||||
}
|
||||
}
|
||||
|
@ -3,41 +3,76 @@ package com.worlabel.domain.project.dto;
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.google.gson.annotations.SerializedName;
|
||||
import com.worlabel.domain.image.entity.Image;
|
||||
import com.worlabel.domain.model.entity.dto.ModelTrainRequest;
|
||||
import com.worlabel.domain.result.entity.Optimizer;
|
||||
import lombok.*;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class AiDto {
|
||||
|
||||
@Data
|
||||
@Getter
|
||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class TrainDataInfo {
|
||||
|
||||
@JsonProperty("image_url")
|
||||
private String imagePath;
|
||||
|
||||
@JsonProperty("data_url")
|
||||
private String dataPath;
|
||||
|
||||
public TrainDataInfo(String imagePath, String dataPath) {
|
||||
this.imagePath = imagePath;
|
||||
this.dataPath = dataPath;
|
||||
public static TrainDataInfo of(Image image) {
|
||||
return new TrainDataInfo(image.getImagePath(), image.getDataPath());
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
@Getter
|
||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class TrainRequest {
|
||||
|
||||
@JsonProperty("project_id")
|
||||
private int projectId;
|
||||
|
||||
@JsonProperty("category_id")
|
||||
private List<Integer> categoryId;
|
||||
@JsonProperty("model_key")
|
||||
private String modelKey;
|
||||
|
||||
@JsonProperty("label_map")
|
||||
private Map<Integer,Integer> labelMap;
|
||||
|
||||
@JsonProperty("data")
|
||||
private List<TrainDataInfo> data;
|
||||
|
||||
@JsonProperty("model_key")
|
||||
private String modelKey;
|
||||
// private int seed; // Optional
|
||||
// private float ratio; // Default = 0.8
|
||||
// private int epochs; // Default = 50
|
||||
// private float batch; // Default = -1
|
||||
private double ratio; // Default = 0.8
|
||||
|
||||
private int epochs; // Default = 50
|
||||
|
||||
private double batch; // Default = -1
|
||||
|
||||
private double lr0;
|
||||
|
||||
private double lrf;
|
||||
|
||||
private Optimizer optimizer;
|
||||
|
||||
public static TrainRequest of(final Integer projectId, final String modelKey, final Map<Integer, Integer> labelMap, final List<TrainDataInfo> data, final ModelTrainRequest trainRequest) {
|
||||
TrainRequest request = new TrainRequest();
|
||||
request.projectId = projectId;
|
||||
request.modelKey = modelKey;
|
||||
request.labelMap = labelMap;
|
||||
request.data = data;
|
||||
request.ratio = request.getRatio();
|
||||
request.epochs = trainRequest.getEpochs();
|
||||
request.batch = trainRequest.getBatch();
|
||||
request.lr0 = trainRequest.getLr0();
|
||||
request.lrf = trainRequest.getLrf();
|
||||
request.optimizer = trainRequest.getOptimizer();
|
||||
|
||||
return request;
|
||||
}
|
||||
}
|
||||
|
||||
@Getter
|
||||
@ -89,7 +124,7 @@ public class AiDto {
|
||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
@Getter
|
||||
@ToString
|
||||
public static class AutoLabelingResult{
|
||||
public static class AutoLabelingResult {
|
||||
|
||||
@SerializedName("image_id")
|
||||
private Long imageId;
|
||||
|
@ -162,7 +162,7 @@ public class ProjectService {
|
||||
*/
|
||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||
public void autoLabeling(final Integer projectId, final AutoModelRequest request) {
|
||||
progressService.predictCheck(projectId);
|
||||
// progressService.predictCheck(projectId);
|
||||
|
||||
Project project = getProject(projectId);
|
||||
String endPoint = project.getProjectType().getValue() + "/predict";
|
||||
|
@ -11,14 +11,14 @@ import org.springframework.web.bind.annotation.RestController;
|
||||
import java.util.List;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/api/reports")
|
||||
@RequestMapping("/api/projects/{project_id}/reports")
|
||||
@RequiredArgsConstructor
|
||||
public class ReportController {
|
||||
|
||||
private final ReportService reportService;
|
||||
|
||||
@GetMapping("/model/{model_id}")
|
||||
public List<ReportResponse> getReportsByModelId(@PathVariable("model_id") final Integer modelId) {
|
||||
return reportService.getReportsByModelId(modelId);
|
||||
public List<ReportResponse> getReportsByModelId(@PathVariable("model_id") final Integer modelId, @PathVariable("project_id") final Integer projectId) {
|
||||
return reportService.getReportsByModelId(projectId,modelId);
|
||||
}
|
||||
}
|
@ -25,17 +25,18 @@ public class Report extends BaseEntity {
|
||||
@JoinColumn(name = "model_id", nullable = false)
|
||||
private AiModel aiModel;
|
||||
|
||||
/**
|
||||
* 현재 에포크
|
||||
*/
|
||||
@Column(name = "epoch", nullable = false)
|
||||
private Integer epoch;
|
||||
|
||||
/**
|
||||
* 전체 에포크
|
||||
*/
|
||||
@Column(name = "total_epochs", nullable = false)
|
||||
private Integer totalEpochs;
|
||||
|
||||
/**
|
||||
* 현재 에포크
|
||||
*/
|
||||
@Column(name = "epoch", nullable = false)
|
||||
private Integer epoch;
|
||||
|
||||
@Column(name = "box_loss", nullable = false)
|
||||
private double boxLoss;
|
||||
@ -48,4 +49,10 @@ public class Report extends BaseEntity {
|
||||
|
||||
@Column(name = "fitness", nullable = false)
|
||||
private double fitness;
|
||||
|
||||
@Column(name = "epoch_time", nullable = false)
|
||||
private double epochTime;
|
||||
|
||||
@Column(name = "left_second", nullable = false)
|
||||
private double leftSecond;
|
||||
}
|
||||
|
@ -4,26 +4,33 @@ import com.worlabel.domain.report.entity.Report;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Getter
|
||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class ReportResponse {
|
||||
private Integer id;
|
||||
private Integer totalEpochs;
|
||||
private Integer epoch;
|
||||
private int modelId;
|
||||
private int totalEpochs;
|
||||
private int epoch;
|
||||
private double boxLoss;
|
||||
private double clsLoss;
|
||||
private double dflLoss;
|
||||
private double fitness;
|
||||
private double epochTime;
|
||||
private double leftSecond;
|
||||
|
||||
public static ReportResponse from(final Report report) {
|
||||
return new ReportResponse(
|
||||
report.getId(),
|
||||
report.getAiModel().getId(),
|
||||
report.getTotalEpochs(),
|
||||
report.getEpoch(),
|
||||
report.getBoxLoss(),
|
||||
report.getClsLoss(),
|
||||
report.getDflLoss(),
|
||||
report.getFitness());
|
||||
report.getFitness(),
|
||||
report.getEpochTime(),
|
||||
report.getLeftSecond()
|
||||
);
|
||||
}
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
package com.worlabel.domain.report.service;
|
||||
|
||||
import com.worlabel.domain.progress.service.ProgressService;
|
||||
import com.worlabel.domain.report.entity.Report;
|
||||
import com.worlabel.domain.report.entity.dto.ReportResponse;
|
||||
import com.worlabel.domain.report.repository.ReportRepository;
|
||||
@ -7,6 +8,7 @@ import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Service
|
||||
@ -15,11 +17,43 @@ import java.util.List;
|
||||
public class ReportService {
|
||||
|
||||
private final ReportRepository reportRepository;
|
||||
private final ProgressService progressService;
|
||||
|
||||
public List<ReportResponse> getReportsByModelId(final Integer modelId) {
|
||||
List<Report> reports = reportRepository.findByAiModelId(modelId);
|
||||
return reports.stream()
|
||||
.map(ReportResponse::from)
|
||||
.toList();
|
||||
public List<ReportResponse> getReportsByModelId(final Integer projectId, final Integer modelId) {
|
||||
// 진행중이면 진행중에서 받아오기
|
||||
return getDummyList();
|
||||
// if(progressService.isProgressTrain(projectId)){
|
||||
// return progressService.getProgressResponse(modelId);
|
||||
// }
|
||||
// // 작업 완료시에는 RDB
|
||||
// else{
|
||||
// List<Report> reports = reportRepository.findByAiModelId(modelId);
|
||||
// return reports.stream()
|
||||
// .map(ReportResponse::from)
|
||||
// .toList();
|
||||
// }
|
||||
}
|
||||
|
||||
private List<ReportResponse> getDummyList() {
|
||||
List<ReportResponse> dummyList = new ArrayList<>();
|
||||
|
||||
// 더미 데이터 15개 생성
|
||||
for (int i = 1; i <= 15; i++) {
|
||||
ReportResponse dummy = new ReportResponse(
|
||||
i, // modelId
|
||||
100, // totalEpochs
|
||||
i, // epoch
|
||||
Math.random(), // boxLoss
|
||||
Math.random(), // clsLoss
|
||||
Math.random(), // dflLoss
|
||||
Math.random(), // fitness
|
||||
Math.random() * 10,// epochTime
|
||||
Math.random() * 100 // leftSecond
|
||||
);
|
||||
dummyList.add(dummy);
|
||||
}
|
||||
|
||||
return dummyList;
|
||||
}
|
||||
|
||||
}
|
@ -14,4 +14,8 @@ public class CacheKey {
|
||||
public static String fcmTokenKey(){
|
||||
return "fcmToken";
|
||||
}
|
||||
|
||||
public static String progressStatusKey(int modelId) {
|
||||
return "progress:" + modelId;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user