diff --git a/backend/src/main/java/com/worlabel/domain/image/repository/ImageRepository.java b/backend/src/main/java/com/worlabel/domain/image/repository/ImageRepository.java index 4b4b231..b00baae 100644 --- a/backend/src/main/java/com/worlabel/domain/image/repository/ImageRepository.java +++ b/backend/src/main/java/com/worlabel/domain/image/repository/ImageRepository.java @@ -10,6 +10,7 @@ import java.util.Optional; public interface ImageRepository extends JpaRepository { + // todo N + 1 발생할듯 @Query("select i from Image i " + "where i.folder.project.id = :projectId") List findImagesByProjectId(@Param("projectId") Integer projectId); diff --git a/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java b/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java index c7b1814..f4ebf65 100644 --- a/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java +++ b/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java @@ -32,7 +32,7 @@ public class AiModelController { @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) @GetMapping("/projects/{project_id}/models") public List getModelList( - @PathVariable("project_id") final Integer projectId) { + @PathVariable("project_id") final Integer projectId) { return aiModelService.getModelList(projectId); } @@ -41,7 +41,7 @@ public class AiModelController { @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) @GetMapping("/models/{model_id}/categories") public List getCategories( - @PathVariable("model_id") final Integer modelId) { + @PathVariable("model_id") final Integer modelId) { return aiModelService.getCategories(modelId); } @@ -50,9 +50,9 @@ public class AiModelController { @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) @PostMapping("/projects/{project_id}/models") public void addModel( - @CurrentUser final Integer memberId, - @PathVariable("project_id") final Integer projectId, - @Valid @RequestBody final AiModelRequest aiModelRequest) { + @CurrentUser final Integer memberId, + @PathVariable("project_id") final Integer projectId, + @Valid @RequestBody final AiModelRequest aiModelRequest) { aiModelService.addModel(memberId, projectId, aiModelRequest); } @@ -61,14 +61,21 @@ public class AiModelController { @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) @PutMapping("/projects/{project_id}/models/{model_id}") public void renameModel( - @CurrentUser final Integer memberId, - @PathVariable("project_id") final Integer projectId, - @PathVariable("model_id") final Integer modelId, - @Valid @RequestBody final AiModelRequest aiModelRequest) { - aiModelService.renameModel(memberId, projectId,modelId, aiModelRequest); + @CurrentUser final Integer memberId, + @PathVariable("project_id") final Integer projectId, + @PathVariable("model_id") final Integer modelId, + @Valid @RequestBody final AiModelRequest aiModelRequest) { + aiModelService.renameModel(memberId, projectId, modelId, aiModelRequest); } - // TODO: 여기서 모델 학습을 따로 만들어야 할 듯 Project 있는 모델 학습을 여기로 옮겨서 진행 - // 아마도 필요한 요청 값들은 ModelID - + @Operation(summary = "프로젝트 모델 학습", description = "프로젝트 모델을 학습시킵니다.") + @SwaggerApiSuccess(description = "프로젝트 모델이 성공적으로 학습됩니다.") + @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) + @PostMapping("/projects/{project_id}/train") + public void trainModel( + @CurrentUser final Integer memberId, + @PathVariable("project_id") final Integer projectId, + @RequestBody final Integer modelId) { + aiModelService.train(memberId, projectId, modelId); + } } diff --git a/backend/src/main/java/com/worlabel/domain/model/repository/AiModelRepository.java b/backend/src/main/java/com/worlabel/domain/model/repository/AiModelRepository.java index 7c719a3..15561a4 100644 --- a/backend/src/main/java/com/worlabel/domain/model/repository/AiModelRepository.java +++ b/backend/src/main/java/com/worlabel/domain/model/repository/AiModelRepository.java @@ -18,4 +18,6 @@ public interface AiModelRepository extends JpaRepository { @Query("SELECT a FROM AiModel a " + "WHERE a.project IS NOT NULL AND a.id = :modelId") Optional findCustomModelById(@Param("modelId") int modelId); + + List findAllByModelKeyIn(List allModelKeys); } diff --git a/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java b/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java index 1ba8e7c..d7b036d 100644 --- a/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java +++ b/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java @@ -1,29 +1,38 @@ package com.worlabel.domain.model.service; - import com.google.gson.Gson; import com.google.gson.reflect.TypeToken; +import com.worlabel.domain.image.entity.Image; +import com.worlabel.domain.image.entity.LabelStatus; +import com.worlabel.domain.image.repository.ImageRepository; import com.worlabel.domain.labelcategory.entity.LabelCategory; +import com.worlabel.domain.labelcategory.entity.dto.DefaultLabelCategoryResponse; import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse; import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository; import com.worlabel.domain.model.entity.AiModel; import com.worlabel.domain.model.entity.dto.AiModelRequest; import com.worlabel.domain.model.entity.dto.AiModelResponse; +import com.worlabel.domain.model.entity.dto.DefaultAiModelResponse; import com.worlabel.domain.model.entity.dto.DefaultResponse; import com.worlabel.domain.model.repository.AiModelRepository; import com.worlabel.domain.participant.entity.PrivilegeType; +import com.worlabel.domain.project.dto.RequestDto; import com.worlabel.domain.project.entity.Project; import com.worlabel.domain.project.repository.ProjectRepository; import com.worlabel.global.annotation.CheckPrivilege; import com.worlabel.global.exception.CustomException; import com.worlabel.global.exception.ErrorCode; import com.worlabel.global.service.AiRequestService; +import jakarta.annotation.PostConstruct; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; import java.lang.reflect.Type; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; import java.util.List; @Slf4j @@ -35,15 +44,56 @@ public class AiModelService { private final AiModelRepository aiModelRepository; private final ProjectRepository projectRepository; private final LabelCategoryRepository labelCategoryRepository; + private final ImageRepository imageRepository; private final AiRequestService aiRequestService; private final Gson gson; + @PostConstruct + public void loadDefaultModel() { + String url = "model/default"; + List defaultResponseList = aiRequestService.getRequest(url, this::converter); + + // 1. DefaultResponse의 Key값만 모아서 리스트로 만든다. + List allModelKeys = defaultResponseList.stream() + .map(response -> response.getDefaultAiModelResponse().getModelKey()) + .toList(); + + // 2. 해당 Key값이 DB에 있는지 확인하기 (한 번의 쿼리로) + List existingModelKeys = aiModelRepository.findAllByModelKeyIn(allModelKeys).stream() + .map(AiModel::getModelKey) + .toList(); + + // 3. DB에 없는 Key만 필터링해서 처리 + List newModel = defaultResponseList.stream() + .filter(model -> !existingModelKeys.contains(model.getDefaultAiModelResponse().getModelKey())) + .toList(); + + + // 새롭게 추가된 값을 디비에 저장 + List aiModels = new ArrayList<>(); + List categories = new ArrayList<>(); + for (DefaultResponse defaultResponse : newModel) { + DefaultAiModelResponse defaultAiModelResponse = defaultResponse.getDefaultAiModelResponse(); + AiModel newAiModel = AiModel.of(defaultAiModelResponse.getName(), defaultAiModelResponse.getModelKey(), 0, null); + aiModels.add(newAiModel); + + List defaultLabelCategoryResponseList = defaultResponse.getDefaultLabelCategoryResponseList(); + + for (DefaultLabelCategoryResponse categoryResponse : defaultLabelCategoryResponseList) { + categories.add(LabelCategory.of(newAiModel, categoryResponse.getName(), categoryResponse.getAiId())); + } + } + + aiModelRepository.saveAll(aiModels); + labelCategoryRepository.saveAll(categories); + } + @Transactional(readOnly = true) public List getModelList(final Integer projectId) { return aiModelRepository.findAllByProjectId(projectId) - .stream() - .map(AiModelResponse::of) - .toList(); + .stream() + .map(AiModelResponse::of) + .toList(); } @CheckPrivilege(PrivilegeType.EDITOR) @@ -70,43 +120,70 @@ public class AiModelService { public List getCategories(final Integer modelId) { List categoryList = labelCategoryRepository.findAllByModelId(modelId); return categoryList.stream() - .map(LabelCategoryResponse::from) - .toList(); + .map(LabelCategoryResponse::from) + .toList(); + } + + @CheckPrivilege(PrivilegeType.EDITOR) + public void train(Integer memberId, Integer projectId, Integer modelId) { + // TODO: 레디스 train 테이블에 존재하는지 확인 -> 이미 있으면 있다고 예외를 던져준다. -> 용수 추후 구현 예정 + /* + 없으면 redis 상태 테이블을 만든다. progressTable + */ + + // FastAPI 서버로 학습 요청을 전송 + Project project = getProject(projectId); + AiModel model = getModel(modelId); + List labelCategories = labelCategoryRepository.findAllByModelId(modelId); + List categories = labelCategories.stream() + .map(LabelCategory::getAiCategoryId).toList(); + + List images = imageRepository.findImagesByProjectId(projectId); + + List data = images.stream().filter(image -> image.getStatus() == LabelStatus.COMPLETED) + .map(image -> new RequestDto.TrainDataInfo(image.getImagePath(), image.getDataPath())) + .toList(); + + String endPoint = project.getProjectType().getValue() + "/train"; + + RequestDto.TrainRequest trainRequest = new RequestDto.TrainRequest(); + trainRequest.setProjectId(projectId); + trainRequest.setCategoryId(categories); + trainRequest.setData(data); + trainRequest.setModelKey(model.getModelKey()); + + // FastAPI 서버로 POST 요청 전송 + String modelKey = aiRequestService.postRequest(endPoint, trainRequest, String.class, response -> response); + + // 가져온 modelKey -> version 업된 모델 다시 새롭게 저장 + String currentDateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")); + + AiModel newModel = AiModel.of(currentDateTime, modelKey, model.getVersion() + 1, project); + aiModelRepository.save(newModel); } /** - * 해당 Default 모델 불러오기 API 예시 - */ - // TODO : 스프링이 로딩 후 DefaultModel을 불러온다. - public void loadDefaultModel() { - String url = "model/default"; - List defaultResponseList = aiRequestService.getRequest(url, this::converter); - - // TODO: defaultModel 현재 DB에 해당하는지 안하는지 확인하기 - - // TODO : 1.DefaultResponse의 Key값만 모아서 리스트로 만든다. - - // TODO: 2. 그 중 IN(key...)로 해당되는 Key값 확인하기 - - // TODO: 3. 현재 DB에 없는 Key만 모아서 DB와 CategoryList에 넣어주면 됨 - } - - /** - * Json -> List + * Json -> List */ // TODO: 추후 리팩토링 해야함 이건 예시 private List converter(String data) { - try{ - Type listType = new TypeToken>() {}.getType(); + try { + Type listType = new TypeToken>() { + }.getType(); return gson.fromJson(data, listType); - }catch (Exception e){ + } catch (Exception e) { log.debug("TODO: 추후 리팩토링 해야함 이건 예시"); throw new CustomException(ErrorCode.BAD_REQUEST); } } private Project getProject(Integer projectId) { - return projectRepository.findById(projectId).orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND)); + return projectRepository.findById(projectId) + .orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND)); } + private AiModel getModel(Integer modelId) { + return aiModelRepository.findById(modelId) + .orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND)); + } } diff --git a/backend/src/main/java/com/worlabel/domain/project/controller/ProjectController.java b/backend/src/main/java/com/worlabel/domain/project/controller/ProjectController.java index 0f7153f..05a597f 100644 --- a/backend/src/main/java/com/worlabel/domain/project/controller/ProjectController.java +++ b/backend/src/main/java/com/worlabel/domain/project/controller/ProjectController.java @@ -72,15 +72,15 @@ public class ProjectController { return projectService.updateProject(memberId, projectId, projectRequest); } - @Operation(summary = "프로젝트 모델 학습", description = "프로젝트 모델을 학습시킵니다.") - @SwaggerApiSuccess(description = "프로젝트 모델이 성공적으로 학습됩니다.") - @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) - @PostMapping("/projects/{project_id}/train") - public void trainModel( - @CurrentUser final Integer memberId, - @PathVariable("project_id") final Integer projectId) { - projectService.train(memberId, projectId); - } +// @Operation(summary = "프로젝트 모델 학습", description = "프로젝트 모델을 학습시킵니다.") +// @SwaggerApiSuccess(description = "프로젝트 모델이 성공적으로 학습됩니다.") +// @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) +// @PostMapping("/projects/{project_id}/train") +// public void trainModel( +// @CurrentUser final Integer memberId, +// @PathVariable("project_id") final Integer projectId) { +// projectService.train(memberId, projectId); +// } @Operation(summary = "프로젝트 오토 레이블링", description = "해당 프로젝트 이미지를 오토레이블링합니다.") @SwaggerApiSuccess(description = "해당 프로젝트가 오토 레이블링 됩니다.") diff --git a/backend/src/main/java/com/worlabel/domain/project/dto/RequestDto.java b/backend/src/main/java/com/worlabel/domain/project/dto/RequestDto.java index 92fe3b5..e0019b7 100644 --- a/backend/src/main/java/com/worlabel/domain/project/dto/RequestDto.java +++ b/backend/src/main/java/com/worlabel/domain/project/dto/RequestDto.java @@ -8,8 +8,14 @@ import java.util.List; public class RequestDto { @Data - public class TrainDataInfo { - private String imageUrl; + public static class TrainDataInfo { + private String imagePath; + private String dataPath; + + public TrainDataInfo(String imagePath, String dataPath) { + this.imagePath = imagePath; + this.dataPath = dataPath; + } } @Data @@ -17,8 +23,14 @@ public class RequestDto { @JsonProperty("project_id") private int projectId; + @JsonProperty("category_id") + private List categoryId; + @JsonProperty("data") private List data; + + @JsonProperty("model_key") + private String modelKey; // private int seed; // Optional // private float ratio; // Default = 0.8 // private int epochs; // Default = 50 diff --git a/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java b/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java index 00bd51d..51b2415 100644 --- a/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java +++ b/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java @@ -134,26 +134,26 @@ public class ProjectService { participantRepository.delete(participant); } - @CheckPrivilege(PrivilegeType.EDITOR) - public void train(final Integer memberId, final Integer projectId) { - // TODO: 레디스 train 테이블에 존재하는지 확인 -> 이미 있으면 있다고 예외를 던져준다. -> 용수 추후 구현 예정 - /* - 없으면 redis 상태 테이블을 만든다. progressTable - */ - - // FastAPI 서버로 학습 요청을 전송 - Project project = getProject(projectId); - String endPoint = project.getProjectType().getValue() + "/train"; - - TrainRequest trainRequest = new TrainRequest(); - trainRequest.setProjectId(projectId); - trainRequest.setData(List.of()); - - // FastAPI 서버로 POST 요청 전송 - String modelKey = aiService.postRequest(endPoint, trainRequest, String.class, response -> response); - - // TODO: 모델 생성 후 Default 이름과 Key 값 설정 - } +// @CheckPrivilege(PrivilegeType.EDITOR) +// public void train(final Integer memberId, final Integer projectId) { +// // TODO: 레디스 train 테이블에 존재하는지 확인 -> 이미 있으면 있다고 예외를 던져준다. -> 용수 추후 구현 예정 +// /* +// 없으면 redis 상태 테이블을 만든다. progressTable +// */ +// +// // FastAPI 서버로 학습 요청을 전송 +// Project project = getProject(projectId); +// String endPoint = project.getProjectType().getValue() + "/train"; +// +// TrainRequest trainRequest = new TrainRequest(); +// trainRequest.setProjectId(projectId); +// trainRequest.setData(List.of()); +// +// // FastAPI 서버로 POST 요청 전송 +// String modelKey = aiService.postRequest(endPoint, trainRequest, String.class, response -> response); +// +// // TODO: 모델 생성 후 Default 이름과 Key 값 설정 +// } /** * 프로젝트별 오토 레이블링