Merge branch 'be/refactor/train' into 'be/develop'
Feat: Model 조회 및 더미 데이터 API 생성 See merge request s11-s-project/S11P21S002!162
This commit is contained in:
commit
a138b04ee7
@ -3,7 +3,9 @@ package com.worlabel.domain.model.controller;
|
|||||||
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
|
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
|
||||||
import com.worlabel.domain.model.entity.dto.AiModelRequest;
|
import com.worlabel.domain.model.entity.dto.AiModelRequest;
|
||||||
import com.worlabel.domain.model.entity.dto.AiModelResponse;
|
import com.worlabel.domain.model.entity.dto.AiModelResponse;
|
||||||
|
import com.worlabel.domain.model.entity.dto.ModelTrainRequest;
|
||||||
import com.worlabel.domain.model.service.AiModelService;
|
import com.worlabel.domain.model.service.AiModelService;
|
||||||
|
import com.worlabel.domain.progress.service.ProgressService;
|
||||||
import com.worlabel.domain.project.entity.dto.ProjectRequest;
|
import com.worlabel.domain.project.entity.dto.ProjectRequest;
|
||||||
import com.worlabel.global.annotation.CurrentUser;
|
import com.worlabel.global.annotation.CurrentUser;
|
||||||
import com.worlabel.global.config.swagger.SwaggerApiError;
|
import com.worlabel.global.config.swagger.SwaggerApiError;
|
||||||
@ -26,13 +28,14 @@ import java.util.List;
|
|||||||
public class AiModelController {
|
public class AiModelController {
|
||||||
|
|
||||||
private final AiModelService aiModelService;
|
private final AiModelService aiModelService;
|
||||||
|
private final ProgressService progressService;
|
||||||
|
|
||||||
@Operation(summary = "프로젝트 모델 조회", description = "프로젝트에 있는 모델을 조회합니다.")
|
@Operation(summary = "프로젝트 모델 조회", description = "프로젝트에 있는 모델을 조회합니다.")
|
||||||
@SwaggerApiSuccess(description = "프로젝트 멤버를 성공적으로 조회합니다.")
|
@SwaggerApiSuccess(description = "프로젝트 멤버를 성공적으로 조회합니다.")
|
||||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||||
@GetMapping("/projects/{project_id}/models")
|
@GetMapping("/projects/{project_id}/models")
|
||||||
public List<AiModelResponse> getModelList(
|
public List<AiModelResponse> getModelList(
|
||||||
@PathVariable("project_id") final Integer projectId) {
|
@PathVariable("project_id") final Integer projectId) {
|
||||||
return aiModelService.getModelList(projectId);
|
return aiModelService.getModelList(projectId);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,7 +44,7 @@ public class AiModelController {
|
|||||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||||
@GetMapping("/models/{model_id}/categories")
|
@GetMapping("/models/{model_id}/categories")
|
||||||
public List<LabelCategoryResponse> getCategories(
|
public List<LabelCategoryResponse> getCategories(
|
||||||
@PathVariable("model_id") final Integer modelId) {
|
@PathVariable("model_id") final Integer modelId) {
|
||||||
return aiModelService.getCategories(modelId);
|
return aiModelService.getCategories(modelId);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,8 +53,8 @@ public class AiModelController {
|
|||||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||||
@PostMapping("/projects/{project_id}/models")
|
@PostMapping("/projects/{project_id}/models")
|
||||||
public void addModel(
|
public void addModel(
|
||||||
@PathVariable("project_id") final Integer projectId,
|
@PathVariable("project_id") final Integer projectId,
|
||||||
@Valid @RequestBody final AiModelRequest aiModelRequest) {
|
@Valid @RequestBody final AiModelRequest aiModelRequest) {
|
||||||
aiModelService.addModel(projectId, aiModelRequest);
|
aiModelService.addModel(projectId, aiModelRequest);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,9 +63,9 @@ public class AiModelController {
|
|||||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||||
@PutMapping("/projects/{project_id}/models/{model_id}")
|
@PutMapping("/projects/{project_id}/models/{model_id}")
|
||||||
public void renameModel(
|
public void renameModel(
|
||||||
@PathVariable("project_id") final Integer projectId,
|
@PathVariable("project_id") final Integer projectId,
|
||||||
@PathVariable("model_id") final Integer modelId,
|
@PathVariable("model_id") final Integer modelId,
|
||||||
@Valid @RequestBody final AiModelRequest aiModelRequest) {
|
@Valid @RequestBody final AiModelRequest aiModelRequest) {
|
||||||
aiModelService.renameModel(projectId, modelId, aiModelRequest);
|
aiModelService.renameModel(projectId, modelId, aiModelRequest);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,8 +74,9 @@ public class AiModelController {
|
|||||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||||
@PostMapping("/projects/{project_id}/train")
|
@PostMapping("/projects/{project_id}/train")
|
||||||
public void trainModel(
|
public void trainModel(
|
||||||
@PathVariable("project_id") final Integer projectId,
|
@PathVariable("project_id") final Integer projectId,
|
||||||
@RequestBody final Integer modelId) {
|
@RequestBody final ModelTrainRequest trainRequest) {
|
||||||
aiModelService.train(projectId, modelId);
|
log.debug("모델 학습 요청 {}", trainRequest);
|
||||||
|
aiModelService.train(projectId, trainRequest);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,35 @@
|
|||||||
|
package com.worlabel.domain.model.entity.dto;
|
||||||
|
|
||||||
|
import com.worlabel.domain.result.entity.Optimizer;
|
||||||
|
import io.swagger.v3.oas.annotations.media.Schema;
|
||||||
|
import jakarta.validation.constraints.NotEmpty;
|
||||||
|
import lombok.*;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@AllArgsConstructor
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
@Schema(name = "모델 훈련 요청 dto", description = "모델 훈련 요청 DTO")
|
||||||
|
public class ModelTrainRequest {
|
||||||
|
|
||||||
|
@Schema(description = "모델 ID", example = "1")
|
||||||
|
@NotEmpty(message = "아이디를 입력하세요")
|
||||||
|
private Integer modelId;
|
||||||
|
|
||||||
|
@Schema(description = "ratio", example = "Default = 0.8")
|
||||||
|
private double ratio;
|
||||||
|
|
||||||
|
@Schema(description = "epochs", example = "Default = 50")
|
||||||
|
private int epochs;
|
||||||
|
|
||||||
|
@Schema(description = "batch", example = "Default = -1")
|
||||||
|
private int batch;
|
||||||
|
|
||||||
|
@Schema(description = "lr0", example = "Default = 0.01")
|
||||||
|
private double lr0;
|
||||||
|
|
||||||
|
@Schema(description = "lrf", example = "Default = 0.01")
|
||||||
|
private double lrf;
|
||||||
|
|
||||||
|
@Schema(description = "optimizer", example = "Default = auto")
|
||||||
|
private Optimizer optimizer;
|
||||||
|
}
|
@ -6,19 +6,21 @@ import com.worlabel.domain.image.entity.Image;
|
|||||||
import com.worlabel.domain.image.entity.LabelStatus;
|
import com.worlabel.domain.image.entity.LabelStatus;
|
||||||
import com.worlabel.domain.image.repository.ImageRepository;
|
import com.worlabel.domain.image.repository.ImageRepository;
|
||||||
import com.worlabel.domain.labelcategory.entity.LabelCategory;
|
import com.worlabel.domain.labelcategory.entity.LabelCategory;
|
||||||
|
import com.worlabel.domain.labelcategory.entity.ProjectCategory;
|
||||||
import com.worlabel.domain.labelcategory.entity.dto.DefaultLabelCategoryResponse;
|
import com.worlabel.domain.labelcategory.entity.dto.DefaultLabelCategoryResponse;
|
||||||
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
|
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
|
||||||
import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository;
|
import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository;
|
||||||
import com.worlabel.domain.model.entity.AiModel;
|
import com.worlabel.domain.model.entity.AiModel;
|
||||||
import com.worlabel.domain.model.entity.dto.AiModelRequest;
|
import com.worlabel.domain.model.entity.dto.*;
|
||||||
import com.worlabel.domain.model.entity.dto.AiModelResponse;
|
|
||||||
import com.worlabel.domain.model.entity.dto.DefaultAiModelResponse;
|
|
||||||
import com.worlabel.domain.model.entity.dto.DefaultResponse;
|
|
||||||
import com.worlabel.domain.model.repository.AiModelRepository;
|
import com.worlabel.domain.model.repository.AiModelRepository;
|
||||||
import com.worlabel.domain.participant.entity.PrivilegeType;
|
import com.worlabel.domain.participant.entity.PrivilegeType;
|
||||||
|
import com.worlabel.domain.progress.service.ProgressService;
|
||||||
import com.worlabel.domain.project.dto.AiDto;
|
import com.worlabel.domain.project.dto.AiDto;
|
||||||
|
import com.worlabel.domain.project.dto.AiDto.TrainDataInfo;
|
||||||
|
import com.worlabel.domain.project.dto.AiDto.TrainRequest;
|
||||||
import com.worlabel.domain.project.entity.Project;
|
import com.worlabel.domain.project.entity.Project;
|
||||||
import com.worlabel.domain.project.repository.ProjectRepository;
|
import com.worlabel.domain.project.repository.ProjectRepository;
|
||||||
|
import com.worlabel.domain.project.service.ProjectService;
|
||||||
import com.worlabel.global.annotation.CheckPrivilege;
|
import com.worlabel.global.annotation.CheckPrivilege;
|
||||||
import com.worlabel.global.cache.CacheKey;
|
import com.worlabel.global.cache.CacheKey;
|
||||||
import com.worlabel.global.exception.CustomException;
|
import com.worlabel.global.exception.CustomException;
|
||||||
@ -34,7 +36,10 @@ import java.lang.reflect.Type;
|
|||||||
import java.time.LocalDateTime;
|
import java.time.LocalDateTime;
|
||||||
import java.time.format.DateTimeFormatter;
|
import java.time.format.DateTimeFormatter;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Service
|
@Service
|
||||||
@ -42,13 +47,15 @@ import java.util.List;
|
|||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class AiModelService {
|
public class AiModelService {
|
||||||
|
|
||||||
|
private final LabelCategoryRepository labelCategoryRepository;
|
||||||
|
private final RedisTemplate<String, Object> redisTemplate;
|
||||||
private final AiModelRepository aiModelRepository;
|
private final AiModelRepository aiModelRepository;
|
||||||
private final ProjectRepository projectRepository;
|
private final ProjectRepository projectRepository;
|
||||||
private final LabelCategoryRepository labelCategoryRepository;
|
|
||||||
private final ImageRepository imageRepository;
|
|
||||||
private final AiRequestService aiRequestService;
|
private final AiRequestService aiRequestService;
|
||||||
private final RedisTemplate<String, Object> redisTemplate;
|
private final ImageRepository imageRepository;
|
||||||
|
private final ProjectService projectService;
|
||||||
private final Gson gson;
|
private final Gson gson;
|
||||||
|
private final ProgressService progressService;
|
||||||
|
|
||||||
// @PostConstruct
|
// @PostConstruct
|
||||||
public void loadDefaultModel() {
|
public void loadDefaultModel() {
|
||||||
@ -127,56 +134,43 @@ public class AiModelService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||||
public void train(final Integer projectId, final Integer modelId) {
|
public void train(final Integer projectId, final ModelTrainRequest trainRequest) {
|
||||||
trainProgressCheck(projectId);
|
// progressService.trainProgressCheck(projectId);
|
||||||
|
|
||||||
// FastAPI 서버로 학습 요청을 전송
|
// FastAPI 서버로 학습 요청을 전송
|
||||||
Project project = getProject(projectId);
|
Project project = getProject(projectId);
|
||||||
AiModel model = getModel(modelId);
|
AiModel model = getModel(trainRequest.getModelId());
|
||||||
List<LabelCategory> labelCategories = labelCategoryRepository.findAllByModelId(modelId);
|
|
||||||
List<Integer> categories = labelCategories.stream()
|
Map<Integer, Integer> labelMap = project.getCategoryList().stream()
|
||||||
.map(LabelCategory::getAiCategoryId).toList();
|
.collect(Collectors.toMap(
|
||||||
|
category -> category.getLabelCategory().getId(),
|
||||||
|
ProjectCategory::getId
|
||||||
|
));
|
||||||
|
|
||||||
List<Image> images = imageRepository.findImagesByProjectId(projectId);
|
List<Image> images = imageRepository.findImagesByProjectId(projectId);
|
||||||
|
List<TrainDataInfo> data = images.stream()
|
||||||
List<AiDto.TrainDataInfo> data = images.stream().filter(image -> image.getStatus() == LabelStatus.COMPLETED)
|
.filter(image -> image.getStatus() == LabelStatus.COMPLETED)
|
||||||
.map(image -> new AiDto.TrainDataInfo(image.getImagePath(), image.getDataPath()))
|
.map(TrainDataInfo::of)
|
||||||
.toList();
|
.toList();
|
||||||
|
|
||||||
|
// progressService.registerTrainProgress(projectId);
|
||||||
|
TrainRequest aiRequest = TrainRequest.of(project.getId(), model.getModelKey(), labelMap, data, trainRequest);
|
||||||
|
// progressService.removeTrainProgress(projectId);
|
||||||
|
|
||||||
String endPoint = project.getProjectType().getValue() + "/train";
|
String endPoint = project.getProjectType().getValue() + "/train";
|
||||||
|
|
||||||
AiDto.TrainRequest trainRequest = new AiDto.TrainRequest();
|
|
||||||
trainRequest.setProjectId(projectId);
|
|
||||||
trainRequest.setCategoryId(categories);
|
|
||||||
trainRequest.setData(data);
|
|
||||||
trainRequest.setModelKey(model.getModelKey());
|
|
||||||
|
|
||||||
// FastAPI 서버로 POST 요청 전송
|
// FastAPI 서버로 POST 요청 전송
|
||||||
String modelKey = aiRequestService.postRequest(endPoint, trainRequest, String.class, response -> response);
|
String modelKey = aiRequestService.postRequest(endPoint, aiRequest, String.class, response -> response);
|
||||||
|
|
||||||
// 가져온 modelKey -> version 업된 모델 다시 새롭게 저장
|
// 가져온 modelKey -> version 업된 모델 다시 새롭게 저장
|
||||||
String currentDateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
|
String currentDateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMdd_HHmm"));
|
||||||
|
int newVersion = model.getVersion() + 1;
|
||||||
|
String newName = currentDateTime + String.format("%03d", newVersion);
|
||||||
|
|
||||||
AiModel newModel = AiModel.of(currentDateTime, modelKey, model.getVersion() + 1, project);
|
AiModel newModel = AiModel.of(newName, modelKey, newVersion, project);
|
||||||
aiModelRepository.save(newModel);
|
aiModelRepository.save(newModel);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Redis 중복 요청 체크
|
|
||||||
*/
|
|
||||||
private void trainProgressCheck(Integer projectId) {
|
|
||||||
String trainProgressKey = CacheKey.trainProgressKey();
|
|
||||||
|
|
||||||
// 존재 확인
|
|
||||||
Boolean isProjectExist = redisTemplate.opsForSet().isMember(trainProgressKey, projectId);
|
|
||||||
if (Boolean.TRUE.equals(isProjectExist)) {
|
|
||||||
throw new CustomException(ErrorCode.AI_IN_PROGRESS);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 학습 진행 중으로 상태 등록
|
|
||||||
redisTemplate.opsForSet().add(trainProgressKey, projectId);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Json -> List<DefaultResponse>
|
* Json -> List<DefaultResponse>
|
||||||
*/
|
*/
|
||||||
|
@ -1,37 +1,91 @@
|
|||||||
package com.worlabel.domain.progress.repository;
|
package com.worlabel.domain.progress.repository;
|
||||||
|
|
||||||
|
import com.google.gson.Gson;
|
||||||
|
import com.worlabel.domain.report.entity.dto.ReportResponse;
|
||||||
import com.worlabel.global.cache.CacheKey;
|
import com.worlabel.global.cache.CacheKey;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.data.redis.core.RedisTemplate;
|
import org.springframework.data.redis.core.RedisTemplate;
|
||||||
import org.springframework.stereotype.Repository;
|
import org.springframework.stereotype.Repository;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Repository
|
@Repository
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class ProgressCacheRepository {
|
public class ProgressCacheRepository {
|
||||||
|
|
||||||
private final RedisTemplate<String, Object> redisTemplate;
|
private final RedisTemplate<String, String> redisTemplate;
|
||||||
|
private final Gson gson;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 현재 오토레이블링중인지 확인하는 메서드
|
* 현재 오토레이블링중인지 확인하는 메서드
|
||||||
*/
|
*/
|
||||||
public boolean predictCheck(final int projectId) {
|
public boolean predictProgressCheck(final int projectId) {
|
||||||
String key = CacheKey.autoLabelingProgressKey();
|
Boolean isProgress = redisTemplate.opsForSet().isMember(CacheKey.autoLabelingProgressKey(), String.valueOf(projectId));
|
||||||
Boolean isProgress = redisTemplate.opsForSet().isMember(key, projectId);
|
|
||||||
return Boolean.TRUE.equals(isProgress);
|
return Boolean.TRUE.equals(isProgress);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 학습 진행 중 등록 메서드
|
* 오토레이블링 진행 중 등록 메서드
|
||||||
*/
|
*/
|
||||||
public void registerPredictProgress(final int projectId) {
|
public void registerPredictProgress(final int projectId) {
|
||||||
String key = CacheKey.autoLabelingProgressKey();
|
redisTemplate.opsForSet().add(CacheKey.autoLabelingProgressKey(), String.valueOf(projectId));
|
||||||
redisTemplate.opsForSet().add(key, projectId);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 오토레이블링 진행 제거 메서드
|
||||||
|
*/
|
||||||
public void removePredictProgress(final int projectId) {
|
public void removePredictProgress(final int projectId) {
|
||||||
String key = CacheKey.autoLabelingProgressKey();
|
redisTemplate.opsForSet().remove(CacheKey.autoLabelingProgressKey(), String.valueOf(projectId));
|
||||||
redisTemplate.opsForSet().remove(key, projectId);
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 학습 진행 확인 메서드
|
||||||
|
*/
|
||||||
|
public boolean trainProgressCheck(final int projectId) {
|
||||||
|
Boolean isProgress = redisTemplate.opsForSet().isMember(CacheKey.trainProgressKey(), String.valueOf(projectId));
|
||||||
|
return Boolean.TRUE.equals(isProgress);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 학습 진행 등록 메서드
|
||||||
|
*/
|
||||||
|
public void registerTrainProgress(final int projectId) {
|
||||||
|
redisTemplate.opsForSet().add(CacheKey.trainProgressKey(), String.valueOf(projectId));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 학습 진행 제거 메서드
|
||||||
|
*/
|
||||||
|
public void removeTrainProgress(final int projectId) {
|
||||||
|
redisTemplate.opsForSet().remove(CacheKey.trainProgressKey(), String.valueOf(projectId));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 진행 상황을 Redis에 추가
|
||||||
|
*/
|
||||||
|
public void addProgressModel(final int modelId,final String data){
|
||||||
|
ReportResponse reportResponse = convert(data);
|
||||||
|
redisTemplate.opsForList().rightPush(CacheKey.progressStatusKey(modelId), gson.toJson(reportResponse));
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<ReportResponse> getProgressModel(final int modelId) {
|
||||||
|
// 저장된 걸 주어진 응답에 맞추어 리턴
|
||||||
|
String key = CacheKey.progressStatusKey(modelId);
|
||||||
|
List<String> progressList = redisTemplate.opsForList().range(key, 0, -1);
|
||||||
|
|
||||||
|
return progressList.stream()
|
||||||
|
.map(this::convert)
|
||||||
|
.toList();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void clearProgressModel(final int modelId) {
|
||||||
|
redisTemplate.delete(CacheKey.progressStatusKey(modelId));
|
||||||
|
}
|
||||||
|
|
||||||
|
private ReportResponse convert(String data){
|
||||||
|
return gson.fromJson(data, ReportResponse.class);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
package com.worlabel.domain.progress.service;
|
package com.worlabel.domain.progress.service;
|
||||||
|
|
||||||
import com.worlabel.domain.progress.repository.ProgressCacheRepository;
|
import com.worlabel.domain.progress.repository.ProgressCacheRepository;
|
||||||
|
import com.worlabel.domain.report.entity.dto.ReportResponse;
|
||||||
import com.worlabel.global.exception.CustomException;
|
import com.worlabel.global.exception.CustomException;
|
||||||
import com.worlabel.global.exception.ErrorCode;
|
import com.worlabel.global.exception.ErrorCode;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Service
|
@Service
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
@ -14,14 +17,39 @@ public class ProgressService {
|
|||||||
|
|
||||||
private final ProgressCacheRepository progressCacheRepository;
|
private final ProgressCacheRepository progressCacheRepository;
|
||||||
|
|
||||||
public void predictCheck(final int projectId){
|
public void predictProgressCheck(final int projectId){
|
||||||
if(progressCacheRepository.predictCheck(projectId)){
|
if(progressCacheRepository.predictProgressCheck(projectId)){
|
||||||
// throw new CustomException(ErrorCode.AI_IN_PROGRESS);
|
throw new CustomException(ErrorCode.AI_IN_PROGRESS);
|
||||||
progressCacheRepository.removePredictProgress(projectId);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void registerPredictProgress(final int projectId){
|
public void registerPredictProgress(final int projectId){
|
||||||
progressCacheRepository.registerPredictProgress(projectId);
|
progressCacheRepository.registerPredictProgress(projectId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void removePredictProgress(final int projectId){
|
||||||
|
progressCacheRepository.removePredictProgress(projectId);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void trainProgressCheck(final int projectId){
|
||||||
|
if(progressCacheRepository.trainProgressCheck(projectId)){
|
||||||
|
throw new CustomException(ErrorCode.AI_IN_PROGRESS);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isProgressTrain(final int projectId){
|
||||||
|
return progressCacheRepository.trainProgressCheck(projectId);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void registerTrainProgress(final int projectId){
|
||||||
|
progressCacheRepository.registerTrainProgress(projectId);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void removeTrainProgress(final int projectId){
|
||||||
|
progressCacheRepository.removeTrainProgress(projectId);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<ReportResponse> getProgressResponse(final int modelId) {
|
||||||
|
return progressCacheRepository.getProgressModel(modelId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,41 +3,76 @@ package com.worlabel.domain.project.dto;
|
|||||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
import com.google.gson.annotations.SerializedName;
|
import com.google.gson.annotations.SerializedName;
|
||||||
import com.worlabel.domain.image.entity.Image;
|
import com.worlabel.domain.image.entity.Image;
|
||||||
|
import com.worlabel.domain.model.entity.dto.ModelTrainRequest;
|
||||||
|
import com.worlabel.domain.result.entity.Optimizer;
|
||||||
import lombok.*;
|
import lombok.*;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
public class AiDto {
|
public class AiDto {
|
||||||
|
|
||||||
@Data
|
@Getter
|
||||||
|
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
public static class TrainDataInfo {
|
public static class TrainDataInfo {
|
||||||
|
|
||||||
|
@JsonProperty("image_url")
|
||||||
private String imagePath;
|
private String imagePath;
|
||||||
|
|
||||||
|
@JsonProperty("data_url")
|
||||||
private String dataPath;
|
private String dataPath;
|
||||||
|
|
||||||
public TrainDataInfo(String imagePath, String dataPath) {
|
public static TrainDataInfo of(Image image) {
|
||||||
this.imagePath = imagePath;
|
return new TrainDataInfo(image.getImagePath(), image.getDataPath());
|
||||||
this.dataPath = dataPath;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Getter
|
||||||
|
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
|
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
public static class TrainRequest {
|
public static class TrainRequest {
|
||||||
|
|
||||||
@JsonProperty("project_id")
|
@JsonProperty("project_id")
|
||||||
private int projectId;
|
private int projectId;
|
||||||
|
|
||||||
@JsonProperty("category_id")
|
@JsonProperty("model_key")
|
||||||
private List<Integer> categoryId;
|
private String modelKey;
|
||||||
|
|
||||||
|
@JsonProperty("label_map")
|
||||||
|
private Map<Integer,Integer> labelMap;
|
||||||
|
|
||||||
@JsonProperty("data")
|
@JsonProperty("data")
|
||||||
private List<TrainDataInfo> data;
|
private List<TrainDataInfo> data;
|
||||||
|
|
||||||
@JsonProperty("model_key")
|
private double ratio; // Default = 0.8
|
||||||
private String modelKey;
|
|
||||||
// private int seed; // Optional
|
private int epochs; // Default = 50
|
||||||
// private float ratio; // Default = 0.8
|
|
||||||
// private int epochs; // Default = 50
|
private double batch; // Default = -1
|
||||||
// private float batch; // Default = -1
|
|
||||||
|
private double lr0;
|
||||||
|
|
||||||
|
private double lrf;
|
||||||
|
|
||||||
|
private Optimizer optimizer;
|
||||||
|
|
||||||
|
public static TrainRequest of(final Integer projectId, final String modelKey, final Map<Integer, Integer> labelMap, final List<TrainDataInfo> data, final ModelTrainRequest trainRequest) {
|
||||||
|
TrainRequest request = new TrainRequest();
|
||||||
|
request.projectId = projectId;
|
||||||
|
request.modelKey = modelKey;
|
||||||
|
request.labelMap = labelMap;
|
||||||
|
request.data = data;
|
||||||
|
request.ratio = request.getRatio();
|
||||||
|
request.epochs = trainRequest.getEpochs();
|
||||||
|
request.batch = trainRequest.getBatch();
|
||||||
|
request.lr0 = trainRequest.getLr0();
|
||||||
|
request.lrf = trainRequest.getLrf();
|
||||||
|
request.optimizer = trainRequest.getOptimizer();
|
||||||
|
|
||||||
|
return request;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@ -89,7 +124,7 @@ public class AiDto {
|
|||||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||||
@Getter
|
@Getter
|
||||||
@ToString
|
@ToString
|
||||||
public static class AutoLabelingResult{
|
public static class AutoLabelingResult {
|
||||||
|
|
||||||
@SerializedName("image_id")
|
@SerializedName("image_id")
|
||||||
private Long imageId;
|
private Long imageId;
|
||||||
|
@ -162,7 +162,7 @@ public class ProjectService {
|
|||||||
*/
|
*/
|
||||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||||
public void autoLabeling(final Integer projectId, final AutoModelRequest request) {
|
public void autoLabeling(final Integer projectId, final AutoModelRequest request) {
|
||||||
progressService.predictCheck(projectId);
|
// progressService.predictCheck(projectId);
|
||||||
|
|
||||||
Project project = getProject(projectId);
|
Project project = getProject(projectId);
|
||||||
String endPoint = project.getProjectType().getValue() + "/predict";
|
String endPoint = project.getProjectType().getValue() + "/predict";
|
||||||
|
@ -11,14 +11,14 @@ import org.springframework.web.bind.annotation.RestController;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@RestController
|
@RestController
|
||||||
@RequestMapping("/api/reports")
|
@RequestMapping("/api/projects/{project_id}/reports")
|
||||||
@RequiredArgsConstructor
|
@RequiredArgsConstructor
|
||||||
public class ReportController {
|
public class ReportController {
|
||||||
|
|
||||||
private final ReportService reportService;
|
private final ReportService reportService;
|
||||||
|
|
||||||
@GetMapping("/model/{model_id}")
|
@GetMapping("/model/{model_id}")
|
||||||
public List<ReportResponse> getReportsByModelId(@PathVariable("model_id") final Integer modelId) {
|
public List<ReportResponse> getReportsByModelId(@PathVariable("model_id") final Integer modelId, @PathVariable("project_id") final Integer projectId) {
|
||||||
return reportService.getReportsByModelId(modelId);
|
return reportService.getReportsByModelId(projectId,modelId);
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -25,17 +25,18 @@ public class Report extends BaseEntity {
|
|||||||
@JoinColumn(name = "model_id", nullable = false)
|
@JoinColumn(name = "model_id", nullable = false)
|
||||||
private AiModel aiModel;
|
private AiModel aiModel;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 현재 에포크
|
||||||
|
*/
|
||||||
|
@Column(name = "epoch", nullable = false)
|
||||||
|
private Integer epoch;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 전체 에포크
|
* 전체 에포크
|
||||||
*/
|
*/
|
||||||
@Column(name = "total_epochs", nullable = false)
|
@Column(name = "total_epochs", nullable = false)
|
||||||
private Integer totalEpochs;
|
private Integer totalEpochs;
|
||||||
|
|
||||||
/**
|
|
||||||
* 현재 에포크
|
|
||||||
*/
|
|
||||||
@Column(name = "epoch", nullable = false)
|
|
||||||
private Integer epoch;
|
|
||||||
|
|
||||||
@Column(name = "box_loss", nullable = false)
|
@Column(name = "box_loss", nullable = false)
|
||||||
private double boxLoss;
|
private double boxLoss;
|
||||||
@ -48,4 +49,10 @@ public class Report extends BaseEntity {
|
|||||||
|
|
||||||
@Column(name = "fitness", nullable = false)
|
@Column(name = "fitness", nullable = false)
|
||||||
private double fitness;
|
private double fitness;
|
||||||
|
|
||||||
|
@Column(name = "epoch_time", nullable = false)
|
||||||
|
private double epochTime;
|
||||||
|
|
||||||
|
@Column(name = "left_second", nullable = false)
|
||||||
|
private double leftSecond;
|
||||||
}
|
}
|
||||||
|
@ -4,26 +4,33 @@ import com.worlabel.domain.report.entity.Report;
|
|||||||
import lombok.AccessLevel;
|
import lombok.AccessLevel;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
public class ReportResponse {
|
public class ReportResponse {
|
||||||
private Integer id;
|
private int modelId;
|
||||||
private Integer totalEpochs;
|
private int totalEpochs;
|
||||||
private Integer epoch;
|
private int epoch;
|
||||||
private double boxLoss;
|
private double boxLoss;
|
||||||
private double clsLoss;
|
private double clsLoss;
|
||||||
private double dflLoss;
|
private double dflLoss;
|
||||||
private double fitness;
|
private double fitness;
|
||||||
|
private double epochTime;
|
||||||
|
private double leftSecond;
|
||||||
|
|
||||||
public static ReportResponse from(final Report report) {
|
public static ReportResponse from(final Report report) {
|
||||||
return new ReportResponse(
|
return new ReportResponse(
|
||||||
report.getId(),
|
report.getAiModel().getId(),
|
||||||
report.getTotalEpochs(),
|
report.getTotalEpochs(),
|
||||||
report.getEpoch(),
|
report.getEpoch(),
|
||||||
report.getBoxLoss(),
|
report.getBoxLoss(),
|
||||||
report.getClsLoss(),
|
report.getClsLoss(),
|
||||||
report.getDflLoss(),
|
report.getDflLoss(),
|
||||||
report.getFitness());
|
report.getFitness(),
|
||||||
|
report.getEpochTime(),
|
||||||
|
report.getLeftSecond()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,5 +1,6 @@
|
|||||||
package com.worlabel.domain.report.service;
|
package com.worlabel.domain.report.service;
|
||||||
|
|
||||||
|
import com.worlabel.domain.progress.service.ProgressService;
|
||||||
import com.worlabel.domain.report.entity.Report;
|
import com.worlabel.domain.report.entity.Report;
|
||||||
import com.worlabel.domain.report.entity.dto.ReportResponse;
|
import com.worlabel.domain.report.entity.dto.ReportResponse;
|
||||||
import com.worlabel.domain.report.repository.ReportRepository;
|
import com.worlabel.domain.report.repository.ReportRepository;
|
||||||
@ -7,6 +8,7 @@ import lombok.RequiredArgsConstructor;
|
|||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@ -15,11 +17,43 @@ import java.util.List;
|
|||||||
public class ReportService {
|
public class ReportService {
|
||||||
|
|
||||||
private final ReportRepository reportRepository;
|
private final ReportRepository reportRepository;
|
||||||
|
private final ProgressService progressService;
|
||||||
|
|
||||||
public List<ReportResponse> getReportsByModelId(final Integer modelId) {
|
public List<ReportResponse> getReportsByModelId(final Integer projectId, final Integer modelId) {
|
||||||
List<Report> reports = reportRepository.findByAiModelId(modelId);
|
// 진행중이면 진행중에서 받아오기
|
||||||
return reports.stream()
|
return getDummyList();
|
||||||
.map(ReportResponse::from)
|
// if(progressService.isProgressTrain(projectId)){
|
||||||
.toList();
|
// return progressService.getProgressResponse(modelId);
|
||||||
|
// }
|
||||||
|
// // 작업 완료시에는 RDB
|
||||||
|
// else{
|
||||||
|
// List<Report> reports = reportRepository.findByAiModelId(modelId);
|
||||||
|
// return reports.stream()
|
||||||
|
// .map(ReportResponse::from)
|
||||||
|
// .toList();
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private List<ReportResponse> getDummyList() {
|
||||||
|
List<ReportResponse> dummyList = new ArrayList<>();
|
||||||
|
|
||||||
|
// 더미 데이터 15개 생성
|
||||||
|
for (int i = 1; i <= 15; i++) {
|
||||||
|
ReportResponse dummy = new ReportResponse(
|
||||||
|
i, // modelId
|
||||||
|
100, // totalEpochs
|
||||||
|
i, // epoch
|
||||||
|
Math.random(), // boxLoss
|
||||||
|
Math.random(), // clsLoss
|
||||||
|
Math.random(), // dflLoss
|
||||||
|
Math.random(), // fitness
|
||||||
|
Math.random() * 10,// epochTime
|
||||||
|
Math.random() * 100 // leftSecond
|
||||||
|
);
|
||||||
|
dummyList.add(dummy);
|
||||||
|
}
|
||||||
|
|
||||||
|
return dummyList;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
@ -14,4 +14,8 @@ public class CacheKey {
|
|||||||
public static String fcmTokenKey(){
|
public static String fcmTokenKey(){
|
||||||
return "fcmToken";
|
return "fcmToken";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static String progressStatusKey(int modelId) {
|
||||||
|
return "progress:" + modelId;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user