Feat: AI Model Download API 구현

This commit is contained in:
김용수 2024-09-30 00:16:51 +09:00
parent f4c26ec6c8
commit 43c36919b2
3 changed files with 93 additions and 54 deletions

View File

@ -15,6 +15,8 @@ import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.io.Resource;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import java.util.List;
@ -70,4 +72,15 @@ public class AiModelController {
log.debug("모델 학습 요청 {}", trainRequest);
aiModelService.train(memberId, projectId, trainRequest);
}
@Operation(summary = "프로젝트 모델 다운로드", description = "프로젝트 모델을 다운로드합니다.")
@SwaggerApiSuccess(description = "프로젝트 모델이 성공적으로 다운로드됩니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@PostMapping("/projects/{project_id}/models/{model_id}/download")
public ResponseEntity<Resource> modelDownload(
@PathVariable("project_id") final Integer projectId,
@PathVariable("model_id") final Integer modelId) {
log.debug("다운로드 요청 projectId : {} modelId : {}", projectId, modelId);
return aiModelService.modelDownload(projectId, modelId);
}
}

View File

@ -26,7 +26,9 @@ import com.worlabel.global.exception.ErrorCode;
import com.worlabel.global.service.AiRequestService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.io.Resource;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@ -86,28 +88,18 @@ public class AiModelService {
@CheckPrivilege(PrivilegeType.EDITOR)
public void train(final Integer memberId, final Integer projectId, final ModelTrainRequest trainRequest) {
// FastAPI 서버로 학습 요청을 전송
progressService.trainProgressCheck(projectId, trainRequest.getModelId());
try {
progressService.registerTrainProgress(projectId, trainRequest.getModelId());
Project project = getProject(projectId);
AiModel model = getModel(trainRequest.getModelId());
Map<String, Integer> labelMap = project.getCategoryList().stream()
.collect(Collectors.toMap(
ProjectCategory::getLabelName,
ProjectCategory::getId
));
List<Image> images = imageRepository.findImagesByProjectIdAndCompleted(projectId);
List<TrainDataInfo> data = images.stream()
.map(TrainDataInfo::of)
.toList();
TrainRequest aiRequest = TrainRequest.of(project.getId(), model.getId(), model.getModelKey(), labelMap, data, trainRequest);
String endPoint = project.getProjectType().getValue() + "/train";
TrainRequest aiRequest = getTrainRequest(trainRequest, project, model);
// FastAPI 서버로 POST 요청 전송
log.debug("요청 DTO :{}", aiRequest);
String endPoint = project.getProjectType().getValue() + "/train";
TrainResponse trainResponse = aiRequestService.postRequest(endPoint, aiRequest, TrainResponse.class, this::converterTrain);
// 가져온 modelKey -> version 업된 모델 다시 새롭게 저장
@ -115,18 +107,66 @@ public class AiModelService {
int newVersion = model.getVersion() + 1;
String newName = currentDateTime + String.format("%03d", newVersion);
// 새로운 모델 저장
AiModel newModel = AiModel.of(newName, trainResponse.getModelKey(), newVersion, project);
aiModelRepository.save(newModel);
// 결과 저장
Result result = Result.of(newModel, trainResponse, trainRequest);
resultRepository.save(result);
// 레디스 정보 DB에 저장
reportService.changeReport(project.getId(), model.getId(), newModel);
// 알람 전송
alarmService.save(memberId, Alarm.AlarmType.TRAIN);
} finally {
progressService.removeTrainProgress(projectId, trainRequest.getModelId());
}
}
@CheckPrivilege(PrivilegeType.EDITOR)
public ResponseEntity<Resource> modelDownload(final Integer projectId,final Integer modelId) {
AiModel model = getModel(modelId);
String modelKey = model.getModelKey();
String endPoint = "/models/download";
endPoint += "?modelKey=" + modelKey;
ResponseEntity<Resource> fileRequest = aiRequestService.getFileRequest(endPoint);
}
public TrainRequest getTrainRequest(final ModelTrainRequest trainRequest, final Project project, final AiModel model) {
Map<String, Integer> labelMap = getLabelMap(project);
List<TrainDataInfo> data = getTrainDataInfoList(project.getId());
return TrainRequest.of(project.getId(), model.getId(), model.getModelKey(), labelMap, data, trainRequest);
}
// 레이블 만들기
private Map<String, Integer> getLabelMap(final Project project) {
return project.getCategoryList().stream()
.collect(Collectors.toMap(
ProjectCategory::getLabelName,
ProjectCategory::getId
));
}
@Transactional(readOnly = true)
public List<TrainDataInfo> getTrainDataInfoList(final Integer projectId) {
return imageRepository.findImagesByProjectIdAndCompleted(projectId)
.stream()
.map(TrainDataInfo::of)
.toList();
}
private Project getProject(Integer projectId) {
return projectRepository.findById(projectId)
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
}
private AiModel getModel(Integer modelId) {
return aiModelRepository.findById(modelId)
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
}
private TrainResponse converterTrain(String data) {
@ -139,27 +179,4 @@ public class AiModelService {
}
}
/**
* Json -> List<DefaultResponse>
*/
// TODO: 추후 리팩토링 해야함 이건 예시
private List<DefaultResponse> converter(String data) {
try {
Type listType = new TypeToken<List<DefaultResponse>>() {
}.getType();
return gson.fromJson(data, listType);
} catch (Exception e) {
throw new CustomException(ErrorCode.BAD_REQUEST);
}
}
private Project getProject(Integer projectId) {
return projectRepository.findById(projectId)
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
}
private AiModel getModel(Integer modelId) {
return aiModelRepository.findById(modelId)
.orElseThrow(() -> new CustomException(ErrorCode.DATA_NOT_FOUND));
}
}

View File

@ -5,6 +5,7 @@ import com.worlabel.global.exception.ErrorCode;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
@ -52,6 +53,14 @@ public class AiRequestService {
return converter.apply(data);
}
public ResponseEntity<Resource> getFileRequest(String endPoint) {
String url = createApiUrl(endPoint);
HttpEntity<Void> request = new HttpEntity<>(createJsonHeaders());
ResponseEntity<Resource> exchange = restTemplate.exchange(url, HttpMethod.GET, request, Resource.class);
return exchange;
}
// 응답이 없는 요청인 경우 : 오토 레이블링 요청
private <T> void sendVoidRequest(String url, HttpEntity<T> requestEntity) {
try {