Merge branch 'be/refactor/category' into 'be/develop'
Refactor: 카테고리 수정 - S11P21S002-223 See merge request s11-s-project/S11P21S002!210
This commit is contained in:
commit
ffe79cc79b
@ -1,16 +1,13 @@
|
||||
package com.worlabel.domain.labelcategory.controller;
|
||||
|
||||
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryRequest;
|
||||
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
|
||||
import com.worlabel.domain.labelcategory.entity.dto.CategoryResponse;
|
||||
import com.worlabel.domain.labelcategory.service.ProjectLabelCategoryService;
|
||||
import com.worlabel.global.annotation.CurrentUser;
|
||||
import com.worlabel.global.config.swagger.SwaggerApiError;
|
||||
import com.worlabel.global.config.swagger.SwaggerApiSuccess;
|
||||
import com.worlabel.global.exception.ErrorCode;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import org.springframework.data.repository.query.Param;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.List;
|
||||
@ -23,39 +20,11 @@ public class CategoryController {
|
||||
|
||||
private final ProjectLabelCategoryService categoryService;
|
||||
|
||||
@Operation(summary = "프로젝트 레이블 카테고리 선택", description = "프로젝트 레이블 카테고리를 추가합니다.")
|
||||
@SwaggerApiSuccess(description = "카테고리 성공적으로 추가합니다.")
|
||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||
@PostMapping
|
||||
public void createFolder(
|
||||
@CurrentUser final Integer memberId,
|
||||
@RequestBody final LabelCategoryRequest categoryRequest) {
|
||||
categoryService.createCategory(memberId, categoryRequest);
|
||||
}
|
||||
|
||||
@Operation(summary = "레이블 카테고리 존재 여부 조회", description = "해당 프로젝트에 같은 레이블 카테고리 이름이 있는지 조회합니다.")
|
||||
@SwaggerApiSuccess(description = "카테고리 존재 여부를 조회합니다.")
|
||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||
@GetMapping("/exist")
|
||||
public boolean existByCategoryName(
|
||||
@PathVariable("project_id") final Integer projectId,
|
||||
@Param("categoryName") final String categoryName) {
|
||||
return categoryService.existByCategoryName(projectId, categoryName);
|
||||
}
|
||||
|
||||
@Operation(summary = "프로젝트 레이블 카테고리 리스트 조회", description = "레이블 카테고리 리스트를 조회합니다..")
|
||||
@SwaggerApiSuccess(description = "카테고리 리스트를 성공적으로 조회합니다.")
|
||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||
@GetMapping
|
||||
public List<LabelCategoryResponse> getCategoryList(@PathVariable("project_id") final Integer projectId) {
|
||||
return categoryService.getCategoryList(projectId);
|
||||
}
|
||||
|
||||
@Operation(summary = "카테고리 삭제", description = "카테고리를 삭제합니다.")
|
||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||
@SwaggerApiSuccess(description = "카테고리를 성공적으로 삭제합니다.")
|
||||
@DeleteMapping("/{category_id}")
|
||||
public void deleteCategoryById(@PathVariable("project_id") final Integer projectId, @PathVariable("category_id") final Integer categoryId) {
|
||||
categoryService.deleteCategory(projectId, categoryId);
|
||||
public List<CategoryResponse> getCategoryList(@PathVariable("project_id") final Integer projectId) {
|
||||
return categoryService.getCategoryById(projectId);
|
||||
}
|
||||
}
|
||||
|
@ -1,53 +0,0 @@
|
||||
package com.worlabel.domain.labelcategory.entity;
|
||||
|
||||
|
||||
import com.worlabel.domain.model.entity.AiModel;
|
||||
import com.worlabel.global.common.BaseEntity;
|
||||
import jakarta.persistence.*;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Getter
|
||||
@Entity
|
||||
@Table(name = "label_category")
|
||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
||||
public class LabelCategory extends BaseEntity {
|
||||
|
||||
/**
|
||||
* 레이블 카테고리 PK
|
||||
*/
|
||||
@Id
|
||||
@Column(name = "label_category_id", nullable = false)
|
||||
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
||||
private Integer id;
|
||||
|
||||
/**
|
||||
* 속한 모델
|
||||
*/
|
||||
@ManyToOne(fetch = FetchType.LAZY)
|
||||
@JoinColumn(name = "model_id", nullable = false)
|
||||
private AiModel aiModel;
|
||||
|
||||
/**
|
||||
* 레이블 카테고리 이름
|
||||
*/
|
||||
@Column(name = "label_category_name", nullable = false)
|
||||
private String name;
|
||||
|
||||
/**
|
||||
* 실제 AI 모델의 ai 카테고리 id
|
||||
*/
|
||||
@Column(name = "ai_category_id", nullable = false)
|
||||
private Integer aiCategoryId;
|
||||
|
||||
private LabelCategory(final AiModel aiModel, final String name, final int aiCategoryId) {
|
||||
this.aiModel = aiModel;
|
||||
this.name = name;
|
||||
this.aiCategoryId = aiCategoryId;
|
||||
}
|
||||
|
||||
public static LabelCategory of(final AiModel aiModel, final String name, final int aiCategoryId) {
|
||||
return new LabelCategory(aiModel, name, aiCategoryId);
|
||||
}
|
||||
}
|
@ -22,11 +22,10 @@ public class ProjectCategory extends BaseEntity {
|
||||
private Integer id;
|
||||
|
||||
/**
|
||||
* 레이블 카테고리
|
||||
* Model name
|
||||
*/
|
||||
@ManyToOne(fetch = FetchType.LAZY)
|
||||
@JoinColumn(name = "label_category_id", nullable = false)
|
||||
private LabelCategory labelCategory;
|
||||
@Column(name = "label_category_name", length = 50)
|
||||
private String labelName;
|
||||
|
||||
/**
|
||||
* 프로젝트
|
||||
@ -35,12 +34,12 @@ public class ProjectCategory extends BaseEntity {
|
||||
@JoinColumn(name = "project_id", nullable = false)
|
||||
private Project project;
|
||||
|
||||
private ProjectCategory(LabelCategory labelCategory, Project project) {
|
||||
this.labelCategory = labelCategory;
|
||||
private ProjectCategory(String labelName, Project project) {
|
||||
this.labelName = labelName;
|
||||
this.project = project;
|
||||
}
|
||||
|
||||
public static ProjectCategory of(LabelCategory labelCategory, Project project) {
|
||||
return new ProjectCategory(labelCategory, project);
|
||||
public static ProjectCategory of(String labelName, Project project) {
|
||||
return new ProjectCategory(labelName, project);
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,22 @@
|
||||
package com.worlabel.domain.labelcategory.entity.dto;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
@Schema(name = "카테고리 응답 DTO", description = "프로젝트 내 카테고리 종류 응답 DTO")
|
||||
@Getter
|
||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public class CategoryResponse {
|
||||
|
||||
@Schema(description = "카테고리 ID", example = "1")
|
||||
private Integer id;
|
||||
|
||||
@Schema(description = "라벨링 이름", example = "사람")
|
||||
private String labelName;
|
||||
|
||||
public static CategoryResponse of(final Integer id, final String labelName) {
|
||||
return new CategoryResponse(id, labelName);
|
||||
}
|
||||
}
|
@ -1,21 +0,0 @@
|
||||
package com.worlabel.domain.labelcategory.entity.dto;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Schema(name = "카테고리 요청 DTO", description = "카테고리 생성 및 수정을 위한 요청 DTO")
|
||||
@Getter
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class LabelCategoryRequest {
|
||||
|
||||
@Schema(description = "Model Id", example = "1")
|
||||
private Integer modelId;
|
||||
|
||||
@Schema(description = "카테고리 Id 리스트", example = "[0,3,6,8]")
|
||||
private List<Integer> labelCategoryList;
|
||||
}
|
@ -1,22 +0,0 @@
|
||||
package com.worlabel.domain.labelcategory.entity.dto;
|
||||
|
||||
import com.worlabel.domain.labelcategory.entity.LabelCategory;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
@Schema(name = "카테고리 응답 DTO", description = "카테고리 조회 응답 DTO")
|
||||
public class LabelCategoryResponse {
|
||||
|
||||
@Schema(description = "카테고리 ID", example = "1")
|
||||
private Integer id;
|
||||
|
||||
@Schema(description = "카테고리 이름", example = "Car")
|
||||
private String name;
|
||||
|
||||
public static LabelCategoryResponse from(LabelCategory labelCategory) {
|
||||
return new LabelCategoryResponse(labelCategory.getId(), labelCategory.getName());
|
||||
}
|
||||
}
|
@ -1,22 +0,0 @@
|
||||
package com.worlabel.domain.labelcategory.repository;
|
||||
|
||||
import com.worlabel.domain.labelcategory.entity.LabelCategory;
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
import org.springframework.data.jpa.repository.Query;
|
||||
import org.springframework.data.repository.query.Param;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface LabelCategoryRepository extends JpaRepository<LabelCategory, Integer> {
|
||||
|
||||
@Query("SELECT l FROM LabelCategory l " +
|
||||
"WHERE l.aiModel.id = :modelId")
|
||||
List<LabelCategory> findAllByModelId(@Param("modelId") Integer modelId);
|
||||
|
||||
|
||||
@Query("SELECT l FROM LabelCategory l " +
|
||||
"WHERE l.aiModel.id = :modelId AND " +
|
||||
"l.id IN :categoryList")
|
||||
List<LabelCategory> findAllByIdsAndModelId(@Param("categoryList") List<Integer> labelCategoryList,@Param("modelId") Integer modelId);
|
||||
|
||||
}
|
@ -9,14 +9,5 @@ import java.util.List;
|
||||
|
||||
public interface ProjectLabelCategoryRepository extends JpaRepository<ProjectCategory, Integer> {
|
||||
|
||||
|
||||
@Query("SELECT COUNT(pc) >= 1 FROM ProjectCategory pc " +
|
||||
"WHERE pc.project.id = :projectId AND " +
|
||||
"pc.labelCategory.name = :categoryName ")
|
||||
boolean existsByNameAndProjectId(@Param("categoryName") String categoryName , @Param("projectId") int projectId);
|
||||
|
||||
@Query("SELECT pc FROM ProjectCategory pc " +
|
||||
"JOIN FETCH pc.labelCategory " +
|
||||
"WHERE pc.project.id = :projectId ")
|
||||
List<ProjectCategory> findAllByProjectId(@Param("projectId") Integer projectId);
|
||||
List<ProjectCategory> findAllByProjectId(Integer projectId);
|
||||
}
|
||||
|
@ -1,17 +1,10 @@
|
||||
package com.worlabel.domain.labelcategory.service;
|
||||
|
||||
import com.worlabel.domain.labelcategory.entity.LabelCategory;
|
||||
import com.worlabel.domain.labelcategory.entity.ProjectCategory;
|
||||
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryRequest;
|
||||
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
|
||||
import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository;
|
||||
import com.worlabel.domain.labelcategory.entity.dto.CategoryResponse;
|
||||
import com.worlabel.domain.labelcategory.repository.ProjectLabelCategoryRepository;
|
||||
import com.worlabel.domain.participant.entity.PrivilegeType;
|
||||
import com.worlabel.domain.project.entity.Project;
|
||||
import com.worlabel.domain.project.repository.ProjectRepository;
|
||||
import com.worlabel.global.annotation.CheckPrivilege;
|
||||
import com.worlabel.global.exception.CustomException;
|
||||
import com.worlabel.global.exception.ErrorCode;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
@ -26,48 +19,13 @@ import java.util.List;
|
||||
public class ProjectLabelCategoryService {
|
||||
|
||||
private final ProjectLabelCategoryRepository projectLabelCategoryRepository;
|
||||
private final LabelCategoryRepository labelCategoryRepository;
|
||||
private final ProjectRepository projectRepository;
|
||||
|
||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||
public void createCategory(final Integer projectId, final LabelCategoryRequest categoryRequest) {
|
||||
Project project = getProject(projectId);
|
||||
List<LabelCategory> labelCategoryList = labelCategoryRepository.findAllByIdsAndModelId(categoryRequest.getLabelCategoryList(), categoryRequest.getModelId());
|
||||
List<ProjectCategory> projectCategoryList = labelCategoryList.stream().map(o -> ProjectCategory.of(o, project)).toList();
|
||||
projectLabelCategoryRepository.saveAll(projectCategoryList);
|
||||
}
|
||||
|
||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||
public void deleteCategory(final int projectId, final int projectCategoryId) {
|
||||
ProjectCategory projectCategory = getProjectCategory(projectCategoryId);
|
||||
projectLabelCategoryRepository.delete(projectCategory);
|
||||
}
|
||||
|
||||
@CheckPrivilege(PrivilegeType.VIEWER)
|
||||
public LabelCategoryResponse getCategoryById(final int projectId, final int categoryId) {
|
||||
return LabelCategoryResponse.from(getProjectCategory(categoryId).getLabelCategory());
|
||||
}
|
||||
public List<CategoryResponse> getCategoryById(final int projectId) {
|
||||
List<ProjectCategory> categories = projectLabelCategoryRepository.findAllByProjectId(projectId);
|
||||
|
||||
public boolean existByCategoryName(final int projectId, final String categoryName) {
|
||||
return projectLabelCategoryRepository.existsByNameAndProjectId(categoryName, projectId);
|
||||
}
|
||||
|
||||
@CheckPrivilege(PrivilegeType.VIEWER)
|
||||
public List<LabelCategoryResponse> getCategoryList(final Integer projectId) {
|
||||
return projectLabelCategoryRepository.findAllByProjectId(projectId)
|
||||
.stream()
|
||||
.map(ProjectCategory::getLabelCategory)
|
||||
.map(LabelCategoryResponse::from)
|
||||
return categories.stream()
|
||||
.map(category -> CategoryResponse.of(category.getId(), category.getLabelName()))
|
||||
.toList();
|
||||
}
|
||||
|
||||
private Project getProject(Integer projectId) {
|
||||
return projectRepository.findById(projectId)
|
||||
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
|
||||
}
|
||||
|
||||
private ProjectCategory getProjectCategory(final Integer categoryId) {
|
||||
return projectLabelCategoryRepository.findById(categoryId)
|
||||
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,5 @@
|
||||
package com.worlabel.domain.model.controller;
|
||||
|
||||
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
|
||||
import com.worlabel.domain.model.entity.dto.AiModelRequest;
|
||||
import com.worlabel.domain.model.entity.dto.AiModelResponse;
|
||||
import com.worlabel.domain.model.entity.dto.ModelTrainRequest;
|
||||
@ -39,15 +38,6 @@ public class AiModelController {
|
||||
return aiModelService.getModelList(projectId);
|
||||
}
|
||||
|
||||
@Operation(summary = "특정 모델 카테고리", description = "모델의 카테고리를 조회합니다.")
|
||||
@SwaggerApiSuccess(description = "카테고리를 조회합니다.")
|
||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||
@GetMapping("/models/{model_id}/categories")
|
||||
public List<LabelCategoryResponse> getCategories(
|
||||
@PathVariable("model_id") final Integer modelId) {
|
||||
return aiModelService.getCategories(modelId);
|
||||
}
|
||||
|
||||
@Operation(summary = "프로젝트 모델 추가", description = "프로젝트에 있는 모델을 추가합니다.")
|
||||
@SwaggerApiSuccess(description = "프로젝트 모델을 추가합니다.")
|
||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||
|
@ -1,7 +1,6 @@
|
||||
package com.worlabel.domain.model.entity;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.worlabel.domain.labelcategory.entity.LabelCategory;
|
||||
import com.worlabel.domain.project.entity.Project;
|
||||
import com.worlabel.global.common.BaseEntity;
|
||||
import jakarta.persistence.*;
|
||||
@ -9,9 +8,6 @@ import lombok.AccessLevel;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Getter
|
||||
@Entity
|
||||
@Table(name = "ai_model")
|
||||
@ -52,12 +48,6 @@ public class AiModel extends BaseEntity {
|
||||
@JsonIgnore
|
||||
private Project project;
|
||||
|
||||
/**
|
||||
* 모델에 속한 카테고리
|
||||
*/
|
||||
@OneToMany(mappedBy = "aiModel", fetch = FetchType.LAZY)
|
||||
private List<LabelCategory> categoryList = new ArrayList<>();
|
||||
|
||||
private AiModel(String name, String modelKey, int version, Project project) {
|
||||
this.name = name;
|
||||
this.modelKey = modelKey;
|
||||
|
@ -3,19 +3,13 @@ package com.worlabel.domain.model.service;
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.reflect.TypeToken;
|
||||
import com.worlabel.domain.image.entity.Image;
|
||||
import com.worlabel.domain.image.entity.LabelStatus;
|
||||
import com.worlabel.domain.image.repository.ImageRepository;
|
||||
import com.worlabel.domain.labelcategory.entity.LabelCategory;
|
||||
import com.worlabel.domain.labelcategory.entity.ProjectCategory;
|
||||
import com.worlabel.domain.labelcategory.entity.dto.DefaultLabelCategoryResponse;
|
||||
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
|
||||
import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository;
|
||||
import com.worlabel.domain.model.entity.AiModel;
|
||||
import com.worlabel.domain.model.entity.dto.*;
|
||||
import com.worlabel.domain.model.repository.AiModelRepository;
|
||||
import com.worlabel.domain.participant.entity.PrivilegeType;
|
||||
import com.worlabel.domain.progress.service.ProgressService;
|
||||
import com.worlabel.domain.project.dto.AiDto;
|
||||
import com.worlabel.domain.project.dto.AiDto.TrainDataInfo;
|
||||
import com.worlabel.domain.project.dto.AiDto.TrainRequest;
|
||||
import com.worlabel.domain.project.entity.Project;
|
||||
@ -25,7 +19,6 @@ import com.worlabel.domain.report.service.ReportService;
|
||||
import com.worlabel.domain.result.entity.Result;
|
||||
import com.worlabel.domain.result.repository.ResultRepository;
|
||||
import com.worlabel.global.annotation.CheckPrivilege;
|
||||
import com.worlabel.global.cache.CacheKey;
|
||||
import com.worlabel.global.exception.CustomException;
|
||||
import com.worlabel.global.exception.ErrorCode;
|
||||
import com.worlabel.global.service.AiRequestService;
|
||||
@ -38,8 +31,6 @@ import org.springframework.transaction.annotation.Transactional;
|
||||
import java.lang.reflect.Type;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
@ -50,7 +41,6 @@ import java.util.stream.Collectors;
|
||||
@RequiredArgsConstructor
|
||||
public class AiModelService {
|
||||
|
||||
private final LabelCategoryRepository labelCategoryRepository;
|
||||
private final RedisTemplate<String, Object> redisTemplate;
|
||||
private final AiModelRepository aiModelRepository;
|
||||
private final ProjectRepository projectRepository;
|
||||
@ -62,46 +52,6 @@ public class AiModelService {
|
||||
private final ProgressService progressService;
|
||||
private final ReportService reportService;
|
||||
|
||||
// @PostConstruct
|
||||
public void loadDefaultModel() {
|
||||
String url = "model/default";
|
||||
List<DefaultResponse> defaultResponseList = aiRequestService.getRequest(url, this::converter);
|
||||
|
||||
// 1. DefaultResponse의 Key값만 모아서 리스트로 만든다.
|
||||
List<String> allModelKeys = defaultResponseList.stream()
|
||||
.map(response -> response.getDefaultAiModelResponse().getModelKey())
|
||||
.toList();
|
||||
|
||||
// 2. 해당 Key값이 DB에 있는지 확인하기 (한 번의 쿼리로)
|
||||
List<String> existingModelKeys = aiModelRepository.findAllByModelKeyIn(allModelKeys).stream()
|
||||
.map(AiModel::getModelKey)
|
||||
.toList();
|
||||
|
||||
// 3. DB에 없는 Key만 필터링해서 처리
|
||||
List<DefaultResponse> newModel = defaultResponseList.stream()
|
||||
.filter(model -> !existingModelKeys.contains(model.getDefaultAiModelResponse().getModelKey()))
|
||||
.toList();
|
||||
|
||||
|
||||
// 새롭게 추가된 값을 디비에 저장
|
||||
List<AiModel> aiModels = new ArrayList<>();
|
||||
List<LabelCategory> categories = new ArrayList<>();
|
||||
for (DefaultResponse defaultResponse : newModel) {
|
||||
DefaultAiModelResponse defaultAiModelResponse = defaultResponse.getDefaultAiModelResponse();
|
||||
AiModel newAiModel = AiModel.of(defaultAiModelResponse.getName(), defaultAiModelResponse.getModelKey(), 0, null);
|
||||
aiModels.add(newAiModel);
|
||||
|
||||
List<DefaultLabelCategoryResponse> defaultLabelCategoryResponseList = defaultResponse.getDefaultLabelCategoryResponseList();
|
||||
|
||||
for (DefaultLabelCategoryResponse categoryResponse : defaultLabelCategoryResponseList) {
|
||||
categories.add(LabelCategory.of(newAiModel, categoryResponse.getName(), categoryResponse.getAiId()));
|
||||
}
|
||||
}
|
||||
|
||||
aiModelRepository.saveAll(aiModels);
|
||||
labelCategoryRepository.saveAll(categories);
|
||||
}
|
||||
|
||||
@Transactional(readOnly = true)
|
||||
public List<AiModelResponse> getModelList(final Integer projectId) {
|
||||
int progressModelId = progressService.getProgressModelByProjectId(projectId);
|
||||
@ -131,23 +81,15 @@ public class AiModelService {
|
||||
return aiModelRepository.findCustomModelById(modelId).orElseThrow(() -> new CustomException(ErrorCode.BAD_REQUEST));
|
||||
}
|
||||
|
||||
@Transactional(readOnly = true)
|
||||
public List<LabelCategoryResponse> getCategories(final Integer modelId) {
|
||||
List<LabelCategory> categoryList = labelCategoryRepository.findAllByModelId(modelId);
|
||||
return categoryList.stream()
|
||||
.map(LabelCategoryResponse::from)
|
||||
.toList();
|
||||
}
|
||||
|
||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||
public void train(final Integer projectId, final ModelTrainRequest trainRequest) {
|
||||
// FastAPI 서버로 학습 요청을 전송
|
||||
Project project = getProject(projectId);
|
||||
AiModel model = getModel(trainRequest.getModelId());
|
||||
|
||||
Map<Integer, Integer> labelMap = project.getCategoryList().stream()
|
||||
Map<String, Integer> labelMap = project.getCategoryList().stream()
|
||||
.collect(Collectors.toMap(
|
||||
category -> category.getLabelCategory().getId(),
|
||||
ProjectCategory::getLabelName,
|
||||
ProjectCategory::getId
|
||||
));
|
||||
|
||||
|
@ -2,10 +2,7 @@ package com.worlabel.domain.project.controller;
|
||||
|
||||
import com.worlabel.domain.participant.entity.dto.ParticipantRequest;
|
||||
import com.worlabel.domain.project.dto.AutoModelRequest;
|
||||
import com.worlabel.domain.project.entity.dto.ProjectMemberResponse;
|
||||
import com.worlabel.domain.project.entity.dto.ProjectRequest;
|
||||
import com.worlabel.domain.project.entity.dto.ProjectResponse;
|
||||
import com.worlabel.domain.project.entity.dto.ProjectWithThumbnailResponse;
|
||||
import com.worlabel.domain.project.entity.dto.*;
|
||||
import com.worlabel.domain.project.service.ProjectService;
|
||||
import com.worlabel.global.annotation.CurrentUser;
|
||||
import com.worlabel.global.config.swagger.SwaggerApiError;
|
||||
@ -38,7 +35,7 @@ public class ProjectController {
|
||||
public ProjectResponse createProject(
|
||||
@CurrentUser final Integer memberId,
|
||||
@PathVariable("workspace_id") final Integer workspaceId,
|
||||
@Valid @RequestBody final ProjectRequest projectRequest) {
|
||||
@Valid @RequestBody final ProjectWithCategoryRequest projectRequest) {
|
||||
return projectService.createProject(memberId, workspaceId, projectRequest);
|
||||
}
|
||||
|
||||
|
@ -45,7 +45,7 @@ public class AiDto {
|
||||
private Integer modelId;
|
||||
|
||||
@JsonProperty("label_map")
|
||||
private Map<Integer, Integer> labelMap;
|
||||
private Map<String, Integer> labelMap;
|
||||
|
||||
@JsonProperty("data")
|
||||
private List<TrainDataInfo> data;
|
||||
@ -62,7 +62,7 @@ public class AiDto {
|
||||
|
||||
private Optimizer optimizer;
|
||||
|
||||
public static TrainRequest of(final Integer projectId, final Integer modelId, final String modelKey, final Map<Integer, Integer> labelMap, final List<TrainDataInfo> data, final ModelTrainRequest trainRequest) {
|
||||
public static TrainRequest of(final Integer projectId, final Integer modelId, final String modelKey, final Map<String, Integer> labelMap, final List<TrainDataInfo> data, final ModelTrainRequest trainRequest) {
|
||||
TrainRequest request = new TrainRequest();
|
||||
request.projectId = projectId;
|
||||
request.modelId = modelId;
|
||||
@ -93,7 +93,7 @@ public class AiDto {
|
||||
private String modelKey;
|
||||
|
||||
@JsonProperty("label_map")
|
||||
private HashMap<Integer, Integer> labelMap;
|
||||
private HashMap<String, Integer> labelMap;
|
||||
|
||||
@JsonProperty("image_list")
|
||||
private List<AutoLabelingImageRequest> imageList;
|
||||
@ -104,7 +104,7 @@ public class AiDto {
|
||||
@JsonProperty("iou_threshold")
|
||||
private Double iouThreshold;
|
||||
|
||||
public static AutoLabelingRequest of(final Integer projectId, final String modelKey, final HashMap<Integer, Integer> labelMap, final List<AutoLabelingImageRequest> imageList) {
|
||||
public static AutoLabelingRequest of(final Integer projectId, final String modelKey, final HashMap<String, Integer> labelMap, final List<AutoLabelingImageRequest> imageList) {
|
||||
return new AutoLabelingRequest(projectId, modelKey, labelMap, imageList, 0.25, 0.45);
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,31 @@
|
||||
package com.worlabel.domain.project.entity.dto;
|
||||
|
||||
import com.worlabel.domain.project.entity.ProjectType;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import jakarta.validation.constraints.NotEmpty;
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Schema(name = "프로젝트 + 카테고리 요청 dto", description = "프로젝트 + 카테고리 요청 DTO")
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
@AllArgsConstructor
|
||||
@Getter
|
||||
public class ProjectWithCategoryRequest {
|
||||
|
||||
@Schema(description = "프로젝트 제목", example = "삼성 갤럭시 s23")
|
||||
@NotEmpty(message = "제목을 입력하세요.")
|
||||
private String title;
|
||||
|
||||
@Schema(description = "카테고리 목록 이름", example = "자동차, 사람")
|
||||
@NotEmpty(message = "카테고리 목록 이름")
|
||||
private List<String> categories;
|
||||
|
||||
@Schema(description = "프로젝트 유형", example = "classification")
|
||||
@NotNull(message = "카테고리를 입력하세요.")
|
||||
private ProjectType projectType;
|
||||
}
|
@ -7,6 +7,7 @@ import com.worlabel.domain.image.entity.Image;
|
||||
import com.worlabel.domain.image.entity.LabelStatus;
|
||||
import com.worlabel.domain.image.repository.ImageRepository;
|
||||
import com.worlabel.domain.labelcategory.entity.ProjectCategory;
|
||||
import com.worlabel.domain.labelcategory.repository.ProjectLabelCategoryRepository;
|
||||
import com.worlabel.domain.member.entity.Member;
|
||||
import com.worlabel.domain.member.repository.MemberRepository;
|
||||
import com.worlabel.domain.model.entity.AiModel;
|
||||
@ -23,10 +24,7 @@ import com.worlabel.domain.project.dto.AiDto.AutoLabelingRequest;
|
||||
import com.worlabel.domain.project.dto.AiDto.AutoLabelingResult;
|
||||
import com.worlabel.domain.project.dto.AutoModelRequest;
|
||||
import com.worlabel.domain.project.entity.Project;
|
||||
import com.worlabel.domain.project.entity.dto.ProjectMemberResponse;
|
||||
import com.worlabel.domain.project.entity.dto.ProjectRequest;
|
||||
import com.worlabel.domain.project.entity.dto.ProjectResponse;
|
||||
import com.worlabel.domain.project.entity.dto.ProjectWithThumbnailResponse;
|
||||
import com.worlabel.domain.project.entity.dto.*;
|
||||
import com.worlabel.domain.project.repository.ProjectRepository;
|
||||
import com.worlabel.domain.workspace.entity.Workspace;
|
||||
import com.worlabel.domain.workspace.repository.WorkspaceRepository;
|
||||
@ -58,15 +56,15 @@ public class ProjectService {
|
||||
private final ProjectRepository projectRepository;
|
||||
private final MemberRepository memberRepository;
|
||||
private final ProgressService progressService;
|
||||
private final ProjectLabelCategoryRepository projectLabelCategoryRepository;
|
||||
private final S3UploadService s3UploadService;
|
||||
private final ImageRepository imageRepository;
|
||||
private final AiRequestService aiService;
|
||||
|
||||
|
||||
private final Gson gson;
|
||||
|
||||
@Transactional
|
||||
public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectRequest projectRequest) {
|
||||
public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectWithCategoryRequest projectRequest) {
|
||||
Workspace workspace = getWorkspace(memberId, workspaceId);
|
||||
Member member = getMember(memberId);
|
||||
|
||||
@ -74,6 +72,11 @@ public class ProjectService {
|
||||
Participant participant = Participant.of(project, member, PrivilegeType.ADMIN);
|
||||
|
||||
projectRepository.save(project);
|
||||
|
||||
for (String labelName : projectRequest.getCategories()) {
|
||||
projectLabelCategoryRepository.save(ProjectCategory.of(labelName, project));
|
||||
}
|
||||
|
||||
participantRepository.save(participant);
|
||||
|
||||
return ProjectResponse.from(project);
|
||||
@ -173,7 +176,7 @@ public class ProjectService {
|
||||
.map(AutoLabelingImageRequest::of)
|
||||
.toList();
|
||||
|
||||
HashMap<Integer, Integer> labelMap = getLabelMap(project);
|
||||
HashMap<String, Integer> labelMap = getLabelMap(project);
|
||||
|
||||
AiModel aiModel = getAiModel(request);
|
||||
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, aiModel.getModelKey(), labelMap, imageRequestList);
|
||||
@ -219,15 +222,16 @@ public class ProjectService {
|
||||
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
|
||||
}
|
||||
|
||||
private HashMap<Integer, Integer> getLabelMap(Project project) {
|
||||
HashMap<Integer, Integer> labelMap = new HashMap<>();
|
||||
private HashMap<String, Integer> getLabelMap(Project project) {
|
||||
HashMap<String, Integer> labelMap = new HashMap<>();
|
||||
List<ProjectCategory> category = project.getCategoryList();
|
||||
for (ProjectCategory projectCategory : category) {
|
||||
int aiId = projectCategory.getLabelCategory().getAiCategoryId();
|
||||
if (labelMap.containsKey(aiId)) continue;
|
||||
|
||||
labelMap.put(aiId, projectCategory.getId());
|
||||
if (labelMap.containsKey(projectCategory.getLabelName())) {
|
||||
continue;
|
||||
}
|
||||
labelMap.put(projectCategory.getLabelName(), projectCategory.getId());
|
||||
}
|
||||
|
||||
return labelMap;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user