Merge branch 'be/refactor/category' into 'be/develop'

Refactor: 카테고리 수정  - S11P21S002-223

See merge request s11-s-project/S11P21S002!210
This commit is contained in:
홍창기 2024-09-27 11:17:26 +09:00
commit ffe79cc79b
16 changed files with 101 additions and 326 deletions

View File

@ -1,16 +1,13 @@
package com.worlabel.domain.labelcategory.controller;
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryRequest;
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
import com.worlabel.domain.labelcategory.entity.dto.CategoryResponse;
import com.worlabel.domain.labelcategory.service.ProjectLabelCategoryService;
import com.worlabel.global.annotation.CurrentUser;
import com.worlabel.global.config.swagger.SwaggerApiError;
import com.worlabel.global.config.swagger.SwaggerApiSuccess;
import com.worlabel.global.exception.ErrorCode;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import org.springframework.data.repository.query.Param;
import org.springframework.web.bind.annotation.*;
import java.util.List;
@ -23,39 +20,11 @@ public class CategoryController {
private final ProjectLabelCategoryService categoryService;
@Operation(summary = "프로젝트 레이블 카테고리 선택", description = "프로젝트 레이블 카테고리를 추가합니다.")
@SwaggerApiSuccess(description = "카테고리 성공적으로 추가합니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@PostMapping
public void createFolder(
@CurrentUser final Integer memberId,
@RequestBody final LabelCategoryRequest categoryRequest) {
categoryService.createCategory(memberId, categoryRequest);
}
@Operation(summary = "레이블 카테고리 존재 여부 조회", description = "해당 프로젝트에 같은 레이블 카테고리 이름이 있는지 조회합니다.")
@SwaggerApiSuccess(description = "카테고리 존재 여부를 조회합니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@GetMapping("/exist")
public boolean existByCategoryName(
@PathVariable("project_id") final Integer projectId,
@Param("categoryName") final String categoryName) {
return categoryService.existByCategoryName(projectId, categoryName);
}
@Operation(summary = "프로젝트 레이블 카테고리 리스트 조회", description = "레이블 카테고리 리스트를 조회합니다..")
@SwaggerApiSuccess(description = "카테고리 리스트를 성공적으로 조회합니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@GetMapping
public List<LabelCategoryResponse> getCategoryList(@PathVariable("project_id") final Integer projectId) {
return categoryService.getCategoryList(projectId);
}
@Operation(summary = "카테고리 삭제", description = "카테고리를 삭제합니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@SwaggerApiSuccess(description = "카테고리를 성공적으로 삭제합니다.")
@DeleteMapping("/{category_id}")
public void deleteCategoryById(@PathVariable("project_id") final Integer projectId, @PathVariable("category_id") final Integer categoryId) {
categoryService.deleteCategory(projectId, categoryId);
public List<CategoryResponse> getCategoryList(@PathVariable("project_id") final Integer projectId) {
return categoryService.getCategoryById(projectId);
}
}

View File

@ -1,53 +0,0 @@
package com.worlabel.domain.labelcategory.entity;
import com.worlabel.domain.model.entity.AiModel;
import com.worlabel.global.common.BaseEntity;
import jakarta.persistence.*;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.NoArgsConstructor;
@Getter
@Entity
@Table(name = "label_category")
@NoArgsConstructor(access = AccessLevel.PROTECTED)
public class LabelCategory extends BaseEntity {
/**
* 레이블 카테고리 PK
*/
@Id
@Column(name = "label_category_id", nullable = false)
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Integer id;
/**
* 속한 모델
*/
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "model_id", nullable = false)
private AiModel aiModel;
/**
* 레이블 카테고리 이름
*/
@Column(name = "label_category_name", nullable = false)
private String name;
/**
* 실제 AI 모델의 ai 카테고리 id
*/
@Column(name = "ai_category_id", nullable = false)
private Integer aiCategoryId;
private LabelCategory(final AiModel aiModel, final String name, final int aiCategoryId) {
this.aiModel = aiModel;
this.name = name;
this.aiCategoryId = aiCategoryId;
}
public static LabelCategory of(final AiModel aiModel, final String name, final int aiCategoryId) {
return new LabelCategory(aiModel, name, aiCategoryId);
}
}

View File

@ -22,11 +22,10 @@ public class ProjectCategory extends BaseEntity {
private Integer id;
/**
* 레이블 카테고리
* Model name
*/
@ManyToOne(fetch = FetchType.LAZY)
@JoinColumn(name = "label_category_id", nullable = false)
private LabelCategory labelCategory;
@Column(name = "label_category_name", length = 50)
private String labelName;
/**
* 프로젝트
@ -35,12 +34,12 @@ public class ProjectCategory extends BaseEntity {
@JoinColumn(name = "project_id", nullable = false)
private Project project;
private ProjectCategory(LabelCategory labelCategory, Project project) {
this.labelCategory = labelCategory;
private ProjectCategory(String labelName, Project project) {
this.labelName = labelName;
this.project = project;
}
public static ProjectCategory of(LabelCategory labelCategory, Project project) {
return new ProjectCategory(labelCategory, project);
public static ProjectCategory of(String labelName, Project project) {
return new ProjectCategory(labelName, project);
}
}

View File

@ -0,0 +1,22 @@
package com.worlabel.domain.labelcategory.entity.dto;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Schema(name = "카테고리 응답 DTO", description = "프로젝트 내 카테고리 종류 응답 DTO")
@Getter
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public class CategoryResponse {
@Schema(description = "카테고리 ID", example = "1")
private Integer id;
@Schema(description = "라벨링 이름", example = "사람")
private String labelName;
public static CategoryResponse of(final Integer id, final String labelName) {
return new CategoryResponse(id, labelName);
}
}

View File

@ -1,21 +0,0 @@
package com.worlabel.domain.labelcategory.entity.dto;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import java.util.List;
@Schema(name = "카테고리 요청 DTO", description = "카테고리 생성 및 수정을 위한 요청 DTO")
@Getter
@NoArgsConstructor
@AllArgsConstructor
public class LabelCategoryRequest {
@Schema(description = "Model Id", example = "1")
private Integer modelId;
@Schema(description = "카테고리 Id 리스트", example = "[0,3,6,8]")
private List<Integer> labelCategoryList;
}

View File

@ -1,22 +0,0 @@
package com.worlabel.domain.labelcategory.entity.dto;
import com.worlabel.domain.labelcategory.entity.LabelCategory;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
@Schema(name = "카테고리 응답 DTO", description = "카테고리 조회 응답 DTO")
public class LabelCategoryResponse {
@Schema(description = "카테고리 ID", example = "1")
private Integer id;
@Schema(description = "카테고리 이름", example = "Car")
private String name;
public static LabelCategoryResponse from(LabelCategory labelCategory) {
return new LabelCategoryResponse(labelCategory.getId(), labelCategory.getName());
}
}

View File

@ -1,22 +0,0 @@
package com.worlabel.domain.labelcategory.repository;
import com.worlabel.domain.labelcategory.entity.LabelCategory;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import java.util.List;
public interface LabelCategoryRepository extends JpaRepository<LabelCategory, Integer> {
@Query("SELECT l FROM LabelCategory l " +
"WHERE l.aiModel.id = :modelId")
List<LabelCategory> findAllByModelId(@Param("modelId") Integer modelId);
@Query("SELECT l FROM LabelCategory l " +
"WHERE l.aiModel.id = :modelId AND " +
"l.id IN :categoryList")
List<LabelCategory> findAllByIdsAndModelId(@Param("categoryList") List<Integer> labelCategoryList,@Param("modelId") Integer modelId);
}

View File

@ -9,14 +9,5 @@ import java.util.List;
public interface ProjectLabelCategoryRepository extends JpaRepository<ProjectCategory, Integer> {
@Query("SELECT COUNT(pc) >= 1 FROM ProjectCategory pc " +
"WHERE pc.project.id = :projectId AND " +
"pc.labelCategory.name = :categoryName ")
boolean existsByNameAndProjectId(@Param("categoryName") String categoryName , @Param("projectId") int projectId);
@Query("SELECT pc FROM ProjectCategory pc " +
"JOIN FETCH pc.labelCategory " +
"WHERE pc.project.id = :projectId ")
List<ProjectCategory> findAllByProjectId(@Param("projectId") Integer projectId);
List<ProjectCategory> findAllByProjectId(Integer projectId);
}

View File

@ -1,17 +1,10 @@
package com.worlabel.domain.labelcategory.service;
import com.worlabel.domain.labelcategory.entity.LabelCategory;
import com.worlabel.domain.labelcategory.entity.ProjectCategory;
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryRequest;
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository;
import com.worlabel.domain.labelcategory.entity.dto.CategoryResponse;
import com.worlabel.domain.labelcategory.repository.ProjectLabelCategoryRepository;
import com.worlabel.domain.participant.entity.PrivilegeType;
import com.worlabel.domain.project.entity.Project;
import com.worlabel.domain.project.repository.ProjectRepository;
import com.worlabel.global.annotation.CheckPrivilege;
import com.worlabel.global.exception.CustomException;
import com.worlabel.global.exception.ErrorCode;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@ -26,48 +19,13 @@ import java.util.List;
public class ProjectLabelCategoryService {
private final ProjectLabelCategoryRepository projectLabelCategoryRepository;
private final LabelCategoryRepository labelCategoryRepository;
private final ProjectRepository projectRepository;
@CheckPrivilege(PrivilegeType.EDITOR)
public void createCategory(final Integer projectId, final LabelCategoryRequest categoryRequest) {
Project project = getProject(projectId);
List<LabelCategory> labelCategoryList = labelCategoryRepository.findAllByIdsAndModelId(categoryRequest.getLabelCategoryList(), categoryRequest.getModelId());
List<ProjectCategory> projectCategoryList = labelCategoryList.stream().map(o -> ProjectCategory.of(o, project)).toList();
projectLabelCategoryRepository.saveAll(projectCategoryList);
}
@CheckPrivilege(PrivilegeType.EDITOR)
public void deleteCategory(final int projectId, final int projectCategoryId) {
ProjectCategory projectCategory = getProjectCategory(projectCategoryId);
projectLabelCategoryRepository.delete(projectCategory);
}
@CheckPrivilege(PrivilegeType.VIEWER)
public LabelCategoryResponse getCategoryById(final int projectId, final int categoryId) {
return LabelCategoryResponse.from(getProjectCategory(categoryId).getLabelCategory());
}
public List<CategoryResponse> getCategoryById(final int projectId) {
List<ProjectCategory> categories = projectLabelCategoryRepository.findAllByProjectId(projectId);
public boolean existByCategoryName(final int projectId, final String categoryName) {
return projectLabelCategoryRepository.existsByNameAndProjectId(categoryName, projectId);
}
@CheckPrivilege(PrivilegeType.VIEWER)
public List<LabelCategoryResponse> getCategoryList(final Integer projectId) {
return projectLabelCategoryRepository.findAllByProjectId(projectId)
.stream()
.map(ProjectCategory::getLabelCategory)
.map(LabelCategoryResponse::from)
return categories.stream()
.map(category -> CategoryResponse.of(category.getId(), category.getLabelName()))
.toList();
}
private Project getProject(Integer projectId) {
return projectRepository.findById(projectId)
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
}
private ProjectCategory getProjectCategory(final Integer categoryId) {
return projectLabelCategoryRepository.findById(categoryId)
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
}
}

View File

@ -1,6 +1,5 @@
package com.worlabel.domain.model.controller;
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
import com.worlabel.domain.model.entity.dto.AiModelRequest;
import com.worlabel.domain.model.entity.dto.AiModelResponse;
import com.worlabel.domain.model.entity.dto.ModelTrainRequest;
@ -39,15 +38,6 @@ public class AiModelController {
return aiModelService.getModelList(projectId);
}
@Operation(summary = "특정 모델 카테고리", description = "모델의 카테고리를 조회합니다.")
@SwaggerApiSuccess(description = "카테고리를 조회합니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@GetMapping("/models/{model_id}/categories")
public List<LabelCategoryResponse> getCategories(
@PathVariable("model_id") final Integer modelId) {
return aiModelService.getCategories(modelId);
}
@Operation(summary = "프로젝트 모델 추가", description = "프로젝트에 있는 모델을 추가합니다.")
@SwaggerApiSuccess(description = "프로젝트 모델을 추가합니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})

View File

@ -1,7 +1,6 @@
package com.worlabel.domain.model.entity;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.worlabel.domain.labelcategory.entity.LabelCategory;
import com.worlabel.domain.project.entity.Project;
import com.worlabel.global.common.BaseEntity;
import jakarta.persistence.*;
@ -9,9 +8,6 @@ import lombok.AccessLevel;
import lombok.Getter;
import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.List;
@Getter
@Entity
@Table(name = "ai_model")
@ -52,12 +48,6 @@ public class AiModel extends BaseEntity {
@JsonIgnore
private Project project;
/**
* 모델에 속한 카테고리
*/
@OneToMany(mappedBy = "aiModel", fetch = FetchType.LAZY)
private List<LabelCategory> categoryList = new ArrayList<>();
private AiModel(String name, String modelKey, int version, Project project) {
this.name = name;
this.modelKey = modelKey;

View File

@ -3,19 +3,13 @@ package com.worlabel.domain.model.service;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import com.worlabel.domain.image.entity.Image;
import com.worlabel.domain.image.entity.LabelStatus;
import com.worlabel.domain.image.repository.ImageRepository;
import com.worlabel.domain.labelcategory.entity.LabelCategory;
import com.worlabel.domain.labelcategory.entity.ProjectCategory;
import com.worlabel.domain.labelcategory.entity.dto.DefaultLabelCategoryResponse;
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository;
import com.worlabel.domain.model.entity.AiModel;
import com.worlabel.domain.model.entity.dto.*;
import com.worlabel.domain.model.repository.AiModelRepository;
import com.worlabel.domain.participant.entity.PrivilegeType;
import com.worlabel.domain.progress.service.ProgressService;
import com.worlabel.domain.project.dto.AiDto;
import com.worlabel.domain.project.dto.AiDto.TrainDataInfo;
import com.worlabel.domain.project.dto.AiDto.TrainRequest;
import com.worlabel.domain.project.entity.Project;
@ -25,7 +19,6 @@ import com.worlabel.domain.report.service.ReportService;
import com.worlabel.domain.result.entity.Result;
import com.worlabel.domain.result.repository.ResultRepository;
import com.worlabel.global.annotation.CheckPrivilege;
import com.worlabel.global.cache.CacheKey;
import com.worlabel.global.exception.CustomException;
import com.worlabel.global.exception.ErrorCode;
import com.worlabel.global.service.AiRequestService;
@ -38,8 +31,6 @@ import org.springframework.transaction.annotation.Transactional;
import java.lang.reflect.Type;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@ -50,7 +41,6 @@ import java.util.stream.Collectors;
@RequiredArgsConstructor
public class AiModelService {
private final LabelCategoryRepository labelCategoryRepository;
private final RedisTemplate<String, Object> redisTemplate;
private final AiModelRepository aiModelRepository;
private final ProjectRepository projectRepository;
@ -62,46 +52,6 @@ public class AiModelService {
private final ProgressService progressService;
private final ReportService reportService;
// @PostConstruct
public void loadDefaultModel() {
String url = "model/default";
List<DefaultResponse> defaultResponseList = aiRequestService.getRequest(url, this::converter);
// 1. DefaultResponse의 Key값만 모아서 리스트로 만든다.
List<String> allModelKeys = defaultResponseList.stream()
.map(response -> response.getDefaultAiModelResponse().getModelKey())
.toList();
// 2. 해당 Key값이 DB에 있는지 확인하기 ( 번의 쿼리로)
List<String> existingModelKeys = aiModelRepository.findAllByModelKeyIn(allModelKeys).stream()
.map(AiModel::getModelKey)
.toList();
// 3. DB에 없는 Key만 필터링해서 처리
List<DefaultResponse> newModel = defaultResponseList.stream()
.filter(model -> !existingModelKeys.contains(model.getDefaultAiModelResponse().getModelKey()))
.toList();
// 새롭게 추가된 값을 디비에 저장
List<AiModel> aiModels = new ArrayList<>();
List<LabelCategory> categories = new ArrayList<>();
for (DefaultResponse defaultResponse : newModel) {
DefaultAiModelResponse defaultAiModelResponse = defaultResponse.getDefaultAiModelResponse();
AiModel newAiModel = AiModel.of(defaultAiModelResponse.getName(), defaultAiModelResponse.getModelKey(), 0, null);
aiModels.add(newAiModel);
List<DefaultLabelCategoryResponse> defaultLabelCategoryResponseList = defaultResponse.getDefaultLabelCategoryResponseList();
for (DefaultLabelCategoryResponse categoryResponse : defaultLabelCategoryResponseList) {
categories.add(LabelCategory.of(newAiModel, categoryResponse.getName(), categoryResponse.getAiId()));
}
}
aiModelRepository.saveAll(aiModels);
labelCategoryRepository.saveAll(categories);
}
@Transactional(readOnly = true)
public List<AiModelResponse> getModelList(final Integer projectId) {
int progressModelId = progressService.getProgressModelByProjectId(projectId);
@ -131,23 +81,15 @@ public class AiModelService {
return aiModelRepository.findCustomModelById(modelId).orElseThrow(() -> new CustomException(ErrorCode.BAD_REQUEST));
}
@Transactional(readOnly = true)
public List<LabelCategoryResponse> getCategories(final Integer modelId) {
List<LabelCategory> categoryList = labelCategoryRepository.findAllByModelId(modelId);
return categoryList.stream()
.map(LabelCategoryResponse::from)
.toList();
}
@CheckPrivilege(PrivilegeType.EDITOR)
public void train(final Integer projectId, final ModelTrainRequest trainRequest) {
// FastAPI 서버로 학습 요청을 전송
Project project = getProject(projectId);
AiModel model = getModel(trainRequest.getModelId());
Map<Integer, Integer> labelMap = project.getCategoryList().stream()
Map<String, Integer> labelMap = project.getCategoryList().stream()
.collect(Collectors.toMap(
category -> category.getLabelCategory().getId(),
ProjectCategory::getLabelName,
ProjectCategory::getId
));

View File

@ -2,10 +2,7 @@ package com.worlabel.domain.project.controller;
import com.worlabel.domain.participant.entity.dto.ParticipantRequest;
import com.worlabel.domain.project.dto.AutoModelRequest;
import com.worlabel.domain.project.entity.dto.ProjectMemberResponse;
import com.worlabel.domain.project.entity.dto.ProjectRequest;
import com.worlabel.domain.project.entity.dto.ProjectResponse;
import com.worlabel.domain.project.entity.dto.ProjectWithThumbnailResponse;
import com.worlabel.domain.project.entity.dto.*;
import com.worlabel.domain.project.service.ProjectService;
import com.worlabel.global.annotation.CurrentUser;
import com.worlabel.global.config.swagger.SwaggerApiError;
@ -38,7 +35,7 @@ public class ProjectController {
public ProjectResponse createProject(
@CurrentUser final Integer memberId,
@PathVariable("workspace_id") final Integer workspaceId,
@Valid @RequestBody final ProjectRequest projectRequest) {
@Valid @RequestBody final ProjectWithCategoryRequest projectRequest) {
return projectService.createProject(memberId, workspaceId, projectRequest);
}

View File

@ -45,7 +45,7 @@ public class AiDto {
private Integer modelId;
@JsonProperty("label_map")
private Map<Integer, Integer> labelMap;
private Map<String, Integer> labelMap;
@JsonProperty("data")
private List<TrainDataInfo> data;
@ -62,7 +62,7 @@ public class AiDto {
private Optimizer optimizer;
public static TrainRequest of(final Integer projectId, final Integer modelId, final String modelKey, final Map<Integer, Integer> labelMap, final List<TrainDataInfo> data, final ModelTrainRequest trainRequest) {
public static TrainRequest of(final Integer projectId, final Integer modelId, final String modelKey, final Map<String, Integer> labelMap, final List<TrainDataInfo> data, final ModelTrainRequest trainRequest) {
TrainRequest request = new TrainRequest();
request.projectId = projectId;
request.modelId = modelId;
@ -93,7 +93,7 @@ public class AiDto {
private String modelKey;
@JsonProperty("label_map")
private HashMap<Integer, Integer> labelMap;
private HashMap<String, Integer> labelMap;
@JsonProperty("image_list")
private List<AutoLabelingImageRequest> imageList;
@ -104,7 +104,7 @@ public class AiDto {
@JsonProperty("iou_threshold")
private Double iouThreshold;
public static AutoLabelingRequest of(final Integer projectId, final String modelKey, final HashMap<Integer, Integer> labelMap, final List<AutoLabelingImageRequest> imageList) {
public static AutoLabelingRequest of(final Integer projectId, final String modelKey, final HashMap<String, Integer> labelMap, final List<AutoLabelingImageRequest> imageList) {
return new AutoLabelingRequest(projectId, modelKey, labelMap, imageList, 0.25, 0.45);
}
}

View File

@ -0,0 +1,31 @@
package com.worlabel.domain.project.entity.dto;
import com.worlabel.domain.project.entity.ProjectType;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import java.util.List;
@Schema(name = "프로젝트 + 카테고리 요청 dto", description = "프로젝트 + 카테고리 요청 DTO")
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@AllArgsConstructor
@Getter
public class ProjectWithCategoryRequest {
@Schema(description = "프로젝트 제목", example = "삼성 갤럭시 s23")
@NotEmpty(message = "제목을 입력하세요.")
private String title;
@Schema(description = "카테고리 목록 이름", example = "자동차, 사람")
@NotEmpty(message = "카테고리 목록 이름")
private List<String> categories;
@Schema(description = "프로젝트 유형", example = "classification")
@NotNull(message = "카테고리를 입력하세요.")
private ProjectType projectType;
}

View File

@ -7,6 +7,7 @@ import com.worlabel.domain.image.entity.Image;
import com.worlabel.domain.image.entity.LabelStatus;
import com.worlabel.domain.image.repository.ImageRepository;
import com.worlabel.domain.labelcategory.entity.ProjectCategory;
import com.worlabel.domain.labelcategory.repository.ProjectLabelCategoryRepository;
import com.worlabel.domain.member.entity.Member;
import com.worlabel.domain.member.repository.MemberRepository;
import com.worlabel.domain.model.entity.AiModel;
@ -23,10 +24,7 @@ import com.worlabel.domain.project.dto.AiDto.AutoLabelingRequest;
import com.worlabel.domain.project.dto.AiDto.AutoLabelingResult;
import com.worlabel.domain.project.dto.AutoModelRequest;
import com.worlabel.domain.project.entity.Project;
import com.worlabel.domain.project.entity.dto.ProjectMemberResponse;
import com.worlabel.domain.project.entity.dto.ProjectRequest;
import com.worlabel.domain.project.entity.dto.ProjectResponse;
import com.worlabel.domain.project.entity.dto.ProjectWithThumbnailResponse;
import com.worlabel.domain.project.entity.dto.*;
import com.worlabel.domain.project.repository.ProjectRepository;
import com.worlabel.domain.workspace.entity.Workspace;
import com.worlabel.domain.workspace.repository.WorkspaceRepository;
@ -58,15 +56,15 @@ public class ProjectService {
private final ProjectRepository projectRepository;
private final MemberRepository memberRepository;
private final ProgressService progressService;
private final ProjectLabelCategoryRepository projectLabelCategoryRepository;
private final S3UploadService s3UploadService;
private final ImageRepository imageRepository;
private final AiRequestService aiService;
private final Gson gson;
@Transactional
public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectRequest projectRequest) {
public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectWithCategoryRequest projectRequest) {
Workspace workspace = getWorkspace(memberId, workspaceId);
Member member = getMember(memberId);
@ -74,6 +72,11 @@ public class ProjectService {
Participant participant = Participant.of(project, member, PrivilegeType.ADMIN);
projectRepository.save(project);
for (String labelName : projectRequest.getCategories()) {
projectLabelCategoryRepository.save(ProjectCategory.of(labelName, project));
}
participantRepository.save(participant);
return ProjectResponse.from(project);
@ -173,7 +176,7 @@ public class ProjectService {
.map(AutoLabelingImageRequest::of)
.toList();
HashMap<Integer, Integer> labelMap = getLabelMap(project);
HashMap<String, Integer> labelMap = getLabelMap(project);
AiModel aiModel = getAiModel(request);
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, aiModel.getModelKey(), labelMap, imageRequestList);
@ -219,15 +222,16 @@ public class ProjectService {
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
}
private HashMap<Integer, Integer> getLabelMap(Project project) {
HashMap<Integer, Integer> labelMap = new HashMap<>();
private HashMap<String, Integer> getLabelMap(Project project) {
HashMap<String, Integer> labelMap = new HashMap<>();
List<ProjectCategory> category = project.getCategoryList();
for (ProjectCategory projectCategory : category) {
int aiId = projectCategory.getLabelCategory().getAiCategoryId();
if (labelMap.containsKey(aiId)) continue;
labelMap.put(aiId, projectCategory.getId());
if (labelMap.containsKey(projectCategory.getLabelName())) {
continue;
}
labelMap.put(projectCategory.getLabelName(), projectCategory.getId());
}
return labelMap;
}