Merge branch 'be/feat/Modeldownload' into 'be/develop'
Feat: AI Model Download API 구현 See merge request s11-s-project/S11P21S002!237
This commit is contained in:
commit
81e9f81d8b
@ -32,10 +32,11 @@ public class ImageController {
|
|||||||
@Operation(summary = "이미지 목록 업로드", description = "이미지 목록을 업로드합니다.")
|
@Operation(summary = "이미지 목록 업로드", description = "이미지 목록을 업로드합니다.")
|
||||||
@SwaggerApiError({ErrorCode.BAD_REQUEST, ErrorCode.NOT_AUTHOR, ErrorCode.SERVER_ERROR})
|
@SwaggerApiError({ErrorCode.BAD_REQUEST, ErrorCode.NOT_AUTHOR, ErrorCode.SERVER_ERROR})
|
||||||
public void uploadImage(
|
public void uploadImage(
|
||||||
|
@CurrentUser final Integer memberId,
|
||||||
@PathVariable("project_id") final Integer projectId,
|
@PathVariable("project_id") final Integer projectId,
|
||||||
@PathVariable("folder_id") final Integer folderId,
|
@PathVariable("folder_id") final Integer folderId,
|
||||||
@Parameter(name = "폴더에 추가 할 이미지 리스트", description = "MultiPartFile을 imageList로 추가해준다.", example = "") @RequestPart final List<MultipartFile> imageList) {
|
@Parameter(name = "폴더에 추가 할 이미지 리스트", description = "MultiPartFile을 imageList로 추가해준다.", example = "") @RequestPart final List<MultipartFile> imageList) {
|
||||||
imageService.uploadImageList(imageList, folderId, projectId);
|
imageService.uploadImageList(imageList, folderId, projectId, memberId);
|
||||||
}
|
}
|
||||||
|
|
||||||
@PostMapping("/folders/{folder_id}/images/zip")
|
@PostMapping("/folders/{folder_id}/images/zip")
|
||||||
@ -43,10 +44,11 @@ public class ImageController {
|
|||||||
@Operation(summary = "압축 폴더 업로드", description = "압축 폴더 내 폴더와 이미지 파일을 업로드합니다.")
|
@Operation(summary = "압축 폴더 업로드", description = "압축 폴더 내 폴더와 이미지 파일을 업로드합니다.")
|
||||||
@SwaggerApiError({ErrorCode.BAD_REQUEST, ErrorCode.NOT_AUTHOR, ErrorCode.SERVER_ERROR})
|
@SwaggerApiError({ErrorCode.BAD_REQUEST, ErrorCode.NOT_AUTHOR, ErrorCode.SERVER_ERROR})
|
||||||
public void uploadFolder(
|
public void uploadFolder(
|
||||||
|
@CurrentUser final Integer memberId,
|
||||||
@Parameter(name = "압축 폴더", description = "압축 폴더를 추가해준다.", example = "") @RequestPart final MultipartFile folderZip,
|
@Parameter(name = "압축 폴더", description = "압축 폴더를 추가해준다.", example = "") @RequestPart final MultipartFile folderZip,
|
||||||
@PathVariable("project_id") final Integer projectId,
|
@PathVariable("project_id") final Integer projectId,
|
||||||
@PathVariable("folder_id") final Integer folderId) throws IOException {
|
@PathVariable("folder_id") final Integer folderId) throws IOException {
|
||||||
imageService.uploadFolderWithImages(folderZip, projectId, folderId);
|
imageService.uploadFolderWithImages(folderZip, projectId, folderId, memberId);
|
||||||
}
|
}
|
||||||
|
|
||||||
@GetMapping("/folders/{folder_id}/images/{image_id}")
|
@GetMapping("/folders/{folder_id}/images/{image_id}")
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package com.worlabel.domain.image.service;
|
package com.worlabel.domain.image.service;
|
||||||
|
|
||||||
|
import com.worlabel.domain.alarm.entity.Alarm;
|
||||||
|
import com.worlabel.domain.alarm.service.AlarmService;
|
||||||
import com.worlabel.domain.folder.entity.Folder;
|
import com.worlabel.domain.folder.entity.Folder;
|
||||||
import com.worlabel.domain.folder.repository.FolderRepository;
|
import com.worlabel.domain.folder.repository.FolderRepository;
|
||||||
import com.worlabel.domain.image.entity.Image;
|
import com.worlabel.domain.image.entity.Image;
|
||||||
@ -12,6 +14,7 @@ import com.worlabel.domain.project.repository.ProjectRepository;
|
|||||||
import com.worlabel.global.annotation.CheckPrivilege;
|
import com.worlabel.global.annotation.CheckPrivilege;
|
||||||
import com.worlabel.global.exception.CustomException;
|
import com.worlabel.global.exception.CustomException;
|
||||||
import com.worlabel.global.exception.ErrorCode;
|
import com.worlabel.global.exception.ErrorCode;
|
||||||
|
import com.worlabel.global.service.FcmService;
|
||||||
import com.worlabel.global.service.S3UploadService;
|
import com.worlabel.global.service.S3UploadService;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@ -48,12 +51,13 @@ public class ImageService {
|
|||||||
private final FolderRepository folderRepository;
|
private final FolderRepository folderRepository;
|
||||||
private final S3UploadService s3UploadService;
|
private final S3UploadService s3UploadService;
|
||||||
private final ImageRepository imageRepository;
|
private final ImageRepository imageRepository;
|
||||||
|
private final AlarmService alarmService;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 이미지 리스트 업로드
|
* 이미지 리스트 업로드
|
||||||
*/
|
*/
|
||||||
@CheckPrivilege(value = PrivilegeType.EDITOR)
|
@CheckPrivilege(value = PrivilegeType.EDITOR)
|
||||||
public void uploadImageList(final List<MultipartFile> imageList, final Integer folderId, final Integer projectId) {
|
public void uploadImageList(final List<MultipartFile> imageList, final Integer folderId, final Integer projectId, final Integer memberId) {
|
||||||
Folder folder = getOrCreateFolder(folderId, projectId);
|
Folder folder = getOrCreateFolder(folderId, projectId);
|
||||||
|
|
||||||
log.debug("folder Id {}, Project Id {}", folder.getId(), folder.getProject().getId());
|
log.debug("folder Id {}, Project Id {}", folder.getId(), folder.getProject().getId());
|
||||||
@ -83,6 +87,7 @@ public class ImageService {
|
|||||||
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
|
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
|
||||||
|
|
||||||
long after = System.currentTimeMillis();
|
long after = System.currentTimeMillis();
|
||||||
|
alarmService.save(memberId, Alarm.AlarmType.IMAGE);
|
||||||
log.debug("업로드 완료 - 경과시간 {}", ((double) after - prev) / 1000);
|
log.debug("업로드 완료 - 경과시간 {}", ((double) after - prev) / 1000);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -90,7 +95,7 @@ public class ImageService {
|
|||||||
* Zip 파일 처리 메서드
|
* Zip 파일 처리 메서드
|
||||||
*/
|
*/
|
||||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||||
public void uploadFolderWithImages(final MultipartFile zipFile, final Integer projectId, final Integer folderId) throws IOException {
|
public void uploadFolderWithImages(final MultipartFile zipFile, final Integer projectId, final Integer folderId, final Integer memberId) throws IOException {
|
||||||
log.debug("파일 크기: {}, 기존 파일 이름: {} ", zipFile.getSize(), zipFile.getOriginalFilename());
|
log.debug("파일 크기: {}, 기존 파일 이름: {} ", zipFile.getSize(), zipFile.getOriginalFilename());
|
||||||
|
|
||||||
Path tmpDir = null;
|
Path tmpDir = null;
|
||||||
@ -109,6 +114,8 @@ public class ImageService {
|
|||||||
|
|
||||||
unzip(zipFile, tmpDir.toString());
|
unzip(zipFile, tmpDir.toString());
|
||||||
processFolderRecursively(tmpDir.toFile(), rootFolder, project);
|
processFolderRecursively(tmpDir.toFile(), rootFolder, project);
|
||||||
|
|
||||||
|
alarmService.save(memberId, Alarm.AlarmType.IMAGE);
|
||||||
} finally {
|
} finally {
|
||||||
if (tmpDir != null) {
|
if (tmpDir != null) {
|
||||||
deleteDirectoryRecursively(tmpDir);
|
deleteDirectoryRecursively(tmpDir);
|
||||||
|
@ -15,6 +15,8 @@ import io.swagger.v3.oas.annotations.tags.Tag;
|
|||||||
import jakarta.validation.Valid;
|
import jakarta.validation.Valid;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.core.io.Resource;
|
||||||
|
import org.springframework.http.ResponseEntity;
|
||||||
import org.springframework.web.bind.annotation.*;
|
import org.springframework.web.bind.annotation.*;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@ -70,4 +72,15 @@ public class AiModelController {
|
|||||||
log.debug("모델 학습 요청 {}", trainRequest);
|
log.debug("모델 학습 요청 {}", trainRequest);
|
||||||
aiModelService.train(memberId, projectId, trainRequest);
|
aiModelService.train(memberId, projectId, trainRequest);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Operation(summary = "프로젝트 모델 다운로드", description = "프로젝트 모델을 다운로드합니다.")
|
||||||
|
@SwaggerApiSuccess(description = "프로젝트 모델이 성공적으로 다운로드됩니다.")
|
||||||
|
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||||
|
@PostMapping("/projects/{project_id}/models/{model_id}/download")
|
||||||
|
public ResponseEntity<Resource> modelDownload(
|
||||||
|
@PathVariable("project_id") final Integer projectId,
|
||||||
|
@PathVariable("model_id") final Integer modelId) {
|
||||||
|
log.debug("다운로드 요청 projectId : {} modelId : {}", projectId, modelId);
|
||||||
|
return aiModelService.modelDownload(projectId, modelId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,9 @@ import com.worlabel.global.exception.ErrorCode;
|
|||||||
import com.worlabel.global.service.AiRequestService;
|
import com.worlabel.global.service.AiRequestService;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.core.io.Resource;
|
||||||
import org.springframework.data.redis.core.RedisTemplate;
|
import org.springframework.data.redis.core.RedisTemplate;
|
||||||
|
import org.springframework.http.ResponseEntity;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
|
||||||
@ -86,28 +88,18 @@ public class AiModelService {
|
|||||||
|
|
||||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||||
public void train(final Integer memberId, final Integer projectId, final ModelTrainRequest trainRequest) {
|
public void train(final Integer memberId, final Integer projectId, final ModelTrainRequest trainRequest) {
|
||||||
// FastAPI 서버로 학습 요청을 전송
|
progressService.trainProgressCheck(projectId, trainRequest.getModelId());
|
||||||
|
|
||||||
|
try {
|
||||||
|
progressService.registerTrainProgress(projectId, trainRequest.getModelId());
|
||||||
|
|
||||||
Project project = getProject(projectId);
|
Project project = getProject(projectId);
|
||||||
AiModel model = getModel(trainRequest.getModelId());
|
AiModel model = getModel(trainRequest.getModelId());
|
||||||
|
TrainRequest aiRequest = getTrainRequest(trainRequest, project, model);
|
||||||
Map<String, Integer> labelMap = project.getCategoryList().stream()
|
|
||||||
.collect(Collectors.toMap(
|
|
||||||
ProjectCategory::getLabelName,
|
|
||||||
ProjectCategory::getId
|
|
||||||
));
|
|
||||||
|
|
||||||
List<Image> images = imageRepository.findImagesByProjectIdAndCompleted(projectId);
|
|
||||||
|
|
||||||
List<TrainDataInfo> data = images.stream()
|
|
||||||
.map(TrainDataInfo::of)
|
|
||||||
.toList();
|
|
||||||
|
|
||||||
TrainRequest aiRequest = TrainRequest.of(project.getId(), model.getId(), model.getModelKey(), labelMap, data, trainRequest);
|
|
||||||
|
|
||||||
String endPoint = project.getProjectType().getValue() + "/train";
|
|
||||||
|
|
||||||
// FastAPI 서버로 POST 요청 전송
|
// FastAPI 서버로 POST 요청 전송
|
||||||
log.debug("요청 DTO :{}", aiRequest);
|
log.debug("요청 DTO :{}", aiRequest);
|
||||||
|
String endPoint = project.getProjectType().getValue() + "/train";
|
||||||
TrainResponse trainResponse = aiRequestService.postRequest(endPoint, aiRequest, TrainResponse.class, this::converterTrain);
|
TrainResponse trainResponse = aiRequestService.postRequest(endPoint, aiRequest, TrainResponse.class, this::converterTrain);
|
||||||
|
|
||||||
// 가져온 modelKey -> version 업된 모델 다시 새롭게 저장
|
// 가져온 modelKey -> version 업된 모델 다시 새롭게 저장
|
||||||
@ -115,18 +107,66 @@ public class AiModelService {
|
|||||||
int newVersion = model.getVersion() + 1;
|
int newVersion = model.getVersion() + 1;
|
||||||
String newName = currentDateTime + String.format("%03d", newVersion);
|
String newName = currentDateTime + String.format("%03d", newVersion);
|
||||||
|
|
||||||
|
// 새로운 모델 저장
|
||||||
AiModel newModel = AiModel.of(newName, trainResponse.getModelKey(), newVersion, project);
|
AiModel newModel = AiModel.of(newName, trainResponse.getModelKey(), newVersion, project);
|
||||||
|
|
||||||
aiModelRepository.save(newModel);
|
aiModelRepository.save(newModel);
|
||||||
|
|
||||||
|
// 결과 저장
|
||||||
Result result = Result.of(newModel, trainResponse, trainRequest);
|
Result result = Result.of(newModel, trainResponse, trainRequest);
|
||||||
|
|
||||||
resultRepository.save(result);
|
resultRepository.save(result);
|
||||||
|
|
||||||
// 레디스 정보 DB에 저장
|
// 레디스 정보 DB에 저장
|
||||||
reportService.changeReport(project.getId(), model.getId(), newModel);
|
reportService.changeReport(project.getId(), model.getId(), newModel);
|
||||||
|
|
||||||
|
// 알람 전송
|
||||||
alarmService.save(memberId, Alarm.AlarmType.TRAIN);
|
alarmService.save(memberId, Alarm.AlarmType.TRAIN);
|
||||||
|
} finally {
|
||||||
|
progressService.removeTrainProgress(projectId, trainRequest.getModelId());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||||
|
public ResponseEntity<Resource> modelDownload(final Integer projectId,final Integer modelId) {
|
||||||
|
AiModel model = getModel(modelId);
|
||||||
|
String modelKey = model.getModelKey();
|
||||||
|
|
||||||
|
String endPoint = "/models/download";
|
||||||
|
endPoint += "?modelKey=" + modelKey;
|
||||||
|
|
||||||
|
ResponseEntity<Resource> fileRequest = aiRequestService.getFileRequest(endPoint);
|
||||||
|
}
|
||||||
|
|
||||||
|
public TrainRequest getTrainRequest(final ModelTrainRequest trainRequest, final Project project, final AiModel model) {
|
||||||
|
Map<String, Integer> labelMap = getLabelMap(project);
|
||||||
|
List<TrainDataInfo> data = getTrainDataInfoList(project.getId());
|
||||||
|
return TrainRequest.of(project.getId(), model.getId(), model.getModelKey(), labelMap, data, trainRequest);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 레이블 맵 만들기
|
||||||
|
private Map<String, Integer> getLabelMap(final Project project) {
|
||||||
|
return project.getCategoryList().stream()
|
||||||
|
.collect(Collectors.toMap(
|
||||||
|
ProjectCategory::getLabelName,
|
||||||
|
ProjectCategory::getId
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Transactional(readOnly = true)
|
||||||
|
public List<TrainDataInfo> getTrainDataInfoList(final Integer projectId) {
|
||||||
|
return imageRepository.findImagesByProjectIdAndCompleted(projectId)
|
||||||
|
.stream()
|
||||||
|
.map(TrainDataInfo::of)
|
||||||
|
.toList();
|
||||||
|
}
|
||||||
|
|
||||||
|
private Project getProject(Integer projectId) {
|
||||||
|
return projectRepository.findById(projectId)
|
||||||
|
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
|
||||||
|
}
|
||||||
|
|
||||||
|
private AiModel getModel(Integer modelId) {
|
||||||
|
return aiModelRepository.findById(modelId)
|
||||||
|
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
|
||||||
}
|
}
|
||||||
|
|
||||||
private TrainResponse converterTrain(String data) {
|
private TrainResponse converterTrain(String data) {
|
||||||
@ -139,27 +179,4 @@ public class AiModelService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Json -> List<DefaultResponse>
|
|
||||||
*/
|
|
||||||
// TODO: 추후 리팩토링 해야함 이건 예시
|
|
||||||
private List<DefaultResponse> converter(String data) {
|
|
||||||
try {
|
|
||||||
Type listType = new TypeToken<List<DefaultResponse>>() {
|
|
||||||
}.getType();
|
|
||||||
return gson.fromJson(data, listType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new CustomException(ErrorCode.BAD_REQUEST);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private Project getProject(Integer projectId) {
|
|
||||||
return projectRepository.findById(projectId)
|
|
||||||
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
|
|
||||||
}
|
|
||||||
|
|
||||||
private AiModel getModel(Integer modelId) {
|
|
||||||
return aiModelRepository.findById(modelId)
|
|
||||||
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -171,6 +171,8 @@ public class ProjectService {
|
|||||||
public void autoLabeling(final Integer memberId, final Integer projectId, final AutoModelRequest request) {
|
public void autoLabeling(final Integer memberId, final Integer projectId, final AutoModelRequest request) {
|
||||||
progressService.predictProgressCheck(projectId);
|
progressService.predictProgressCheck(projectId);
|
||||||
|
|
||||||
|
try {
|
||||||
|
progressService.registerPredictProgress(projectId);
|
||||||
Project project = getProject(projectId);
|
Project project = getProject(projectId);
|
||||||
String endPoint = project.getProjectType().getValue() + "/predict";
|
String endPoint = project.getProjectType().getValue() + "/predict";
|
||||||
|
|
||||||
@ -189,17 +191,17 @@ public class ProjectService {
|
|||||||
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, aiModel.getModelKey(), labelMap, imageRequestList);
|
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, aiModel.getModelKey(), labelMap, imageRequestList);
|
||||||
|
|
||||||
log.debug("요청 {}", autoLabelingRequest);
|
log.debug("요청 {}", autoLabelingRequest);
|
||||||
progressService.registerPredictProgress(projectId);
|
|
||||||
List<AutoLabelingResult> list = aiService.postRequest(endPoint, autoLabelingRequest, List.class, this::converter);
|
List<AutoLabelingResult> list = aiService.postRequest(endPoint, autoLabelingRequest, List.class, this::converter);
|
||||||
log.debug("완료 후 삭제:{}", list);
|
saveAutoLabelList(list);
|
||||||
|
|
||||||
alarmService.save(memberId, Alarm.AlarmType.PREDICT);
|
alarmService.save(memberId, Alarm.AlarmType.PREDICT);
|
||||||
|
} finally {
|
||||||
progressService.removePredictProgress(projectId);
|
progressService.removePredictProgress(projectId);
|
||||||
|
|
||||||
saveAutoLabelList(list);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: 트랜잭션 설정
|
}
|
||||||
|
|
||||||
@Transactional
|
@Transactional
|
||||||
public void saveAutoLabelList(final List<AutoLabelingResult> resultList) {
|
public void saveAutoLabelList(final List<AutoLabelingResult> resultList) {
|
||||||
for (AutoLabelingResult result : resultList) {
|
for (AutoLabelingResult result : resultList) {
|
||||||
|
@ -5,6 +5,7 @@ import com.worlabel.global.exception.ErrorCode;
|
|||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.core.io.Resource;
|
||||||
import org.springframework.http.HttpEntity;
|
import org.springframework.http.HttpEntity;
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.HttpMethod;
|
import org.springframework.http.HttpMethod;
|
||||||
@ -52,6 +53,14 @@ public class AiRequestService {
|
|||||||
return converter.apply(data);
|
return converter.apply(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ResponseEntity<Resource> getFileRequest(String endPoint) {
|
||||||
|
String url = createApiUrl(endPoint);
|
||||||
|
HttpEntity<Void> request = new HttpEntity<>(createJsonHeaders());
|
||||||
|
|
||||||
|
ResponseEntity<Resource> exchange = restTemplate.exchange(url, HttpMethod.GET, request, Resource.class);
|
||||||
|
return exchange;
|
||||||
|
}
|
||||||
|
|
||||||
// 응답이 없는 요청인 경우 예 : 오토 레이블링 요청
|
// 응답이 없는 요청인 경우 예 : 오토 레이블링 요청
|
||||||
private <T> void sendVoidRequest(String url, HttpEntity<T> requestEntity) {
|
private <T> void sendVoidRequest(String url, HttpEntity<T> requestEntity) {
|
||||||
try {
|
try {
|
||||||
|
Loading…
Reference in New Issue
Block a user