Merge branch 'be/feat/model' into 'be/develop'
Feat: 모델 관련 API 구현 See merge request s11-s-project/S11P21S002!119
This commit is contained in:
commit
64c549c8bc
@ -10,6 +10,7 @@ import java.util.Optional;
|
|||||||
|
|
||||||
public interface ImageRepository extends JpaRepository<Image, Long> {
|
public interface ImageRepository extends JpaRepository<Image, Long> {
|
||||||
|
|
||||||
|
// todo N + 1 발생할듯
|
||||||
@Query("select i from Image i " +
|
@Query("select i from Image i " +
|
||||||
"where i.folder.project.id = :projectId")
|
"where i.folder.project.id = :projectId")
|
||||||
List<Image> findImagesByProjectId(@Param("projectId") Integer projectId);
|
List<Image> findImagesByProjectId(@Param("projectId") Integer projectId);
|
||||||
|
@ -68,7 +68,14 @@ public class AiModelController {
|
|||||||
aiModelService.renameModel(memberId, projectId, modelId, aiModelRequest);
|
aiModelService.renameModel(memberId, projectId, modelId, aiModelRequest);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: 여기서 모델 학습을 따로 만들어야 할 듯 Project 있는 모델 학습을 여기로 옮겨서 진행
|
@Operation(summary = "프로젝트 모델 학습", description = "프로젝트 모델을 학습시킵니다.")
|
||||||
// 아마도 필요한 요청 값들은 ModelID
|
@SwaggerApiSuccess(description = "프로젝트 모델이 성공적으로 학습됩니다.")
|
||||||
|
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||||
|
@PostMapping("/projects/{project_id}/train")
|
||||||
|
public void trainModel(
|
||||||
|
@CurrentUser final Integer memberId,
|
||||||
|
@PathVariable("project_id") final Integer projectId,
|
||||||
|
@RequestBody final Integer modelId) {
|
||||||
|
aiModelService.train(memberId, projectId, modelId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -18,4 +18,6 @@ public interface AiModelRepository extends JpaRepository<AiModel, Integer> {
|
|||||||
@Query("SELECT a FROM AiModel a " +
|
@Query("SELECT a FROM AiModel a " +
|
||||||
"WHERE a.project IS NOT NULL AND a.id = :modelId")
|
"WHERE a.project IS NOT NULL AND a.id = :modelId")
|
||||||
Optional<AiModel> findCustomModelById(@Param("modelId") int modelId);
|
Optional<AiModel> findCustomModelById(@Param("modelId") int modelId);
|
||||||
|
|
||||||
|
List<AiModel> findAllByModelKeyIn(List<String> allModelKeys);
|
||||||
}
|
}
|
||||||
|
@ -1,29 +1,38 @@
|
|||||||
package com.worlabel.domain.model.service;
|
package com.worlabel.domain.model.service;
|
||||||
|
|
||||||
|
|
||||||
import com.google.gson.Gson;
|
import com.google.gson.Gson;
|
||||||
import com.google.gson.reflect.TypeToken;
|
import com.google.gson.reflect.TypeToken;
|
||||||
|
import com.worlabel.domain.image.entity.Image;
|
||||||
|
import com.worlabel.domain.image.entity.LabelStatus;
|
||||||
|
import com.worlabel.domain.image.repository.ImageRepository;
|
||||||
import com.worlabel.domain.labelcategory.entity.LabelCategory;
|
import com.worlabel.domain.labelcategory.entity.LabelCategory;
|
||||||
|
import com.worlabel.domain.labelcategory.entity.dto.DefaultLabelCategoryResponse;
|
||||||
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
|
import com.worlabel.domain.labelcategory.entity.dto.LabelCategoryResponse;
|
||||||
import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository;
|
import com.worlabel.domain.labelcategory.repository.LabelCategoryRepository;
|
||||||
import com.worlabel.domain.model.entity.AiModel;
|
import com.worlabel.domain.model.entity.AiModel;
|
||||||
import com.worlabel.domain.model.entity.dto.AiModelRequest;
|
import com.worlabel.domain.model.entity.dto.AiModelRequest;
|
||||||
import com.worlabel.domain.model.entity.dto.AiModelResponse;
|
import com.worlabel.domain.model.entity.dto.AiModelResponse;
|
||||||
|
import com.worlabel.domain.model.entity.dto.DefaultAiModelResponse;
|
||||||
import com.worlabel.domain.model.entity.dto.DefaultResponse;
|
import com.worlabel.domain.model.entity.dto.DefaultResponse;
|
||||||
import com.worlabel.domain.model.repository.AiModelRepository;
|
import com.worlabel.domain.model.repository.AiModelRepository;
|
||||||
import com.worlabel.domain.participant.entity.PrivilegeType;
|
import com.worlabel.domain.participant.entity.PrivilegeType;
|
||||||
|
import com.worlabel.domain.project.dto.RequestDto;
|
||||||
import com.worlabel.domain.project.entity.Project;
|
import com.worlabel.domain.project.entity.Project;
|
||||||
import com.worlabel.domain.project.repository.ProjectRepository;
|
import com.worlabel.domain.project.repository.ProjectRepository;
|
||||||
import com.worlabel.global.annotation.CheckPrivilege;
|
import com.worlabel.global.annotation.CheckPrivilege;
|
||||||
import com.worlabel.global.exception.CustomException;
|
import com.worlabel.global.exception.CustomException;
|
||||||
import com.worlabel.global.exception.ErrorCode;
|
import com.worlabel.global.exception.ErrorCode;
|
||||||
import com.worlabel.global.service.AiRequestService;
|
import com.worlabel.global.service.AiRequestService;
|
||||||
|
import jakarta.annotation.PostConstruct;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
import java.lang.reflect.Type;
|
import java.lang.reflect.Type;
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
import java.time.format.DateTimeFormatter;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@ -35,9 +44,50 @@ public class AiModelService {
|
|||||||
private final AiModelRepository aiModelRepository;
|
private final AiModelRepository aiModelRepository;
|
||||||
private final ProjectRepository projectRepository;
|
private final ProjectRepository projectRepository;
|
||||||
private final LabelCategoryRepository labelCategoryRepository;
|
private final LabelCategoryRepository labelCategoryRepository;
|
||||||
|
private final ImageRepository imageRepository;
|
||||||
private final AiRequestService aiRequestService;
|
private final AiRequestService aiRequestService;
|
||||||
private final Gson gson;
|
private final Gson gson;
|
||||||
|
|
||||||
|
@PostConstruct
|
||||||
|
public void loadDefaultModel() {
|
||||||
|
String url = "model/default";
|
||||||
|
List<DefaultResponse> defaultResponseList = aiRequestService.getRequest(url, this::converter);
|
||||||
|
|
||||||
|
// 1. DefaultResponse의 Key값만 모아서 리스트로 만든다.
|
||||||
|
List<String> allModelKeys = defaultResponseList.stream()
|
||||||
|
.map(response -> response.getDefaultAiModelResponse().getModelKey())
|
||||||
|
.toList();
|
||||||
|
|
||||||
|
// 2. 해당 Key값이 DB에 있는지 확인하기 (한 번의 쿼리로)
|
||||||
|
List<String> existingModelKeys = aiModelRepository.findAllByModelKeyIn(allModelKeys).stream()
|
||||||
|
.map(AiModel::getModelKey)
|
||||||
|
.toList();
|
||||||
|
|
||||||
|
// 3. DB에 없는 Key만 필터링해서 처리
|
||||||
|
List<DefaultResponse> newModel = defaultResponseList.stream()
|
||||||
|
.filter(model -> !existingModelKeys.contains(model.getDefaultAiModelResponse().getModelKey()))
|
||||||
|
.toList();
|
||||||
|
|
||||||
|
|
||||||
|
// 새롭게 추가된 값을 디비에 저장
|
||||||
|
List<AiModel> aiModels = new ArrayList<>();
|
||||||
|
List<LabelCategory> categories = new ArrayList<>();
|
||||||
|
for (DefaultResponse defaultResponse : newModel) {
|
||||||
|
DefaultAiModelResponse defaultAiModelResponse = defaultResponse.getDefaultAiModelResponse();
|
||||||
|
AiModel newAiModel = AiModel.of(defaultAiModelResponse.getName(), defaultAiModelResponse.getModelKey(), 0, null);
|
||||||
|
aiModels.add(newAiModel);
|
||||||
|
|
||||||
|
List<DefaultLabelCategoryResponse> defaultLabelCategoryResponseList = defaultResponse.getDefaultLabelCategoryResponseList();
|
||||||
|
|
||||||
|
for (DefaultLabelCategoryResponse categoryResponse : defaultLabelCategoryResponseList) {
|
||||||
|
categories.add(LabelCategory.of(newAiModel, categoryResponse.getName(), categoryResponse.getAiId()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
aiModelRepository.saveAll(aiModels);
|
||||||
|
labelCategoryRepository.saveAll(categories);
|
||||||
|
}
|
||||||
|
|
||||||
@Transactional(readOnly = true)
|
@Transactional(readOnly = true)
|
||||||
public List<AiModelResponse> getModelList(final Integer projectId) {
|
public List<AiModelResponse> getModelList(final Integer projectId) {
|
||||||
return aiModelRepository.findAllByProjectId(projectId)
|
return aiModelRepository.findAllByProjectId(projectId)
|
||||||
@ -74,21 +124,42 @@ public class AiModelService {
|
|||||||
.toList();
|
.toList();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||||
* 해당 Default 모델 불러오기 API 예시
|
public void train(Integer memberId, Integer projectId, Integer modelId) {
|
||||||
|
// TODO: 레디스 train 테이블에 존재하는지 확인 -> 이미 있으면 있다고 예외를 던져준다. -> 용수 추후 구현 예정
|
||||||
|
/*
|
||||||
|
없으면 redis 상태 테이블을 만든다. progressTable
|
||||||
*/
|
*/
|
||||||
// TODO : 스프링이 로딩 후 DefaultModel을 불러온다.
|
|
||||||
public void loadDefaultModel() {
|
|
||||||
String url = "model/default";
|
|
||||||
List<DefaultResponse> defaultResponseList = aiRequestService.getRequest(url, this::converter);
|
|
||||||
|
|
||||||
// TODO: defaultModel 현재 DB에 해당하는지 안하는지 확인하기
|
// FastAPI 서버로 학습 요청을 전송
|
||||||
|
Project project = getProject(projectId);
|
||||||
|
AiModel model = getModel(modelId);
|
||||||
|
List<LabelCategory> labelCategories = labelCategoryRepository.findAllByModelId(modelId);
|
||||||
|
List<Integer> categories = labelCategories.stream()
|
||||||
|
.map(LabelCategory::getAiCategoryId).toList();
|
||||||
|
|
||||||
// TODO : 1.DefaultResponse의 Key값만 모아서 리스트로 만든다.
|
List<Image> images = imageRepository.findImagesByProjectId(projectId);
|
||||||
|
|
||||||
// TODO: 2. 그 중 IN(key...)로 해당되는 Key값 확인하기
|
List<RequestDto.TrainDataInfo> data = images.stream().filter(image -> image.getStatus() == LabelStatus.COMPLETED)
|
||||||
|
.map(image -> new RequestDto.TrainDataInfo(image.getImagePath(), image.getDataPath()))
|
||||||
|
.toList();
|
||||||
|
|
||||||
// TODO: 3. 현재 DB에 없는 Key만 모아서 DB와 CategoryList에 넣어주면 됨
|
String endPoint = project.getProjectType().getValue() + "/train";
|
||||||
|
|
||||||
|
RequestDto.TrainRequest trainRequest = new RequestDto.TrainRequest();
|
||||||
|
trainRequest.setProjectId(projectId);
|
||||||
|
trainRequest.setCategoryId(categories);
|
||||||
|
trainRequest.setData(data);
|
||||||
|
trainRequest.setModelKey(model.getModelKey());
|
||||||
|
|
||||||
|
// FastAPI 서버로 POST 요청 전송
|
||||||
|
String modelKey = aiRequestService.postRequest(endPoint, trainRequest, String.class, response -> response);
|
||||||
|
|
||||||
|
// 가져온 modelKey -> version 업된 모델 다시 새롭게 저장
|
||||||
|
String currentDateTime = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"));
|
||||||
|
|
||||||
|
AiModel newModel = AiModel.of(currentDateTime, modelKey, model.getVersion() + 1, project);
|
||||||
|
aiModelRepository.save(newModel);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -97,7 +168,8 @@ public class AiModelService {
|
|||||||
// TODO: 추후 리팩토링 해야함 이건 예시
|
// TODO: 추후 리팩토링 해야함 이건 예시
|
||||||
private List<DefaultResponse> converter(String data) {
|
private List<DefaultResponse> converter(String data) {
|
||||||
try {
|
try {
|
||||||
Type listType = new TypeToken<List<DefaultResponse>>() {}.getType();
|
Type listType = new TypeToken<List<DefaultResponse>>() {
|
||||||
|
}.getType();
|
||||||
return gson.fromJson(data, listType);
|
return gson.fromJson(data, listType);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.debug("TODO: 추후 리팩토링 해야함 이건 예시");
|
log.debug("TODO: 추후 리팩토링 해야함 이건 예시");
|
||||||
@ -106,7 +178,12 @@ public class AiModelService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private Project getProject(Integer projectId) {
|
private Project getProject(Integer projectId) {
|
||||||
return projectRepository.findById(projectId).orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
|
return projectRepository.findById(projectId)
|
||||||
|
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private AiModel getModel(Integer modelId) {
|
||||||
|
return aiModelRepository.findById(modelId)
|
||||||
|
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -72,15 +72,15 @@ public class ProjectController {
|
|||||||
return projectService.updateProject(memberId, projectId, projectRequest);
|
return projectService.updateProject(memberId, projectId, projectRequest);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Operation(summary = "프로젝트 모델 학습", description = "프로젝트 모델을 학습시킵니다.")
|
// @Operation(summary = "프로젝트 모델 학습", description = "프로젝트 모델을 학습시킵니다.")
|
||||||
@SwaggerApiSuccess(description = "프로젝트 모델이 성공적으로 학습됩니다.")
|
// @SwaggerApiSuccess(description = "프로젝트 모델이 성공적으로 학습됩니다.")
|
||||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
// @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||||
@PostMapping("/projects/{project_id}/train")
|
// @PostMapping("/projects/{project_id}/train")
|
||||||
public void trainModel(
|
// public void trainModel(
|
||||||
@CurrentUser final Integer memberId,
|
// @CurrentUser final Integer memberId,
|
||||||
@PathVariable("project_id") final Integer projectId) {
|
// @PathVariable("project_id") final Integer projectId) {
|
||||||
projectService.train(memberId, projectId);
|
// projectService.train(memberId, projectId);
|
||||||
}
|
// }
|
||||||
|
|
||||||
@Operation(summary = "프로젝트 오토 레이블링", description = "해당 프로젝트 이미지를 오토레이블링합니다.")
|
@Operation(summary = "프로젝트 오토 레이블링", description = "해당 프로젝트 이미지를 오토레이블링합니다.")
|
||||||
@SwaggerApiSuccess(description = "해당 프로젝트가 오토 레이블링 됩니다.")
|
@SwaggerApiSuccess(description = "해당 프로젝트가 오토 레이블링 됩니다.")
|
||||||
|
@ -8,8 +8,14 @@ import java.util.List;
|
|||||||
public class RequestDto {
|
public class RequestDto {
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class TrainDataInfo {
|
public static class TrainDataInfo {
|
||||||
private String imageUrl;
|
private String imagePath;
|
||||||
|
private String dataPath;
|
||||||
|
|
||||||
|
public TrainDataInfo(String imagePath, String dataPath) {
|
||||||
|
this.imagePath = imagePath;
|
||||||
|
this.dataPath = dataPath;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@ -17,8 +23,14 @@ public class RequestDto {
|
|||||||
@JsonProperty("project_id")
|
@JsonProperty("project_id")
|
||||||
private int projectId;
|
private int projectId;
|
||||||
|
|
||||||
|
@JsonProperty("category_id")
|
||||||
|
private List<Integer> categoryId;
|
||||||
|
|
||||||
@JsonProperty("data")
|
@JsonProperty("data")
|
||||||
private List<TrainDataInfo> data;
|
private List<TrainDataInfo> data;
|
||||||
|
|
||||||
|
@JsonProperty("model_key")
|
||||||
|
private String modelKey;
|
||||||
// private int seed; // Optional
|
// private int seed; // Optional
|
||||||
// private float ratio; // Default = 0.8
|
// private float ratio; // Default = 0.8
|
||||||
// private int epochs; // Default = 50
|
// private int epochs; // Default = 50
|
||||||
|
@ -134,26 +134,26 @@ public class ProjectService {
|
|||||||
participantRepository.delete(participant);
|
participantRepository.delete(participant);
|
||||||
}
|
}
|
||||||
|
|
||||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
// @CheckPrivilege(PrivilegeType.EDITOR)
|
||||||
public void train(final Integer memberId, final Integer projectId) {
|
// public void train(final Integer memberId, final Integer projectId) {
|
||||||
// TODO: 레디스 train 테이블에 존재하는지 확인 -> 이미 있으면 있다고 예외를 던져준다. -> 용수 추후 구현 예정
|
// // TODO: 레디스 train 테이블에 존재하는지 확인 -> 이미 있으면 있다고 예외를 던져준다. -> 용수 추후 구현 예정
|
||||||
/*
|
// /*
|
||||||
없으면 redis 상태 테이블을 만든다. progressTable
|
// 없으면 redis 상태 테이블을 만든다. progressTable
|
||||||
*/
|
// */
|
||||||
|
//
|
||||||
// FastAPI 서버로 학습 요청을 전송
|
// // FastAPI 서버로 학습 요청을 전송
|
||||||
Project project = getProject(projectId);
|
// Project project = getProject(projectId);
|
||||||
String endPoint = project.getProjectType().getValue() + "/train";
|
// String endPoint = project.getProjectType().getValue() + "/train";
|
||||||
|
//
|
||||||
TrainRequest trainRequest = new TrainRequest();
|
// TrainRequest trainRequest = new TrainRequest();
|
||||||
trainRequest.setProjectId(projectId);
|
// trainRequest.setProjectId(projectId);
|
||||||
trainRequest.setData(List.of());
|
// trainRequest.setData(List.of());
|
||||||
|
//
|
||||||
// FastAPI 서버로 POST 요청 전송
|
// // FastAPI 서버로 POST 요청 전송
|
||||||
String modelKey = aiService.postRequest(endPoint, trainRequest, String.class, response -> response);
|
// String modelKey = aiService.postRequest(endPoint, trainRequest, String.class, response -> response);
|
||||||
|
//
|
||||||
// TODO: 모델 생성 후 Default 이름과 Key 값 설정
|
// // TODO: 모델 생성 후 Default 이름과 Key 값 설정
|
||||||
}
|
// }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 프로젝트별 오토 레이블링
|
* 프로젝트별 오토 레이블링
|
||||||
|
Loading…
Reference in New Issue
Block a user