Merge branch 'be/feat/model' into 'be/develop'

Feat: 모델 관련 API 구현

See merge request s11-s-project/S11P21S002!119
This commit is contained in:
김용수 2024-09-22 23:32:18 +09:00
commit 64c549c8bc
7 changed files with 171 additions and 72 deletions

View File

@ -10,6 +10,7 @@ import java.util.Optional;
public interface ImageRepository extends JpaRepository<Image, Long> { public interface ImageRepository extends JpaRepository<Image, Long> {
// todo N + 1 발생할듯
@Query("select i from Image i " + @Query("select i from Image i " +
"where i.folder.project.id = :projectId") "where i.folder.project.id = :projectId")
List<Image> findImagesByProjectId(@Param("projectId") Integer projectId); List<Image> findImagesByProjectId(@Param("projectId") Integer projectId);

View File

@ -65,10 +65,17 @@ public class AiModelController {
@PathVariable("project_id") final Integer projectId, @PathVariable("project_id") final Integer projectId,
@PathVariable("model_id") final Integer modelId, @PathVariable("model_id") final Integer modelId,
@Valid @RequestBody final AiModelRequest aiModelRequest) { @Valid @RequestBody final AiModelRequest aiModelRequest) {
aiModelService.renameModel(memberId, projectId,modelId, aiModelRequest); aiModelService.renameModel(memberId, projectId, modelId, aiModelRequest);
} }
// TODO: 여기서 모델 학습을 따로 만들어야 Project 있는 모델 학습을 여기로 옮겨서 진행 @Operation(summary = "프로젝트 모델 학습", description = "프로젝트 모델을 학습시킵니다.")
// 아마도 필요한 요청 값들은 ModelID @SwaggerApiSuccess(description = "프로젝트 모델이 성공적으로 학습됩니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@PostMapping("/projects/{project_id}/train")
public void trainModel(
@CurrentUser final Integer memberId,
@PathVariable("project_id") final Integer projectId,
@RequestBody final Integer modelId) {
aiModelService.train(memberId, projectId, modelId);
}
} }

View File

@ -18,4 +18,6 @@ public interface AiModelRepository extends JpaRepository<AiModel, Integer> {
@Query("SELECT a FROM AiModel a " + @Query("SELECT a FROM AiModel a " +
"WHERE a.project IS NOT NULL AND a.id = :modelId") "WHERE a.project IS NOT NULL AND a.id = :modelId")
Optional<AiModel> findCustomModelById(@Param("modelId") int modelId); Optional<AiModel> findCustomModelById(@Param("modelId") int modelId);
List<AiModel> findAllByModelKeyIn(List<String> allModelKeys);
} }

View File

@ -1,29 +1,38 @@
package com.worlabel.domain.model.service; package com.worlabel.domain.model.service;
import com.google.gson.Gson; import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken; import com.google.gson.reflect.TypeToken;
import com.worlabel.domain.image.entity.Image;
import com.worlabel.domain.image.entity.LabelStatus;
import com.worlabel.domain.image.repository.ImageRepository;
import com.worlabel.domain.labelcategory.entity.LabelCategory; import com.worlabel.domain.labelcategory.entity.LabelCategory;
import com.worlabel.domain.labelcategory.entity.dto.DefaultLabelCategoryResponse;
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse; import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository; import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository;
import com.worlabel.domain.model.entity.AiModel; import com.worlabel.domain.model.entity.AiModel;
import com.worlabel.domain.model.entity.dto.AiModelRequest; import com.worlabel.domain.model.entity.dto.AiModelRequest;
import com.worlabel.domain.model.entity.dto.AiModelResponse; import com.worlabel.domain.model.entity.dto.AiModelResponse;
import com.worlabel.domain.model.entity.dto.DefaultAiModelResponse;
import com.worlabel.domain.model.entity.dto.DefaultResponse; import com.worlabel.domain.model.entity.dto.DefaultResponse;
import com.worlabel.domain.model.repository.AiModelRepository; import com.worlabel.domain.model.repository.AiModelRepository;
import com.worlabel.domain.participant.entity.PrivilegeType; import com.worlabel.domain.participant.entity.PrivilegeType;
import com.worlabel.domain.project.dto.RequestDto;
import com.worlabel.domain.project.entity.Project; import com.worlabel.domain.project.entity.Project;
import com.worlabel.domain.project.repository.ProjectRepository; import com.worlabel.domain.project.repository.ProjectRepository;
import com.worlabel.global.annotation.CheckPrivilege; import com.worlabel.global.annotation.CheckPrivilege;
import com.worlabel.global.exception.CustomException; import com.worlabel.global.exception.CustomException;
import com.worlabel.global.exception.ErrorCode; import com.worlabel.global.exception.ErrorCode;
import com.worlabel.global.service.AiRequestService; import com.worlabel.global.service.AiRequestService;
import jakarta.annotation.PostConstruct;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import java.lang.reflect.Type; import java.lang.reflect.Type;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.List; import java.util.List;
@Slf4j @Slf4j
@ -35,9 +44,50 @@ public class AiModelService {
private final AiModelRepository aiModelRepository; private final AiModelRepository aiModelRepository;
private final ProjectRepository projectRepository; private final ProjectRepository projectRepository;
private final LabelCategoryRepository labelCategoryRepository; private final LabelCategoryRepository labelCategoryRepository;
private final ImageRepository imageRepository;
private final AiRequestService aiRequestService; private final AiRequestService aiRequestService;
private final Gson gson; private final Gson gson;
@PostConstruct
public void loadDefaultModel() {
String url = "model/default";
List<DefaultResponse> defaultResponseList = aiRequestService.getRequest(url, this::converter);
// 1. DefaultResponse의 Key값만 모아서 리스트로 만든다.
List<String> allModelKeys = defaultResponseList.stream()
.map(response -> response.getDefaultAiModelResponse().getModelKey())
.toList();
// 2. 해당 Key값이 DB에 있는지 확인하기 ( 번의 쿼리로)
List<String> existingModelKeys = aiModelRepository.findAllByModelKeyIn(allModelKeys).stream()
.map(AiModel::getModelKey)
.toList();
// 3. DB에 없는 Key만 필터링해서 처리
List<DefaultResponse> newModel = defaultResponseList.stream()
.filter(model -> !existingModelKeys.contains(model.getDefaultAiModelResponse().getModelKey()))
.toList();
// 새롭게 추가된 값을 디비에 저장
List<AiModel> aiModels = new ArrayList<>();
List<LabelCategory> categories = new ArrayList<>();
for (DefaultResponse defaultResponse : newModel) {
DefaultAiModelResponse defaultAiModelResponse = defaultResponse.getDefaultAiModelResponse();
AiModel newAiModel = AiModel.of(defaultAiModelResponse.getName(), defaultAiModelResponse.getModelKey(), 0, null);
aiModels.add(newAiModel);
List<DefaultLabelCategoryResponse> defaultLabelCategoryResponseList = defaultResponse.getDefaultLabelCategoryResponseList();
for (DefaultLabelCategoryResponse categoryResponse : defaultLabelCategoryResponseList) {
categories.add(LabelCategory.of(newAiModel, categoryResponse.getName(), categoryResponse.getAiId()));
}
}
aiModelRepository.saveAll(aiModels);
labelCategoryRepository.saveAll(categories);
}
@Transactional(readOnly = true) @Transactional(readOnly = true)
public List<AiModelResponse> getModelList(final Integer projectId) { public List<AiModelResponse> getModelList(final Integer projectId) {
return aiModelRepository.findAllByProjectId(projectId) return aiModelRepository.findAllByProjectId(projectId)
@ -74,21 +124,42 @@ public class AiModelService {
.toList(); .toList();
} }
/** @CheckPrivilege(PrivilegeType.EDITOR)
* 해당 Default 모델 불러오기 API 예시 public void train(Integer memberId, Integer projectId, Integer modelId) {
// TODO: 레디스 train 테이블에 존재하는지 확인 -> 이미 있으면 있다고 예외를 던져준다. -> 용수 추후 구현 예정
/*
없으면 redis 상태 테이블을 만든다. progressTable
*/ */
// TODO : 스프링이 로딩 DefaultModel을 불러온다.
public void loadDefaultModel() {
String url = "model/default";
List<DefaultResponse> defaultResponseList = aiRequestService.getRequest(url, this::converter);
// TODO: defaultModel 현재 DB에 해당하는지 안하는지 확인하기 // FastAPI 서버로 학습 요청을 전송
Project project = getProject(projectId);
AiModel model = getModel(modelId);
List<LabelCategory> labelCategories = labelCategoryRepository.findAllByModelId(modelId);
List<Integer> categories = labelCategories.stream()
.map(LabelCategory::getAiCategoryId).toList();
// TODO : 1.DefaultResponse의 Key값만 모아서 리스트로 만든다. List<Image> images = imageRepository.findImagesByProjectId(projectId);
// TODO: 2. IN(key...) 해당되는 Key값 확인하기 List<RequestDto.TrainDataInfo> data = images.stream().filter(image -> image.getStatus() == LabelStatus.COMPLETED)
.map(image -> new RequestDto.TrainDataInfo(image.getImagePath(), image.getDataPath()))
.toList();
// TODO: 3. 현재 DB에 없는 Key만 모아서 DB와 CategoryList에 넣어주면 String endPoint = project.getProjectType().getValue() + "/train";
RequestDto.TrainRequest trainRequest = new RequestDto.TrainRequest();
trainRequest.setProjectId(projectId);
trainRequest.setCategoryId(categories);
trainRequest.setData(data);
trainRequest.setModelKey(model.getModelKey());
// FastAPI 서버로 POST 요청 전송
String modelKey = aiRequestService.postRequest(endPoint, trainRequest, String.class, response -> response);
// 가져온 modelKey -> version 업된 모델 다시 새롭게 저장
String currentDateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
AiModel newModel = AiModel.of(currentDateTime, modelKey, model.getVersion() + 1, project);
aiModelRepository.save(newModel);
} }
/** /**
@ -96,17 +167,23 @@ public class AiModelService {
*/ */
// TODO: 추후 리팩토링 해야함 이건 예시 // TODO: 추후 리팩토링 해야함 이건 예시
private List<DefaultResponse> converter(String data) { private List<DefaultResponse> converter(String data) {
try{ try {
Type listType = new TypeToken<List<DefaultResponse>>() {}.getType(); Type listType = new TypeToken<List<DefaultResponse>>() {
}.getType();
return gson.fromJson(data, listType); return gson.fromJson(data, listType);
}catch (Exception e){ } catch (Exception e) {
log.debug("TODO: 추후 리팩토링 해야함 이건 예시"); log.debug("TODO: 추후 리팩토링 해야함 이건 예시");
throw new CustomException(ErrorCode.BAD_REQUEST); throw new CustomException(ErrorCode.BAD_REQUEST);
} }
} }
private Project getProject(Integer projectId) { private Project getProject(Integer projectId) {
return projectRepository.findById(projectId).orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND)); return projectRepository.findById(projectId)
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
} }
private AiModel getModel(Integer modelId) {
return aiModelRepository.findById(modelId)
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
}
} }

View File

@ -72,15 +72,15 @@ public class ProjectController {
return projectService.updateProject(memberId, projectId, projectRequest); return projectService.updateProject(memberId, projectId, projectRequest);
} }
@Operation(summary = "프로젝트 모델 학습", description = "프로젝트 모델을 학습시킵니다.") // @Operation(summary = "프로젝트 모델 학습", description = "프로젝트 모델을 학습시킵니다.")
@SwaggerApiSuccess(description = "프로젝트 모델이 성공적으로 학습됩니다.") // @SwaggerApiSuccess(description = "프로젝트 모델이 성공적으로 학습됩니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) // @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@PostMapping("/projects/{project_id}/train") // @PostMapping("/projects/{project_id}/train")
public void trainModel( // public void trainModel(
@CurrentUser final Integer memberId, // @CurrentUser final Integer memberId,
@PathVariable("project_id") final Integer projectId) { // @PathVariable("project_id") final Integer projectId) {
projectService.train(memberId, projectId); // projectService.train(memberId, projectId);
} // }
@Operation(summary = "프로젝트 오토 레이블링", description = "해당 프로젝트 이미지를 오토레이블링합니다.") @Operation(summary = "프로젝트 오토 레이블링", description = "해당 프로젝트 이미지를 오토레이블링합니다.")
@SwaggerApiSuccess(description = "해당 프로젝트가 오토 레이블링 됩니다.") @SwaggerApiSuccess(description = "해당 프로젝트가 오토 레이블링 됩니다.")

View File

@ -8,8 +8,14 @@ import java.util.List;
public class RequestDto { public class RequestDto {
@Data @Data
public class TrainDataInfo { public static class TrainDataInfo {
private String imageUrl; private String imagePath;
private String dataPath;
public TrainDataInfo(String imagePath, String dataPath) {
this.imagePath = imagePath;
this.dataPath = dataPath;
}
} }
@Data @Data
@ -17,8 +23,14 @@ public class RequestDto {
@JsonProperty("project_id") @JsonProperty("project_id")
private int projectId; private int projectId;
@JsonProperty("category_id")
private List<Integer> categoryId;
@JsonProperty("data") @JsonProperty("data")
private List<TrainDataInfo> data; private List<TrainDataInfo> data;
@JsonProperty("model_key")
private String modelKey;
// private int seed; // Optional // private int seed; // Optional
// private float ratio; // Default = 0.8 // private float ratio; // Default = 0.8
// private int epochs; // Default = 50 // private int epochs; // Default = 50

View File

@ -134,26 +134,26 @@ public class ProjectService {
participantRepository.delete(participant); participantRepository.delete(participant);
} }
@CheckPrivilege(PrivilegeType.EDITOR) // @CheckPrivilege(PrivilegeType.EDITOR)
public void train(final Integer memberId, final Integer projectId) { // public void train(final Integer memberId, final Integer projectId) {
// TODO: 레디스 train 테이블에 존재하는지 확인 -> 이미 있으면 있다고 예외를 던져준다. -> 용수 추후 구현 예정 // // TODO: 레디스 train 테이블에 존재하는지 확인 -> 이미 있으면 있다고 예외를 던져준다. -> 용수 추후 구현 예정
/* // /*
없으면 redis 상태 테이블을 만든다. progressTable // 없으면 redis 상태 테이블을 만든다. progressTable
*/ // */
//
// FastAPI 서버로 학습 요청을 전송 // // FastAPI 서버로 학습 요청을 전송
Project project = getProject(projectId); // Project project = getProject(projectId);
String endPoint = project.getProjectType().getValue() + "/train"; // String endPoint = project.getProjectType().getValue() + "/train";
//
TrainRequest trainRequest = new TrainRequest(); // TrainRequest trainRequest = new TrainRequest();
trainRequest.setProjectId(projectId); // trainRequest.setProjectId(projectId);
trainRequest.setData(List.of()); // trainRequest.setData(List.of());
//
// FastAPI 서버로 POST 요청 전송 // // FastAPI 서버로 POST 요청 전송
String modelKey = aiService.postRequest(endPoint, trainRequest, String.class, response -> response); // String modelKey = aiService.postRequest(endPoint, trainRequest, String.class, response -> response);
//
// TODO: 모델 생성 Default 이름과 Key 설정 // // TODO: 모델 생성 Default 이름과 Key 설정
} // }
/** /**
* 프로젝트별 오토 레이블링 * 프로젝트별 오토 레이블링