Merge branch 'be/refactor/category' into 'be/develop'

Refactor: 카테고리 수정  - S11P21S002-223

See merge request s11-s-project/S11P21S002!210
This commit is contained in:
홍창기 2024-09-27 11:17:26 +09:00
commit ffe79cc79b
16 changed files with 101 additions and 326 deletions

View File

@ -1,16 +1,13 @@
package com.worlabel.domain.labelcategory.controller; package com.worlabel.domain.labelcategory.controller;
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryRequest; import com.worlabel.domain.labelcategory.entity.dto.CategoryResponse;
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
import com.worlabel.domain.labelcategory.service.ProjectLabelCategoryService; import com.worlabel.domain.labelcategory.service.ProjectLabelCategoryService;
import com.worlabel.global.annotation.CurrentUser;
import com.worlabel.global.config.swagger.SwaggerApiError; import com.worlabel.global.config.swagger.SwaggerApiError;
import com.worlabel.global.config.swagger.SwaggerApiSuccess; import com.worlabel.global.config.swagger.SwaggerApiSuccess;
import com.worlabel.global.exception.ErrorCode; import com.worlabel.global.exception.ErrorCode;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import org.springframework.data.repository.query.Param;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import java.util.List; import java.util.List;
@ -23,39 +20,11 @@ public class CategoryController {
private final ProjectLabelCategoryService categoryService; private final ProjectLabelCategoryService categoryService;
@Operation(summary = "프로젝트 레이블 카테고리 선택", description = "프로젝트 레이블 카테고리를 추가합니다.")
@SwaggerApiSuccess(description = "카테고리 성공적으로 추가합니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@PostMapping
public void createFolder(
@CurrentUser final Integer memberId,
@RequestBody final LabelCategoryRequest categoryRequest) {
categoryService.createCategory(memberId, categoryRequest);
}
@Operation(summary = "레이블 카테고리 존재 여부 조회", description = "해당 프로젝트에 같은 레이블 카테고리 이름이 있는지 조회합니다.")
@SwaggerApiSuccess(description = "카테고리 존재 여부를 조회합니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@GetMapping("/exist")
public boolean existByCategoryName(
@PathVariable("project_id") final Integer projectId,
@Param("categoryName") final String categoryName) {
return categoryService.existByCategoryName(projectId, categoryName);
}
@Operation(summary = "프로젝트 레이블 카테고리 리스트 조회", description = "레이블 카테고리 리스트를 조회합니다..") @Operation(summary = "프로젝트 레이블 카테고리 리스트 조회", description = "레이블 카테고리 리스트를 조회합니다..")
@SwaggerApiSuccess(description = "카테고리 리스트를 성공적으로 조회합니다.") @SwaggerApiSuccess(description = "카테고리 리스트를 성공적으로 조회합니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@GetMapping @GetMapping
public List<LabelCategoryResponse> getCategoryList(@PathVariable("project_id") final Integer projectId) { public List<CategoryResponse> getCategoryList(@PathVariable("project_id") final Integer projectId) {
return categoryService.getCategoryList(projectId); return categoryService.getCategoryById(projectId);
}
@Operation(summary = "카테고리 삭제", description = "카테고리를 삭제합니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@SwaggerApiSuccess(description = "카테고리를 성공적으로 삭제합니다.")
@DeleteMapping("/{category_id}")
public void deleteCategoryById(@PathVariable("project_id") final Integer projectId, @PathVariable("category_id") final Integer categoryId) {
categoryService.deleteCategory(projectId, categoryId);
} }
} }

View File

@ -1,53 +0,0 @@
package com.worlabel.domain.labelcategory.entity;
import com.worlabel.domain.model.entity.AiModel;
import com.worlabel.global.common.BaseEntity;
import jakarta.persistence.*;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.NoArgsConstructor;
@Getter
@Entity
@Table(name = "label_category")
@NoArgsConstructor(access = AccessLevel.PROTECTED)
public class LabelCategory extends BaseEntity {
/**
* 레이블 카테고리 PK
*/
@Id
@Column(name = "label_category_id", nullable = false)
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Integer id;
/**
* 속한 모델
*/
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "model_id", nullable = false)
private AiModel aiModel;
/**
* 레이블 카테고리 이름
*/
@Column(name = "label_category_name", nullable = false)
private String name;
/**
* 실제 AI 모델의 ai 카테고리 id
*/
@Column(name = "ai_category_id", nullable = false)
private Integer aiCategoryId;
private LabelCategory(final AiModel aiModel, final String name, final int aiCategoryId) {
this.aiModel = aiModel;
this.name = name;
this.aiCategoryId = aiCategoryId;
}
public static LabelCategory of(final AiModel aiModel, final String name, final int aiCategoryId) {
return new LabelCategory(aiModel, name, aiCategoryId);
}
}

View File

@ -22,11 +22,10 @@ public class ProjectCategory extends BaseEntity {
private Integer id; private Integer id;
/** /**
* 레이블 카테고리 * Model name
*/ */
@ManyToOne(fetch = FetchType.LAZY) @Column(name = "label_category_name", length = 50)
@JoinColumn(name = "label_category_id", nullable = false) private String labelName;
private LabelCategory labelCategory;
/** /**
* 프로젝트 * 프로젝트
@ -35,12 +34,12 @@ public class ProjectCategory extends BaseEntity {
@JoinColumn(name = "project_id", nullable = false) @JoinColumn(name = "project_id", nullable = false)
private Project project; private Project project;
private ProjectCategory(LabelCategory labelCategory, Project project) { private ProjectCategory(String labelName, Project project) {
this.labelCategory = labelCategory; this.labelName = labelName;
this.project = project; this.project = project;
} }
public static ProjectCategory of(LabelCategory labelCategory, Project project) { public static ProjectCategory of(String labelName, Project project) {
return new ProjectCategory(labelCategory, project); return new ProjectCategory(labelName, project);
} }
} }

View File

@ -0,0 +1,22 @@
package com.worlabel.domain.labelcategory.entity.dto;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Schema(name = "카테고리 응답 DTO", description = "프로젝트 내 카테고리 종류 응답 DTO")
@Getter
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class CategoryResponse {
@Schema(description = "카테고리 ID", example = "1")
private Integer id;
@Schema(description = "라벨링 이름", example = "사람")
private String labelName;
public static CategoryResponse of(final Integer id, final String labelName) {
return new CategoryResponse(id, labelName);
}
}

View File

@ -1,21 +0,0 @@
package com.worlabel.domain.labelcategory.entity.dto;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import java.util.List;
@Schema(name = "카테고리 요청 DTO", description = "카테고리 생성 및 수정을 위한 요청 DTO")
@Getter
@NoArgsConstructor
@AllArgsConstructor
public class LabelCategoryRequest {
@Schema(description = "Model Id", example = "1")
private Integer modelId;
@Schema(description = "카테고리 Id 리스트", example = "[0,3,6,8]")
private List<Integer> labelCategoryList;
}

View File

@ -1,22 +0,0 @@
package com.worlabel.domain.labelcategory.entity.dto;
import com.worlabel.domain.labelcategory.entity.LabelCategory;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
@Schema(name = "카테고리 응답 DTO", description = "카테고리 조회 응답 DTO")
public class LabelCategoryResponse {
@Schema(description = "카테고리 ID", example = "1")
private Integer id;
@Schema(description = "카테고리 이름", example = "Car")
private String name;
public static LabelCategoryResponse from(LabelCategory labelCategory) {
return new LabelCategoryResponse(labelCategory.getId(), labelCategory.getName());
}
}

View File

@ -1,22 +0,0 @@
package com.worlabel.domain.labelcategory.repository;
import com.worlabel.domain.labelcategory.entity.LabelCategory;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import java.util.List;
public interface LabelCategoryRepository extends JpaRepository<LabelCategory, Integer> {
@Query("SELECT l FROM LabelCategory l " +
"WHERE l.aiModel.id = :modelId")
List<LabelCategory> findAllByModelId(@Param("modelId") Integer modelId);
@Query("SELECT l FROM LabelCategory l " +
"WHERE l.aiModel.id = :modelId AND " +
"l.id IN :categoryList")
List<LabelCategory> findAllByIdsAndModelId(@Param("categoryList") List<Integer> labelCategoryList,@Param("modelId") Integer modelId);
}

View File

@ -9,14 +9,5 @@ import java.util.List;
public interface ProjectLabelCategoryRepository extends JpaRepository<ProjectCategory, Integer> { public interface ProjectLabelCategoryRepository extends JpaRepository<ProjectCategory, Integer> {
List<ProjectCategory> findAllByProjectId(Integer projectId);
@Query("SELECT COUNT(pc) >= 1 FROM ProjectCategory pc " +
"WHERE pc.project.id = :projectId AND " +
"pc.labelCategory.name = :categoryName ")
boolean existsByNameAndProjectId(@Param("categoryName") String categoryName , @Param("projectId") int projectId);
@Query("SELECT pc FROM ProjectCategory pc " +
"JOIN FETCH pc.labelCategory " +
"WHERE pc.project.id = :projectId ")
List<ProjectCategory> findAllByProjectId(@Param("projectId") Integer projectId);
} }

View File

@ -1,17 +1,10 @@
package com.worlabel.domain.labelcategory.service; package com.worlabel.domain.labelcategory.service;
import com.worlabel.domain.labelcategory.entity.LabelCategory;
import com.worlabel.domain.labelcategory.entity.ProjectCategory; import com.worlabel.domain.labelcategory.entity.ProjectCategory;
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryRequest; import com.worlabel.domain.labelcategory.entity.dto.CategoryResponse;
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository;
import com.worlabel.domain.labelcategory.repository.ProjectLabelCategoryRepository; import com.worlabel.domain.labelcategory.repository.ProjectLabelCategoryRepository;
import com.worlabel.domain.participant.entity.PrivilegeType; import com.worlabel.domain.participant.entity.PrivilegeType;
import com.worlabel.domain.project.entity.Project;
import com.worlabel.domain.project.repository.ProjectRepository;
import com.worlabel.global.annotation.CheckPrivilege; import com.worlabel.global.annotation.CheckPrivilege;
import com.worlabel.global.exception.CustomException;
import com.worlabel.global.exception.ErrorCode;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@ -26,48 +19,13 @@ import java.util.List;
public class ProjectLabelCategoryService { public class ProjectLabelCategoryService {
private final ProjectLabelCategoryRepository projectLabelCategoryRepository; private final ProjectLabelCategoryRepository projectLabelCategoryRepository;
private final LabelCategoryRepository labelCategoryRepository;
private final ProjectRepository projectRepository;
@CheckPrivilege(PrivilegeType.EDITOR)
public void createCategory(final Integer projectId, final LabelCategoryRequest categoryRequest) {
Project project = getProject(projectId);
List<LabelCategory> labelCategoryList = labelCategoryRepository.findAllByIdsAndModelId(categoryRequest.getLabelCategoryList(), categoryRequest.getModelId());
List<ProjectCategory> projectCategoryList = labelCategoryList.stream().map(o -> ProjectCategory.of(o, project)).toList();
projectLabelCategoryRepository.saveAll(projectCategoryList);
}
@CheckPrivilege(PrivilegeType.EDITOR)
public void deleteCategory(final int projectId, final int projectCategoryId) {
ProjectCategory projectCategory = getProjectCategory(projectCategoryId);
projectLabelCategoryRepository.delete(projectCategory);
}
@CheckPrivilege(PrivilegeType.VIEWER) @CheckPrivilege(PrivilegeType.VIEWER)
public LabelCategoryResponse getCategoryById(final int projectId, final int categoryId) { public List<CategoryResponse> getCategoryById(final int projectId) {
return LabelCategoryResponse.from(getProjectCategory(categoryId).getLabelCategory()); List<ProjectCategory> categories = projectLabelCategoryRepository.findAllByProjectId(projectId);
}
public boolean existByCategoryName(final int projectId, final String categoryName) { return categories.stream()
return projectLabelCategoryRepository.existsByNameAndProjectId(categoryName, projectId); .map(category -> CategoryResponse.of(category.getId(), category.getLabelName()))
}
@CheckPrivilege(PrivilegeType.VIEWER)
public List<LabelCategoryResponse> getCategoryList(final Integer projectId) {
return projectLabelCategoryRepository.findAllByProjectId(projectId)
.stream()
.map(ProjectCategory::getLabelCategory)
.map(LabelCategoryResponse::from)
.toList(); .toList();
} }
private Project getProject(Integer projectId) {
return projectRepository.findById(projectId)
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
}
private ProjectCategory getProjectCategory(final Integer categoryId) {
return projectLabelCategoryRepository.findById(categoryId)
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
}
} }

View File

@ -1,6 +1,5 @@
package com.worlabel.domain.model.controller; package com.worlabel.domain.model.controller;
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
import com.worlabel.domain.model.entity.dto.AiModelRequest; import com.worlabel.domain.model.entity.dto.AiModelRequest;
import com.worlabel.domain.model.entity.dto.AiModelResponse; import com.worlabel.domain.model.entity.dto.AiModelResponse;
import com.worlabel.domain.model.entity.dto.ModelTrainRequest; import com.worlabel.domain.model.entity.dto.ModelTrainRequest;
@ -39,15 +38,6 @@ public class AiModelController {
return aiModelService.getModelList(projectId); return aiModelService.getModelList(projectId);
} }
@Operation(summary = "특정 모델 카테고리", description = "모델의 카테고리를 조회합니다.")
@SwaggerApiSuccess(description = "카테고리를 조회합니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@GetMapping("/models/{model_id}/categories")
public List<LabelCategoryResponse> getCategories(
@PathVariable("model_id") final Integer modelId) {
return aiModelService.getCategories(modelId);
}
@Operation(summary = "프로젝트 모델 추가", description = "프로젝트에 있는 모델을 추가합니다.") @Operation(summary = "프로젝트 모델 추가", description = "프로젝트에 있는 모델을 추가합니다.")
@SwaggerApiSuccess(description = "프로젝트 모델을 추가합니다.") @SwaggerApiSuccess(description = "프로젝트 모델을 추가합니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})

View File

@ -1,7 +1,6 @@
package com.worlabel.domain.model.entity; package com.worlabel.domain.model.entity;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.worlabel.domain.labelcategory.entity.LabelCategory;
import com.worlabel.domain.project.entity.Project; import com.worlabel.domain.project.entity.Project;
import com.worlabel.global.common.BaseEntity; import com.worlabel.global.common.BaseEntity;
import jakarta.persistence.*; import jakarta.persistence.*;
@ -9,9 +8,6 @@ import lombok.AccessLevel;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.List;
@Getter @Getter
@Entity @Entity
@Table(name = "ai_model") @Table(name = "ai_model")
@ -52,12 +48,6 @@ public class AiModel extends BaseEntity {
@JsonIgnore @JsonIgnore
private Project project; private Project project;
/**
* 모델에 속한 카테고리
*/
@OneToMany(mappedBy = "aiModel", fetch = FetchType.LAZY)
private List<LabelCategory> categoryList = new ArrayList<>();
private AiModel(String name, String modelKey, int version, Project project) { private AiModel(String name, String modelKey, int version, Project project) {
this.name = name; this.name = name;
this.modelKey = modelKey; this.modelKey = modelKey;

View File

@ -3,19 +3,13 @@ package com.worlabel.domain.model.service;
import com.google.gson.Gson; import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken; import com.google.gson.reflect.TypeToken;
import com.worlabel.domain.image.entity.Image; import com.worlabel.domain.image.entity.Image;
import com.worlabel.domain.image.entity.LabelStatus;
import com.worlabel.domain.image.repository.ImageRepository; import com.worlabel.domain.image.repository.ImageRepository;
import com.worlabel.domain.labelcategory.entity.LabelCategory;
import com.worlabel.domain.labelcategory.entity.ProjectCategory; import com.worlabel.domain.labelcategory.entity.ProjectCategory;
import com.worlabel.domain.labelcategory.entity.dto.DefaultLabelCategoryResponse;
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository;
import com.worlabel.domain.model.entity.AiModel; import com.worlabel.domain.model.entity.AiModel;
import com.worlabel.domain.model.entity.dto.*; import com.worlabel.domain.model.entity.dto.*;
import com.worlabel.domain.model.repository.AiModelRepository; import com.worlabel.domain.model.repository.AiModelRepository;
import com.worlabel.domain.participant.entity.PrivilegeType; import com.worlabel.domain.participant.entity.PrivilegeType;
import com.worlabel.domain.progress.service.ProgressService; import com.worlabel.domain.progress.service.ProgressService;
import com.worlabel.domain.project.dto.AiDto;
import com.worlabel.domain.project.dto.AiDto.TrainDataInfo; import com.worlabel.domain.project.dto.AiDto.TrainDataInfo;
import com.worlabel.domain.project.dto.AiDto.TrainRequest; import com.worlabel.domain.project.dto.AiDto.TrainRequest;
import com.worlabel.domain.project.entity.Project; import com.worlabel.domain.project.entity.Project;
@ -25,7 +19,6 @@ import com.worlabel.domain.report.service.ReportService;
import com.worlabel.domain.result.entity.Result; import com.worlabel.domain.result.entity.Result;
import com.worlabel.domain.result.repository.ResultRepository; import com.worlabel.domain.result.repository.ResultRepository;
import com.worlabel.global.annotation.CheckPrivilege; import com.worlabel.global.annotation.CheckPrivilege;
import com.worlabel.global.cache.CacheKey;
import com.worlabel.global.exception.CustomException; import com.worlabel.global.exception.CustomException;
import com.worlabel.global.exception.ErrorCode; import com.worlabel.global.exception.ErrorCode;
import com.worlabel.global.service.AiRequestService; import com.worlabel.global.service.AiRequestService;
@ -38,8 +31,6 @@ import org.springframework.transaction.annotation.Transactional;
import java.lang.reflect.Type; import java.lang.reflect.Type;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -50,7 +41,6 @@ import java.util.stream.Collectors;
@RequiredArgsConstructor @RequiredArgsConstructor
public class AiModelService { public class AiModelService {
private final LabelCategoryRepository labelCategoryRepository;
private final RedisTemplate<String, Object> redisTemplate; private final RedisTemplate<String, Object> redisTemplate;
private final AiModelRepository aiModelRepository; private final AiModelRepository aiModelRepository;
private final ProjectRepository projectRepository; private final ProjectRepository projectRepository;
@ -62,46 +52,6 @@ public class AiModelService {
private final ProgressService progressService; private final ProgressService progressService;
private final ReportService reportService; private final ReportService reportService;
// @PostConstruct
public void loadDefaultModel() {
String url = "model/default";
List<DefaultResponse> defaultResponseList = aiRequestService.getRequest(url, this::converter);
// 1. DefaultResponse의 Key값만 모아서 리스트로 만든다.
List<String> allModelKeys = defaultResponseList.stream()
.map(response -> response.getDefaultAiModelResponse().getModelKey())
.toList();
// 2. 해당 Key값이 DB에 있는지 확인하기 ( 번의 쿼리로)
List<String> existingModelKeys = aiModelRepository.findAllByModelKeyIn(allModelKeys).stream()
.map(AiModel::getModelKey)
.toList();
// 3. DB에 없는 Key만 필터링해서 처리
List<DefaultResponse> newModel = defaultResponseList.stream()
.filter(model -> !existingModelKeys.contains(model.getDefaultAiModelResponse().getModelKey()))
.toList();
// 새롭게 추가된 값을 디비에 저장
List<AiModel> aiModels = new ArrayList<>();
List<LabelCategory> categories = new ArrayList<>();
for (DefaultResponse defaultResponse : newModel) {
DefaultAiModelResponse defaultAiModelResponse = defaultResponse.getDefaultAiModelResponse();
AiModel newAiModel = AiModel.of(defaultAiModelResponse.getName(), defaultAiModelResponse.getModelKey(), 0, null);
aiModels.add(newAiModel);
List<DefaultLabelCategoryResponse> defaultLabelCategoryResponseList = defaultResponse.getDefaultLabelCategoryResponseList();
for (DefaultLabelCategoryResponse categoryResponse : defaultLabelCategoryResponseList) {
categories.add(LabelCategory.of(newAiModel, categoryResponse.getName(), categoryResponse.getAiId()));
}
}
aiModelRepository.saveAll(aiModels);
labelCategoryRepository.saveAll(categories);
}
@Transactional(readOnly = true) @Transactional(readOnly = true)
public List<AiModelResponse> getModelList(final Integer projectId) { public List<AiModelResponse> getModelList(final Integer projectId) {
int progressModelId = progressService.getProgressModelByProjectId(projectId); int progressModelId = progressService.getProgressModelByProjectId(projectId);
@ -131,23 +81,15 @@ public class AiModelService {
return aiModelRepository.findCustomModelById(modelId).orElseThrow(() -> new CustomException(ErrorCode.BAD_REQUEST)); return aiModelRepository.findCustomModelById(modelId).orElseThrow(() -> new CustomException(ErrorCode.BAD_REQUEST));
} }
@Transactional(readOnly = true)
public List<LabelCategoryResponse> getCategories(final Integer modelId) {
List<LabelCategory> categoryList = labelCategoryRepository.findAllByModelId(modelId);
return categoryList.stream()
.map(LabelCategoryResponse::from)
.toList();
}
@CheckPrivilege(PrivilegeType.EDITOR) @CheckPrivilege(PrivilegeType.EDITOR)
public void train(final Integer projectId, final ModelTrainRequest trainRequest) { public void train(final Integer projectId, final ModelTrainRequest trainRequest) {
// FastAPI 서버로 학습 요청을 전송 // FastAPI 서버로 학습 요청을 전송
Project project = getProject(projectId); Project project = getProject(projectId);
AiModel model = getModel(trainRequest.getModelId()); AiModel model = getModel(trainRequest.getModelId());
Map<Integer, Integer> labelMap = project.getCategoryList().stream() Map<String, Integer> labelMap = project.getCategoryList().stream()
.collect(Collectors.toMap( .collect(Collectors.toMap(
category -> category.getLabelCategory().getId(), ProjectCategory::getLabelName,
ProjectCategory::getId ProjectCategory::getId
)); ));

View File

@ -2,10 +2,7 @@ package com.worlabel.domain.project.controller;
import com.worlabel.domain.participant.entity.dto.ParticipantRequest; import com.worlabel.domain.participant.entity.dto.ParticipantRequest;
import com.worlabel.domain.project.dto.AutoModelRequest; import com.worlabel.domain.project.dto.AutoModelRequest;
import com.worlabel.domain.project.entity.dto.ProjectMemberResponse; import com.worlabel.domain.project.entity.dto.*;
import com.worlabel.domain.project.entity.dto.ProjectRequest;
import com.worlabel.domain.project.entity.dto.ProjectResponse;
import com.worlabel.domain.project.entity.dto.ProjectWithThumbnailResponse;
import com.worlabel.domain.project.service.ProjectService; import com.worlabel.domain.project.service.ProjectService;
import com.worlabel.global.annotation.CurrentUser; import com.worlabel.global.annotation.CurrentUser;
import com.worlabel.global.config.swagger.SwaggerApiError; import com.worlabel.global.config.swagger.SwaggerApiError;
@ -38,7 +35,7 @@ public class ProjectController {
public ProjectResponse createProject( public ProjectResponse createProject(
@CurrentUser final Integer memberId, @CurrentUser final Integer memberId,
@PathVariable("workspace_id") final Integer workspaceId, @PathVariable("workspace_id") final Integer workspaceId,
@Valid @RequestBody final ProjectRequest projectRequest) { @Valid @RequestBody final ProjectWithCategoryRequest projectRequest) {
return projectService.createProject(memberId, workspaceId, projectRequest); return projectService.createProject(memberId, workspaceId, projectRequest);
} }

View File

@ -45,7 +45,7 @@ public class AiDto {
private Integer modelId; private Integer modelId;
@JsonProperty("label_map") @JsonProperty("label_map")
private Map<Integer, Integer> labelMap; private Map<String, Integer> labelMap;
@JsonProperty("data") @JsonProperty("data")
private List<TrainDataInfo> data; private List<TrainDataInfo> data;
@ -62,7 +62,7 @@ public class AiDto {
private Optimizer optimizer; private Optimizer optimizer;
public static TrainRequest of(final Integer projectId, final Integer modelId, final String modelKey, final Map<Integer, Integer> labelMap, final List<TrainDataInfo> data, final ModelTrainRequest trainRequest) { public static TrainRequest of(final Integer projectId, final Integer modelId, final String modelKey, final Map<String, Integer> labelMap, final List<TrainDataInfo> data, final ModelTrainRequest trainRequest) {
TrainRequest request = new TrainRequest(); TrainRequest request = new TrainRequest();
request.projectId = projectId; request.projectId = projectId;
request.modelId = modelId; request.modelId = modelId;
@ -93,7 +93,7 @@ public class AiDto {
private String modelKey; private String modelKey;
@JsonProperty("label_map") @JsonProperty("label_map")
private HashMap<Integer, Integer> labelMap; private HashMap<String, Integer> labelMap;
@JsonProperty("image_list") @JsonProperty("image_list")
private List<AutoLabelingImageRequest> imageList; private List<AutoLabelingImageRequest> imageList;
@ -104,7 +104,7 @@ public class AiDto {
@JsonProperty("iou_threshold") @JsonProperty("iou_threshold")
private Double iouThreshold; private Double iouThreshold;
public static AutoLabelingRequest of(final Integer projectId, final String modelKey, final HashMap<Integer, Integer> labelMap, final List<AutoLabelingImageRequest> imageList) { public static AutoLabelingRequest of(final Integer projectId, final String modelKey, final HashMap<String, Integer> labelMap, final List<AutoLabelingImageRequest> imageList) {
return new AutoLabelingRequest(projectId, modelKey, labelMap, imageList, 0.25, 0.45); return new AutoLabelingRequest(projectId, modelKey, labelMap, imageList, 0.25, 0.45);
} }
} }

View File

@ -0,0 +1,31 @@
package com.worlabel.domain.project.entity.dto;
import com.worlabel.domain.project.entity.ProjectType;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import java.util.List;
@Schema(name = "프로젝트 + 카테고리 요청 dto", description = "프로젝트 + 카테고리 요청 DTO")
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@AllArgsConstructor
@Getter
public class ProjectWithCategoryRequest {
@Schema(description = "프로젝트 제목", example = "삼성 갤럭시 s23")
@NotEmpty(message = "제목을 입력하세요.")
private String title;
@Schema(description = "카테고리 목록 이름", example = "자동차, 사람")
@NotEmpty(message = "카테고리 목록 이름")
private List<String> categories;
@Schema(description = "프로젝트 유형", example = "classification")
@NotNull(message = "카테고리를 입력하세요.")
private ProjectType projectType;
}

View File

@ -7,6 +7,7 @@ import com.worlabel.domain.image.entity.Image;
import com.worlabel.domain.image.entity.LabelStatus; import com.worlabel.domain.image.entity.LabelStatus;
import com.worlabel.domain.image.repository.ImageRepository; import com.worlabel.domain.image.repository.ImageRepository;
import com.worlabel.domain.labelcategory.entity.ProjectCategory; import com.worlabel.domain.labelcategory.entity.ProjectCategory;
import com.worlabel.domain.labelcategory.repository.ProjectLabelCategoryRepository;
import com.worlabel.domain.member.entity.Member; import com.worlabel.domain.member.entity.Member;
import com.worlabel.domain.member.repository.MemberRepository; import com.worlabel.domain.member.repository.MemberRepository;
import com.worlabel.domain.model.entity.AiModel; import com.worlabel.domain.model.entity.AiModel;
@ -23,10 +24,7 @@ import com.worlabel.domain.project.dto.AiDto.AutoLabelingRequest;
import com.worlabel.domain.project.dto.AiDto.AutoLabelingResult; import com.worlabel.domain.project.dto.AiDto.AutoLabelingResult;
import com.worlabel.domain.project.dto.AutoModelRequest; import com.worlabel.domain.project.dto.AutoModelRequest;
import com.worlabel.domain.project.entity.Project; import com.worlabel.domain.project.entity.Project;
import com.worlabel.domain.project.entity.dto.ProjectMemberResponse; import com.worlabel.domain.project.entity.dto.*;
import com.worlabel.domain.project.entity.dto.ProjectRequest;
import com.worlabel.domain.project.entity.dto.ProjectResponse;
import com.worlabel.domain.project.entity.dto.ProjectWithThumbnailResponse;
import com.worlabel.domain.project.repository.ProjectRepository; import com.worlabel.domain.project.repository.ProjectRepository;
import com.worlabel.domain.workspace.entity.Workspace; import com.worlabel.domain.workspace.entity.Workspace;
import com.worlabel.domain.workspace.repository.WorkspaceRepository; import com.worlabel.domain.workspace.repository.WorkspaceRepository;
@ -58,15 +56,15 @@ public class ProjectService {
private final ProjectRepository projectRepository; private final ProjectRepository projectRepository;
private final MemberRepository memberRepository; private final MemberRepository memberRepository;
private final ProgressService progressService; private final ProgressService progressService;
private final ProjectLabelCategoryRepository projectLabelCategoryRepository;
private final S3UploadService s3UploadService; private final S3UploadService s3UploadService;
private final ImageRepository imageRepository; private final ImageRepository imageRepository;
private final AiRequestService aiService; private final AiRequestService aiService;
private final Gson gson; private final Gson gson;
@Transactional @Transactional
public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectRequest projectRequest) { public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectWithCategoryRequest projectRequest) {
Workspace workspace = getWorkspace(memberId, workspaceId); Workspace workspace = getWorkspace(memberId, workspaceId);
Member member = getMember(memberId); Member member = getMember(memberId);
@ -74,6 +72,11 @@ public class ProjectService {
Participant participant = Participant.of(project, member, PrivilegeType.ADMIN); Participant participant = Participant.of(project, member, PrivilegeType.ADMIN);
projectRepository.save(project); projectRepository.save(project);
for (String labelName : projectRequest.getCategories()) {
projectLabelCategoryRepository.save(ProjectCategory.of(labelName, project));
}
participantRepository.save(participant); participantRepository.save(participant);
return ProjectResponse.from(project); return ProjectResponse.from(project);
@ -88,8 +91,8 @@ public class ProjectService {
@Transactional(readOnly = true) @Transactional(readOnly = true)
public List<ProjectWithThumbnailResponse> getProjectsByWorkspaceId(final Integer workspaceId, final Integer memberId, final Integer lastProjectId, final Integer pageSize) { public List<ProjectWithThumbnailResponse> getProjectsByWorkspaceId(final Integer workspaceId, final Integer memberId, final Integer lastProjectId, final Integer pageSize) {
return projectRepository.findProjectsByWorkspaceIdAndMemberIdWithPagination(workspaceId, memberId, lastProjectId, pageSize).stream() return projectRepository.findProjectsByWorkspaceIdAndMemberIdWithPagination(workspaceId, memberId, lastProjectId, pageSize).stream()
.map(project -> ProjectWithThumbnailResponse.from(project, getFirstImageWithProject(project))) .map(project -> ProjectWithThumbnailResponse.from(project, getFirstImageWithProject(project)))
.toList(); .toList();
} }
@Transactional @Transactional
@ -165,7 +168,7 @@ public class ProjectService {
String endPoint = project.getProjectType().getValue() + "/predict"; String endPoint = project.getProjectType().getValue() + "/predict";
List<Image> imageList = imageRepository.findImagesByProjectIdAndPendingOrInProgress(projectId); List<Image> imageList = imageRepository.findImagesByProjectIdAndPendingOrInProgress(projectId);
if(imageList.isEmpty()){ if (imageList.isEmpty()) {
throw new CustomException(ErrorCode.DATA_NOT_FOUND); throw new CustomException(ErrorCode.DATA_NOT_FOUND);
} }
@ -173,7 +176,7 @@ public class ProjectService {
.map(AutoLabelingImageRequest::of) .map(AutoLabelingImageRequest::of)
.toList(); .toList();
HashMap<Integer, Integer> labelMap = getLabelMap(project); HashMap<String, Integer> labelMap = getLabelMap(project);
AiModel aiModel = getAiModel(request); AiModel aiModel = getAiModel(request);
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, aiModel.getModelKey(), labelMap, imageRequestList); AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, aiModel.getModelKey(), labelMap, imageRequestList);
@ -187,9 +190,9 @@ public class ProjectService {
// TODO: 트랜잭션 설정 // TODO: 트랜잭션 설정
@Transactional @Transactional
public void saveAutoLabelList(final List<AutoLabelingResult> resultList) { public void saveAutoLabelList(final List<AutoLabelingResult> resultList) {
for(AutoLabelingResult result: resultList) { for (AutoLabelingResult result : resultList) {
Image image = getImage(result.getImageId()); Image image = getImage(result.getImageId());
if(image.getStatus() == LabelStatus.SAVE || image.getStatus() == LabelStatus.IN_PROGRESS) continue; if (image.getStatus() == LabelStatus.SAVE || image.getStatus() == LabelStatus.IN_PROGRESS) continue;
String dataPath = image.getDataPath(); String dataPath = image.getDataPath();
s3UploadService.uploadJson(result.getData(), dataPath); s3UploadService.uploadJson(result.getData(), dataPath);
image.updateStatus(LabelStatus.IN_PROGRESS); image.updateStatus(LabelStatus.IN_PROGRESS);
@ -219,21 +222,22 @@ public class ProjectService {
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND)); .orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
} }
private HashMap<Integer, Integer> getLabelMap(Project project) { private HashMap<String, Integer> getLabelMap(Project project) {
HashMap<Integer, Integer> labelMap = new HashMap<>(); HashMap<String, Integer> labelMap = new HashMap<>();
List<ProjectCategory> category = project.getCategoryList(); List<ProjectCategory> category = project.getCategoryList();
for (ProjectCategory projectCategory : category) { for (ProjectCategory projectCategory : category) {
int aiId = projectCategory.getLabelCategory().getAiCategoryId(); if (labelMap.containsKey(projectCategory.getLabelName())) {
if (labelMap.containsKey(aiId)) continue; continue;
}
labelMap.put(aiId, projectCategory.getId()); labelMap.put(projectCategory.getLabelName(), projectCategory.getId());
} }
return labelMap; return labelMap;
} }
private Image getImage(final Long imageId){ private Image getImage(final Long imageId) {
return imageRepository.findById(imageId) return imageRepository.findById(imageId)
.orElseThrow(()-> new CustomException(ErrorCode.DATA_NOT_FOUND)); .orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
} }
private Project getProject(final Integer projectId) { private Project getProject(final Integer projectId) {