From b84e06baba2226950062f989b30b6f6c3e6c89ed Mon Sep 17 00:00:00 2001 From: kimtaesoo7 Date: Fri, 27 Sep 2024 11:15:24 +0900 Subject: [PATCH] =?UTF-8?q?Refactor:=20=EC=B9=B4=ED=85=8C=EA=B3=A0?= =?UTF-8?q?=EB=A6=AC=20=EC=88=98=EC=A0=95=20=20-=20S11P21S002-223?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../controller/CategoryController.java | 37 +---------- .../labelcategory/entity/LabelCategory.java | 53 ---------------- .../labelcategory/entity/ProjectCategory.java | 15 +++-- .../entity/dto/CategoryResponse.java | 22 +++++++ .../entity/dto/LabelCategoryRequest.java | 21 ------- .../entity/dto/LabelCategoryResponse.java | 22 ------- .../repository/LabelCategoryRepository.java | 22 ------- .../ProjectLabelCategoryRepository.java | 11 +--- .../service/ProjectLabelCategoryService.java | 52 ++-------------- .../model/controller/AiModelController.java | 10 --- .../worlabel/domain/model/entity/AiModel.java | 10 --- .../domain/model/service/AiModelService.java | 62 +------------------ .../project/controller/ProjectController.java | 7 +-- .../worlabel/domain/project/dto/AiDto.java | 8 +-- .../dto/ProjectWithCategoryRequest.java | 31 ++++++++++ .../project/service/ProjectService.java | 44 +++++++------ 16 files changed, 101 insertions(+), 326 deletions(-) delete mode 100644 backend/src/main/java/com/worlabel/domain/labelcategory/entity/LabelCategory.java create mode 100644 backend/src/main/java/com/worlabel/domain/labelcategory/entity/dto/CategoryResponse.java delete mode 100644 backend/src/main/java/com/worlabel/domain/labelcategory/entity/dto/LabelCategoryRequest.java delete mode 100644 backend/src/main/java/com/worlabel/domain/labelcategory/entity/dto/LabelCategoryResponse.java delete mode 100644 backend/src/main/java/com/worlabel/domain/labelcategory/repository/LabelCategoryRepository.java create mode 100644 backend/src/main/java/com/worlabel/domain/project/entity/dto/ProjectWithCategoryRequest.java diff --git a/backend/src/main/java/com/worlabel/domain/labelcategory/controller/CategoryController.java b/backend/src/main/java/com/worlabel/domain/labelcategory/controller/CategoryController.java index 69c5c71..87f5120 100644 --- a/backend/src/main/java/com/worlabel/domain/labelcategory/controller/CategoryController.java +++ b/backend/src/main/java/com/worlabel/domain/labelcategory/controller/CategoryController.java @@ -1,16 +1,13 @@ package com.worlabel.domain.labelcategory.controller; -import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryRequest; -import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse; +import com.worlabel.domain.labelcategory.entity.dto.CategoryResponse; import com.worlabel.domain.labelcategory.service.ProjectLabelCategoryService; -import com.worlabel.global.annotation.CurrentUser; import com.worlabel.global.config.swagger.SwaggerApiError; import com.worlabel.global.config.swagger.SwaggerApiSuccess; import com.worlabel.global.exception.ErrorCode; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; import lombok.RequiredArgsConstructor; -import org.springframework.data.repository.query.Param; import org.springframework.web.bind.annotation.*; import java.util.List; @@ -23,39 +20,11 @@ public class CategoryController { private final ProjectLabelCategoryService categoryService; - @Operation(summary = "프로젝트 레이블 카테고리 선택", description = "프로젝트 레이블 카테고리를 추가합니다.") - @SwaggerApiSuccess(description = "카테고리 성공적으로 추가합니다.") - @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) - @PostMapping - public void createFolder( - @CurrentUser final Integer memberId, - @RequestBody final LabelCategoryRequest categoryRequest) { - categoryService.createCategory(memberId, categoryRequest); - } - - @Operation(summary = "레이블 카테고리 존재 여부 조회", description = "해당 프로젝트에 같은 레이블 카테고리 이름이 있는지 조회합니다.") - @SwaggerApiSuccess(description = "카테고리 존재 여부를 조회합니다.") - @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) - @GetMapping("/exist") - public boolean existByCategoryName( - @PathVariable("project_id") final Integer projectId, - @Param("categoryName") final String categoryName) { - return categoryService.existByCategoryName(projectId, categoryName); - } - @Operation(summary = "프로젝트 레이블 카테고리 리스트 조회", description = "레이블 카테고리 리스트를 조회합니다..") @SwaggerApiSuccess(description = "카테고리 리스트를 성공적으로 조회합니다.") @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) @GetMapping - public List getCategoryList(@PathVariable("project_id") final Integer projectId) { - return categoryService.getCategoryList(projectId); - } - - @Operation(summary = "카테고리 삭제", description = "카테고리를 삭제합니다.") - @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) - @SwaggerApiSuccess(description = "카테고리를 성공적으로 삭제합니다.") - @DeleteMapping("/{category_id}") - public void deleteCategoryById(@PathVariable("project_id") final Integer projectId, @PathVariable("category_id") final Integer categoryId) { - categoryService.deleteCategory(projectId, categoryId); + public List getCategoryList(@PathVariable("project_id") final Integer projectId) { + return categoryService.getCategoryById(projectId); } } diff --git a/backend/src/main/java/com/worlabel/domain/labelcategory/entity/LabelCategory.java b/backend/src/main/java/com/worlabel/domain/labelcategory/entity/LabelCategory.java deleted file mode 100644 index 5844334..0000000 --- a/backend/src/main/java/com/worlabel/domain/labelcategory/entity/LabelCategory.java +++ /dev/null @@ -1,53 +0,0 @@ -package com.worlabel.domain.labelcategory.entity; - - -import com.worlabel.domain.model.entity.AiModel; -import com.worlabel.global.common.BaseEntity; -import jakarta.persistence.*; -import lombok.AccessLevel; -import lombok.Getter; -import lombok.NoArgsConstructor; - -@Getter -@Entity -@Table(name = "label_category") -@NoArgsConstructor(access = AccessLevel.PROTECTED) -public class LabelCategory extends BaseEntity { - - /** - * 레이블 카테고리 PK - */ - @Id - @Column(name = "label_category_id", nullable = false) - @GeneratedValue(strategy = GenerationType.IDENTITY) - private Integer id; - - /** - * 속한 모델 - */ - @ManyToOne(fetch = FetchType.LAZY) - @JoinColumn(name = "model_id", nullable = false) - private AiModel aiModel; - - /** - * 레이블 카테고리 이름 - */ - @Column(name = "label_category_name", nullable = false) - private String name; - - /** - * 실제 AI 모델의 ai 카테고리 id - */ - @Column(name = "ai_category_id", nullable = false) - private Integer aiCategoryId; - - private LabelCategory(final AiModel aiModel, final String name, final int aiCategoryId) { - this.aiModel = aiModel; - this.name = name; - this.aiCategoryId = aiCategoryId; - } - - public static LabelCategory of(final AiModel aiModel, final String name, final int aiCategoryId) { - return new LabelCategory(aiModel, name, aiCategoryId); - } -} diff --git a/backend/src/main/java/com/worlabel/domain/labelcategory/entity/ProjectCategory.java b/backend/src/main/java/com/worlabel/domain/labelcategory/entity/ProjectCategory.java index eee719d..f540b11 100644 --- a/backend/src/main/java/com/worlabel/domain/labelcategory/entity/ProjectCategory.java +++ b/backend/src/main/java/com/worlabel/domain/labelcategory/entity/ProjectCategory.java @@ -22,11 +22,10 @@ public class ProjectCategory extends BaseEntity { private Integer id; /** - * 레이블 카테고리 + * Model name */ - @ManyToOne(fetch = FetchType.LAZY) - @JoinColumn(name = "label_category_id", nullable = false) - private LabelCategory labelCategory; + @Column(name = "label_category_name", length = 50) + private String labelName; /** * 프로젝트 @@ -35,12 +34,12 @@ public class ProjectCategory extends BaseEntity { @JoinColumn(name = "project_id", nullable = false) private Project project; - private ProjectCategory(LabelCategory labelCategory, Project project) { - this.labelCategory = labelCategory; + private ProjectCategory(String labelName, Project project) { + this.labelName = labelName; this.project = project; } - public static ProjectCategory of(LabelCategory labelCategory, Project project) { - return new ProjectCategory(labelCategory, project); + public static ProjectCategory of(String labelName, Project project) { + return new ProjectCategory(labelName, project); } } diff --git a/backend/src/main/java/com/worlabel/domain/labelcategory/entity/dto/CategoryResponse.java b/backend/src/main/java/com/worlabel/domain/labelcategory/entity/dto/CategoryResponse.java new file mode 100644 index 0000000..e628e4d --- /dev/null +++ b/backend/src/main/java/com/worlabel/domain/labelcategory/entity/dto/CategoryResponse.java @@ -0,0 +1,22 @@ +package com.worlabel.domain.labelcategory.entity.dto; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Getter; + +@Schema(name = "카테고리 응답 DTO", description = "프로젝트 내 카테고리 종류 응답 DTO") +@Getter +@AllArgsConstructor(access = AccessLevel.PRIVATE) +public class CategoryResponse { + + @Schema(description = "카테고리 ID", example = "1") + private Integer id; + + @Schema(description = "라벨링 이름", example = "사람") + private String labelName; + + public static CategoryResponse of(final Integer id, final String labelName) { + return new CategoryResponse(id, labelName); + } +} \ No newline at end of file diff --git a/backend/src/main/java/com/worlabel/domain/labelcategory/entity/dto/LabelCategoryRequest.java b/backend/src/main/java/com/worlabel/domain/labelcategory/entity/dto/LabelCategoryRequest.java deleted file mode 100644 index fb01e63..0000000 --- a/backend/src/main/java/com/worlabel/domain/labelcategory/entity/dto/LabelCategoryRequest.java +++ /dev/null @@ -1,21 +0,0 @@ -package com.worlabel.domain.labelcategory.entity.dto; - -import io.swagger.v3.oas.annotations.media.Schema; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.NoArgsConstructor; - -import java.util.List; - -@Schema(name = "카테고리 요청 DTO", description = "카테고리 생성 및 수정을 위한 요청 DTO") -@Getter -@NoArgsConstructor -@AllArgsConstructor -public class LabelCategoryRequest { - - @Schema(description = "Model Id", example = "1") - private Integer modelId; - - @Schema(description = "카테고리 Id 리스트", example = "[0,3,6,8]") - private List labelCategoryList; -} diff --git a/backend/src/main/java/com/worlabel/domain/labelcategory/entity/dto/LabelCategoryResponse.java b/backend/src/main/java/com/worlabel/domain/labelcategory/entity/dto/LabelCategoryResponse.java deleted file mode 100644 index ab5507b..0000000 --- a/backend/src/main/java/com/worlabel/domain/labelcategory/entity/dto/LabelCategoryResponse.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.worlabel.domain.labelcategory.entity.dto; - -import com.worlabel.domain.labelcategory.entity.LabelCategory; -import io.swagger.v3.oas.annotations.media.Schema; -import lombok.AllArgsConstructor; -import lombok.Getter; - -@Getter -@AllArgsConstructor -@Schema(name = "카테고리 응답 DTO", description = "카테고리 조회 응답 DTO") -public class LabelCategoryResponse { - - @Schema(description = "카테고리 ID", example = "1") - private Integer id; - - @Schema(description = "카테고리 이름", example = "Car") - private String name; - - public static LabelCategoryResponse from(LabelCategory labelCategory) { - return new LabelCategoryResponse(labelCategory.getId(), labelCategory.getName()); - } -} diff --git a/backend/src/main/java/com/worlabel/domain/labelcategory/repository/LabelCategoryRepository.java b/backend/src/main/java/com/worlabel/domain/labelcategory/repository/LabelCategoryRepository.java deleted file mode 100644 index 5b8a7ce..0000000 --- a/backend/src/main/java/com/worlabel/domain/labelcategory/repository/LabelCategoryRepository.java +++ /dev/null @@ -1,22 +0,0 @@ -package com.worlabel.domain.labelcategory.repository; - -import com.worlabel.domain.labelcategory.entity.LabelCategory; -import org.springframework.data.jpa.repository.JpaRepository; -import org.springframework.data.jpa.repository.Query; -import org.springframework.data.repository.query.Param; - -import java.util.List; - -public interface LabelCategoryRepository extends JpaRepository { - - @Query("SELECT l FROM LabelCategory l " + - "WHERE l.aiModel.id = :modelId") - List findAllByModelId(@Param("modelId") Integer modelId); - - - @Query("SELECT l FROM LabelCategory l " + - "WHERE l.aiModel.id = :modelId AND " + - "l.id IN :categoryList") - List findAllByIdsAndModelId(@Param("categoryList") List labelCategoryList,@Param("modelId") Integer modelId); - -} diff --git a/backend/src/main/java/com/worlabel/domain/labelcategory/repository/ProjectLabelCategoryRepository.java b/backend/src/main/java/com/worlabel/domain/labelcategory/repository/ProjectLabelCategoryRepository.java index e195386..dc71729 100644 --- a/backend/src/main/java/com/worlabel/domain/labelcategory/repository/ProjectLabelCategoryRepository.java +++ b/backend/src/main/java/com/worlabel/domain/labelcategory/repository/ProjectLabelCategoryRepository.java @@ -9,14 +9,5 @@ import java.util.List; public interface ProjectLabelCategoryRepository extends JpaRepository { - - @Query("SELECT COUNT(pc) >= 1 FROM ProjectCategory pc " + - "WHERE pc.project.id = :projectId AND " + - "pc.labelCategory.name = :categoryName ") - boolean existsByNameAndProjectId(@Param("categoryName") String categoryName , @Param("projectId") int projectId); - - @Query("SELECT pc FROM ProjectCategory pc " + - "JOIN FETCH pc.labelCategory " + - "WHERE pc.project.id = :projectId ") - List findAllByProjectId(@Param("projectId") Integer projectId); + List findAllByProjectId(Integer projectId); } diff --git a/backend/src/main/java/com/worlabel/domain/labelcategory/service/ProjectLabelCategoryService.java b/backend/src/main/java/com/worlabel/domain/labelcategory/service/ProjectLabelCategoryService.java index a84a9f7..3d4f6cc 100644 --- a/backend/src/main/java/com/worlabel/domain/labelcategory/service/ProjectLabelCategoryService.java +++ b/backend/src/main/java/com/worlabel/domain/labelcategory/service/ProjectLabelCategoryService.java @@ -1,17 +1,10 @@ package com.worlabel.domain.labelcategory.service; -import com.worlabel.domain.labelcategory.entity.LabelCategory; import com.worlabel.domain.labelcategory.entity.ProjectCategory; -import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryRequest; -import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse; -import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository; +import com.worlabel.domain.labelcategory.entity.dto.CategoryResponse; import com.worlabel.domain.labelcategory.repository.ProjectLabelCategoryRepository; import com.worlabel.domain.participant.entity.PrivilegeType; -import com.worlabel.domain.project.entity.Project; -import com.worlabel.domain.project.repository.ProjectRepository; import com.worlabel.global.annotation.CheckPrivilege; -import com.worlabel.global.exception.CustomException; -import com.worlabel.global.exception.ErrorCode; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; @@ -26,48 +19,13 @@ import java.util.List; public class ProjectLabelCategoryService { private final ProjectLabelCategoryRepository projectLabelCategoryRepository; - private final LabelCategoryRepository labelCategoryRepository; - private final ProjectRepository projectRepository; - - @CheckPrivilege(PrivilegeType.EDITOR) - public void createCategory(final Integer projectId, final LabelCategoryRequest categoryRequest) { - Project project = getProject(projectId); - List labelCategoryList = labelCategoryRepository.findAllByIdsAndModelId(categoryRequest.getLabelCategoryList(), categoryRequest.getModelId()); - List projectCategoryList = labelCategoryList.stream().map(o -> ProjectCategory.of(o, project)).toList(); - projectLabelCategoryRepository.saveAll(projectCategoryList); - } - - @CheckPrivilege(PrivilegeType.EDITOR) - public void deleteCategory(final int projectId, final int projectCategoryId) { - ProjectCategory projectCategory = getProjectCategory(projectCategoryId); - projectLabelCategoryRepository.delete(projectCategory); - } @CheckPrivilege(PrivilegeType.VIEWER) - public LabelCategoryResponse getCategoryById(final int projectId, final int categoryId) { - return LabelCategoryResponse.from(getProjectCategory(categoryId).getLabelCategory()); - } + public List getCategoryById(final int projectId) { + List categories = projectLabelCategoryRepository.findAllByProjectId(projectId); - public boolean existByCategoryName(final int projectId, final String categoryName) { - return projectLabelCategoryRepository.existsByNameAndProjectId(categoryName, projectId); - } - - @CheckPrivilege(PrivilegeType.VIEWER) - public List getCategoryList(final Integer projectId) { - return projectLabelCategoryRepository.findAllByProjectId(projectId) - .stream() - .map(ProjectCategory::getLabelCategory) - .map(LabelCategoryResponse::from) + return categories.stream() + .map(category -> CategoryResponse.of(category.getId(), category.getLabelName())) .toList(); } - - private Project getProject(Integer projectId) { - return projectRepository.findById(projectId) - .orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND)); - } - - private ProjectCategory getProjectCategory(final Integer categoryId) { - return projectLabelCategoryRepository.findById(categoryId) - .orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND)); - } } diff --git a/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java b/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java index 3fcbfc0..1cd5545 100644 --- a/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java +++ b/backend/src/main/java/com/worlabel/domain/model/controller/AiModelController.java @@ -1,6 +1,5 @@ package com.worlabel.domain.model.controller; -import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse; import com.worlabel.domain.model.entity.dto.AiModelRequest; import com.worlabel.domain.model.entity.dto.AiModelResponse; import com.worlabel.domain.model.entity.dto.ModelTrainRequest; @@ -39,15 +38,6 @@ public class AiModelController { return aiModelService.getModelList(projectId); } - @Operation(summary = "특정 모델 카테고리", description = "모델의 카테고리를 조회합니다.") - @SwaggerApiSuccess(description = "카테고리를 조회합니다.") - @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) - @GetMapping("/models/{model_id}/categories") - public List getCategories( - @PathVariable("model_id") final Integer modelId) { - return aiModelService.getCategories(modelId); - } - @Operation(summary = "프로젝트 모델 추가", description = "프로젝트에 있는 모델을 추가합니다.") @SwaggerApiSuccess(description = "프로젝트 모델을 추가합니다.") @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) diff --git a/backend/src/main/java/com/worlabel/domain/model/entity/AiModel.java b/backend/src/main/java/com/worlabel/domain/model/entity/AiModel.java index bf9cad9..aeb1c30 100644 --- a/backend/src/main/java/com/worlabel/domain/model/entity/AiModel.java +++ b/backend/src/main/java/com/worlabel/domain/model/entity/AiModel.java @@ -1,7 +1,6 @@ package com.worlabel.domain.model.entity; import com.fasterxml.jackson.annotation.JsonIgnore; -import com.worlabel.domain.labelcategory.entity.LabelCategory; import com.worlabel.domain.project.entity.Project; import com.worlabel.global.common.BaseEntity; import jakarta.persistence.*; @@ -9,9 +8,6 @@ import lombok.AccessLevel; import lombok.Getter; import lombok.NoArgsConstructor; -import java.util.ArrayList; -import java.util.List; - @Getter @Entity @Table(name = "ai_model") @@ -52,12 +48,6 @@ public class AiModel extends BaseEntity { @JsonIgnore private Project project; - /** - * 모델에 속한 카테고리 - */ - @OneToMany(mappedBy = "aiModel", fetch = FetchType.LAZY) - private List categoryList = new ArrayList<>(); - private AiModel(String name, String modelKey, int version, Project project) { this.name = name; this.modelKey = modelKey; diff --git a/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java b/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java index 15f544e..957a73c 100644 --- a/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java +++ b/backend/src/main/java/com/worlabel/domain/model/service/AiModelService.java @@ -3,19 +3,13 @@ package com.worlabel.domain.model.service; import com.google.gson.Gson; import com.google.gson.reflect.TypeToken; import com.worlabel.domain.image.entity.Image; -import com.worlabel.domain.image.entity.LabelStatus; import com.worlabel.domain.image.repository.ImageRepository; -import com.worlabel.domain.labelcategory.entity.LabelCategory; import com.worlabel.domain.labelcategory.entity.ProjectCategory; -import com.worlabel.domain.labelcategory.entity.dto.DefaultLabelCategoryResponse; -import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse; -import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository; import com.worlabel.domain.model.entity.AiModel; import com.worlabel.domain.model.entity.dto.*; import com.worlabel.domain.model.repository.AiModelRepository; import com.worlabel.domain.participant.entity.PrivilegeType; import com.worlabel.domain.progress.service.ProgressService; -import com.worlabel.domain.project.dto.AiDto; import com.worlabel.domain.project.dto.AiDto.TrainDataInfo; import com.worlabel.domain.project.dto.AiDto.TrainRequest; import com.worlabel.domain.project.entity.Project; @@ -25,7 +19,6 @@ import com.worlabel.domain.report.service.ReportService; import com.worlabel.domain.result.entity.Result; import com.worlabel.domain.result.repository.ResultRepository; import com.worlabel.global.annotation.CheckPrivilege; -import com.worlabel.global.cache.CacheKey; import com.worlabel.global.exception.CustomException; import com.worlabel.global.exception.ErrorCode; import com.worlabel.global.service.AiRequestService; @@ -38,8 +31,6 @@ import org.springframework.transaction.annotation.Transactional; import java.lang.reflect.Type; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; -import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -50,7 +41,6 @@ import java.util.stream.Collectors; @RequiredArgsConstructor public class AiModelService { - private final LabelCategoryRepository labelCategoryRepository; private final RedisTemplate redisTemplate; private final AiModelRepository aiModelRepository; private final ProjectRepository projectRepository; @@ -62,46 +52,6 @@ public class AiModelService { private final ProgressService progressService; private final ReportService reportService; - // @PostConstruct - public void loadDefaultModel() { - String url = "model/default"; - List defaultResponseList = aiRequestService.getRequest(url, this::converter); - - // 1. DefaultResponse의 Key값만 모아서 리스트로 만든다. - List allModelKeys = defaultResponseList.stream() - .map(response -> response.getDefaultAiModelResponse().getModelKey()) - .toList(); - - // 2. 해당 Key값이 DB에 있는지 확인하기 (한 번의 쿼리로) - List existingModelKeys = aiModelRepository.findAllByModelKeyIn(allModelKeys).stream() - .map(AiModel::getModelKey) - .toList(); - - // 3. DB에 없는 Key만 필터링해서 처리 - List newModel = defaultResponseList.stream() - .filter(model -> !existingModelKeys.contains(model.getDefaultAiModelResponse().getModelKey())) - .toList(); - - - // 새롭게 추가된 값을 디비에 저장 - List aiModels = new ArrayList<>(); - List categories = new ArrayList<>(); - for (DefaultResponse defaultResponse : newModel) { - DefaultAiModelResponse defaultAiModelResponse = defaultResponse.getDefaultAiModelResponse(); - AiModel newAiModel = AiModel.of(defaultAiModelResponse.getName(), defaultAiModelResponse.getModelKey(), 0, null); - aiModels.add(newAiModel); - - List defaultLabelCategoryResponseList = defaultResponse.getDefaultLabelCategoryResponseList(); - - for (DefaultLabelCategoryResponse categoryResponse : defaultLabelCategoryResponseList) { - categories.add(LabelCategory.of(newAiModel, categoryResponse.getName(), categoryResponse.getAiId())); - } - } - - aiModelRepository.saveAll(aiModels); - labelCategoryRepository.saveAll(categories); - } - @Transactional(readOnly = true) public List getModelList(final Integer projectId) { int progressModelId = progressService.getProgressModelByProjectId(projectId); @@ -131,23 +81,15 @@ public class AiModelService { return aiModelRepository.findCustomModelById(modelId).orElseThrow(() -> new CustomException(ErrorCode.BAD_REQUEST)); } - @Transactional(readOnly = true) - public List getCategories(final Integer modelId) { - List categoryList = labelCategoryRepository.findAllByModelId(modelId); - return categoryList.stream() - .map(LabelCategoryResponse::from) - .toList(); - } - @CheckPrivilege(PrivilegeType.EDITOR) public void train(final Integer projectId, final ModelTrainRequest trainRequest) { // FastAPI 서버로 학습 요청을 전송 Project project = getProject(projectId); AiModel model = getModel(trainRequest.getModelId()); - Map labelMap = project.getCategoryList().stream() + Map labelMap = project.getCategoryList().stream() .collect(Collectors.toMap( - category -> category.getLabelCategory().getId(), + ProjectCategory::getLabelName, ProjectCategory::getId )); diff --git a/backend/src/main/java/com/worlabel/domain/project/controller/ProjectController.java b/backend/src/main/java/com/worlabel/domain/project/controller/ProjectController.java index 5775a64..3d09dfa 100644 --- a/backend/src/main/java/com/worlabel/domain/project/controller/ProjectController.java +++ b/backend/src/main/java/com/worlabel/domain/project/controller/ProjectController.java @@ -2,10 +2,7 @@ package com.worlabel.domain.project.controller; import com.worlabel.domain.participant.entity.dto.ParticipantRequest; import com.worlabel.domain.project.dto.AutoModelRequest; -import com.worlabel.domain.project.entity.dto.ProjectMemberResponse; -import com.worlabel.domain.project.entity.dto.ProjectRequest; -import com.worlabel.domain.project.entity.dto.ProjectResponse; -import com.worlabel.domain.project.entity.dto.ProjectWithThumbnailResponse; +import com.worlabel.domain.project.entity.dto.*; import com.worlabel.domain.project.service.ProjectService; import com.worlabel.global.annotation.CurrentUser; import com.worlabel.global.config.swagger.SwaggerApiError; @@ -38,7 +35,7 @@ public class ProjectController { public ProjectResponse createProject( @CurrentUser final Integer memberId, @PathVariable("workspace_id") final Integer workspaceId, - @Valid @RequestBody final ProjectRequest projectRequest) { + @Valid @RequestBody final ProjectWithCategoryRequest projectRequest) { return projectService.createProject(memberId, workspaceId, projectRequest); } diff --git a/backend/src/main/java/com/worlabel/domain/project/dto/AiDto.java b/backend/src/main/java/com/worlabel/domain/project/dto/AiDto.java index cb3a3fa..e869ada 100644 --- a/backend/src/main/java/com/worlabel/domain/project/dto/AiDto.java +++ b/backend/src/main/java/com/worlabel/domain/project/dto/AiDto.java @@ -45,7 +45,7 @@ public class AiDto { private Integer modelId; @JsonProperty("label_map") - private Map labelMap; + private Map labelMap; @JsonProperty("data") private List data; @@ -62,7 +62,7 @@ public class AiDto { private Optimizer optimizer; - public static TrainRequest of(final Integer projectId, final Integer modelId, final String modelKey, final Map labelMap, final List data, final ModelTrainRequest trainRequest) { + public static TrainRequest of(final Integer projectId, final Integer modelId, final String modelKey, final Map labelMap, final List data, final ModelTrainRequest trainRequest) { TrainRequest request = new TrainRequest(); request.projectId = projectId; request.modelId = modelId; @@ -93,7 +93,7 @@ public class AiDto { private String modelKey; @JsonProperty("label_map") - private HashMap labelMap; + private HashMap labelMap; @JsonProperty("image_list") private List imageList; @@ -104,7 +104,7 @@ public class AiDto { @JsonProperty("iou_threshold") private Double iouThreshold; - public static AutoLabelingRequest of(final Integer projectId, final String modelKey, final HashMap labelMap, final List imageList) { + public static AutoLabelingRequest of(final Integer projectId, final String modelKey, final HashMap labelMap, final List imageList) { return new AutoLabelingRequest(projectId, modelKey, labelMap, imageList, 0.25, 0.45); } } diff --git a/backend/src/main/java/com/worlabel/domain/project/entity/dto/ProjectWithCategoryRequest.java b/backend/src/main/java/com/worlabel/domain/project/entity/dto/ProjectWithCategoryRequest.java new file mode 100644 index 0000000..2300e14 --- /dev/null +++ b/backend/src/main/java/com/worlabel/domain/project/entity/dto/ProjectWithCategoryRequest.java @@ -0,0 +1,31 @@ +package com.worlabel.domain.project.entity.dto; + +import com.worlabel.domain.project.entity.ProjectType; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotEmpty; +import jakarta.validation.constraints.NotNull; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; + +import java.util.List; + +@Schema(name = "프로젝트 + 카테고리 요청 dto", description = "프로젝트 + 카테고리 요청 DTO") +@NoArgsConstructor(access = AccessLevel.PRIVATE) +@AllArgsConstructor +@Getter +public class ProjectWithCategoryRequest { + + @Schema(description = "프로젝트 제목", example = "삼성 갤럭시 s23") + @NotEmpty(message = "제목을 입력하세요.") + private String title; + + @Schema(description = "카테고리 목록 이름", example = "자동차, 사람") + @NotEmpty(message = "카테고리 목록 이름") + private List categories; + + @Schema(description = "프로젝트 유형", example = "classification") + @NotNull(message = "카테고리를 입력하세요.") + private ProjectType projectType; +} diff --git a/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java b/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java index 0ab5937..b567542 100644 --- a/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java +++ b/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java @@ -7,6 +7,7 @@ import com.worlabel.domain.image.entity.Image; import com.worlabel.domain.image.entity.LabelStatus; import com.worlabel.domain.image.repository.ImageRepository; import com.worlabel.domain.labelcategory.entity.ProjectCategory; +import com.worlabel.domain.labelcategory.repository.ProjectLabelCategoryRepository; import com.worlabel.domain.member.entity.Member; import com.worlabel.domain.member.repository.MemberRepository; import com.worlabel.domain.model.entity.AiModel; @@ -23,10 +24,7 @@ import com.worlabel.domain.project.dto.AiDto.AutoLabelingRequest; import com.worlabel.domain.project.dto.AiDto.AutoLabelingResult; import com.worlabel.domain.project.dto.AutoModelRequest; import com.worlabel.domain.project.entity.Project; -import com.worlabel.domain.project.entity.dto.ProjectMemberResponse; -import com.worlabel.domain.project.entity.dto.ProjectRequest; -import com.worlabel.domain.project.entity.dto.ProjectResponse; -import com.worlabel.domain.project.entity.dto.ProjectWithThumbnailResponse; +import com.worlabel.domain.project.entity.dto.*; import com.worlabel.domain.project.repository.ProjectRepository; import com.worlabel.domain.workspace.entity.Workspace; import com.worlabel.domain.workspace.repository.WorkspaceRepository; @@ -58,15 +56,15 @@ public class ProjectService { private final ProjectRepository projectRepository; private final MemberRepository memberRepository; private final ProgressService progressService; + private final ProjectLabelCategoryRepository projectLabelCategoryRepository; private final S3UploadService s3UploadService; private final ImageRepository imageRepository; private final AiRequestService aiService; - private final Gson gson; @Transactional - public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectRequest projectRequest) { + public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectWithCategoryRequest projectRequest) { Workspace workspace = getWorkspace(memberId, workspaceId); Member member = getMember(memberId); @@ -74,6 +72,11 @@ public class ProjectService { Participant participant = Participant.of(project, member, PrivilegeType.ADMIN); projectRepository.save(project); + + for (String labelName : projectRequest.getCategories()) { + projectLabelCategoryRepository.save(ProjectCategory.of(labelName, project)); + } + participantRepository.save(participant); return ProjectResponse.from(project); @@ -88,8 +91,8 @@ public class ProjectService { @Transactional(readOnly = true) public List getProjectsByWorkspaceId(final Integer workspaceId, final Integer memberId, final Integer lastProjectId, final Integer pageSize) { return projectRepository.findProjectsByWorkspaceIdAndMemberIdWithPagination(workspaceId, memberId, lastProjectId, pageSize).stream() - .map(project -> ProjectWithThumbnailResponse.from(project, getFirstImageWithProject(project))) - .toList(); + .map(project -> ProjectWithThumbnailResponse.from(project, getFirstImageWithProject(project))) + .toList(); } @Transactional @@ -165,7 +168,7 @@ public class ProjectService { String endPoint = project.getProjectType().getValue() + "/predict"; List imageList = imageRepository.findImagesByProjectIdAndPendingOrInProgress(projectId); - if(imageList.isEmpty()){ + if (imageList.isEmpty()) { throw new CustomException(ErrorCode.DATA_NOT_FOUND); } @@ -173,7 +176,7 @@ public class ProjectService { .map(AutoLabelingImageRequest::of) .toList(); - HashMap labelMap = getLabelMap(project); + HashMap labelMap = getLabelMap(project); AiModel aiModel = getAiModel(request); AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, aiModel.getModelKey(), labelMap, imageRequestList); @@ -187,9 +190,9 @@ public class ProjectService { // TODO: 트랜잭션 설정 @Transactional public void saveAutoLabelList(final List resultList) { - for(AutoLabelingResult result: resultList) { + for (AutoLabelingResult result : resultList) { Image image = getImage(result.getImageId()); - if(image.getStatus() == LabelStatus.SAVE || image.getStatus() == LabelStatus.IN_PROGRESS) continue; + if (image.getStatus() == LabelStatus.SAVE || image.getStatus() == LabelStatus.IN_PROGRESS) continue; String dataPath = image.getDataPath(); s3UploadService.uploadJson(result.getData(), dataPath); image.updateStatus(LabelStatus.IN_PROGRESS); @@ -219,21 +222,22 @@ public class ProjectService { .orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND)); } - private HashMap getLabelMap(Project project) { - HashMap labelMap = new HashMap<>(); + private HashMap getLabelMap(Project project) { + HashMap labelMap = new HashMap<>(); List category = project.getCategoryList(); for (ProjectCategory projectCategory : category) { - int aiId = projectCategory.getLabelCategory().getAiCategoryId(); - if (labelMap.containsKey(aiId)) continue; - - labelMap.put(aiId, projectCategory.getId()); + if (labelMap.containsKey(projectCategory.getLabelName())) { + continue; + } + labelMap.put(projectCategory.getLabelName(), projectCategory.getId()); } + return labelMap; } - private Image getImage(final Long imageId){ + private Image getImage(final Long imageId) { return imageRepository.findById(imageId) - .orElseThrow(()-> new CustomException(ErrorCode.DATA_NOT_FOUND)); + .orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND)); } private Project getProject(final Integer projectId) {