Feat: 오토레이블링, 모델 학습 완료 알림 전송

This commit is contained in:
김용수 2024-09-27 14:06:33 +09:00
parent ffe79cc79b
commit ecca6d3876
10 changed files with 54 additions and 19 deletions

View File

@ -3,6 +3,7 @@ package com.worlabel.domain.alarm.service;
import com.worlabel.domain.alarm.entity.Alarm;
import com.worlabel.domain.alarm.entity.Alarm.AlarmType;
import com.worlabel.domain.alarm.repository.AlarmCacheRepository;
import com.worlabel.global.service.FcmService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@ -15,9 +16,11 @@ import java.util.List;
public class AlarmService {
private final AlarmCacheRepository alarmCacheRepository;
private final FcmService fcmService;
public void save(int memberId, AlarmType type) {
alarmCacheRepository.save(memberId, type);
fcmService.send(memberId, type.toString());
}
public List<Alarm> getAlarmList(int memberId){

View File

@ -110,7 +110,7 @@ public class AuthController {
@PostMapping("/test")
public void testSend(@CurrentUser final Integer memberId) {
String token = fcmCacheRepository.getToken(memberId);
fcmService.testSend(token, "test알림입니다.");
fcmService.testSend(token);
}
private static String parseRefreshCookie(HttpServletRequest request) {

View File

@ -64,9 +64,10 @@ public class AiModelController {
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@PostMapping("/projects/{project_id}/train")
public void trainModel(
@CurrentUser final Integer memberId,
@PathVariable("project_id") final Integer projectId,
@RequestBody final ModelTrainRequest trainRequest) {
log.debug("모델 학습 요청 {}", trainRequest);
aiModelService.train(projectId, trainRequest);
aiModelService.train(memberId, projectId, trainRequest);
}
}

View File

@ -2,6 +2,8 @@ package com.worlabel.domain.model.service;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import com.worlabel.domain.alarm.entity.Alarm;
import com.worlabel.domain.alarm.service.AlarmService;
import com.worlabel.domain.image.entity.Image;
import com.worlabel.domain.image.repository.ImageRepository;
import com.worlabel.domain.labelcategory.entity.ProjectCategory;
@ -41,16 +43,16 @@ import java.util.stream.Collectors;
@RequiredArgsConstructor
public class AiModelService {
private final RedisTemplate<String, Object> redisTemplate;
private final AiModelRepository aiModelRepository;
private final ProjectRepository projectRepository;
private final AiRequestService aiRequestService;
private final ImageRepository imageRepository;
private final ResultRepository resultRepository;
private final ProjectService projectService;
private final Gson gson;
private final AiRequestService aiRequestService;
private final ProgressService progressService;
private final ImageRepository imageRepository;
private final ReportService reportService;
private final AlarmService alarmService;
private final Gson gson;
@Transactional(readOnly = true)
public List<AiModelResponse> getModelList(final Integer projectId) {
@ -82,7 +84,7 @@ public class AiModelService {
}
@CheckPrivilege(PrivilegeType.EDITOR)
public void train(final Integer projectId, final ModelTrainRequest trainRequest) {
public void train(final Integer memberId, final Integer projectId, final ModelTrainRequest trainRequest) {
// FastAPI 서버로 학습 요청을 전송
Project project = getProject(projectId);
AiModel model = getModel(trainRequest.getModelId());
@ -122,6 +124,8 @@ public class AiModelService {
// 레디스 정보 DB에 저장
reportService.changeReport(project.getId(), model.getId(), newModel);
alarmService.save(memberId, Alarm.AlarmType.TRAIN);
}
private TrainResponse converterTrain(String data) {

View File

@ -19,7 +19,7 @@ public class ProgressService {
public void predictProgressCheck(final int projectId) {
if (progressCacheRepository.predictProgressCheck(projectId)) {
throw new CustomException(ErrorCode.AI_IN_PROGRESS);
throw new CustomException(ErrorCode.AI_IN_PROGRESS, "해당 프로젝트 오토레이블링 진행 중");
}
}

View File

@ -74,9 +74,10 @@ public class ProjectController {
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@PostMapping("/projects/{project_id}/auto")
public void autoLabeling(
@CurrentUser final Integer memberId,
@PathVariable("project_id") final Integer projectId,
@RequestBody final AutoModelRequest request) {
projectService.autoLabeling(projectId, request);
projectService.autoLabeling(memberId, projectId, request);
}
@Operation(summary = "프로젝트 삭제", description = "프로젝트를 삭제합니다.")

View File

@ -125,10 +125,10 @@ public class AiDto {
}
}
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@Getter
@ToString
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@AllArgsConstructor(access = AccessLevel.PRIVATE)
public static class AutoLabelingResult {
@SerializedName("image_id")
@ -137,4 +137,5 @@ public class AiDto {
@SerializedName("data")
private String data;
}
}

View File

@ -26,6 +26,6 @@ public class ProjectWithCategoryRequest {
private List<String> categories;
@Schema(description = "프로젝트 유형", example = "classification")
@NotNull(message = "카테고리를 입력하세요.")
@NotNull(message = "프로젝트 타입을 선택해주세요.")
private ProjectType projectType;
}

View File

@ -3,6 +3,8 @@ package com.worlabel.domain.project.service;
import com.google.gson.Gson;
import com.google.gson.JsonSyntaxException;
import com.google.gson.reflect.TypeToken;
import com.worlabel.domain.alarm.entity.Alarm;
import com.worlabel.domain.alarm.service.AlarmService;
import com.worlabel.domain.image.entity.Image;
import com.worlabel.domain.image.entity.LabelStatus;
import com.worlabel.domain.image.repository.ImageRepository;
@ -32,6 +34,7 @@ import com.worlabel.global.annotation.CheckPrivilege;
import com.worlabel.global.exception.CustomException;
import com.worlabel.global.exception.ErrorCode;
import com.worlabel.global.service.AiRequestService;
import com.worlabel.global.service.FcmService;
import com.worlabel.global.service.S3UploadService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
@ -62,6 +65,8 @@ public class ProjectService {
private final AiRequestService aiService;
private final Gson gson;
private final FcmService fcmService;
private final AlarmService alarmService;
@Transactional
public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectWithCategoryRequest projectRequest) {
@ -163,7 +168,9 @@ public class ProjectService {
* 프로젝트별 오토 레이블링
*/
@CheckPrivilege(PrivilegeType.EDITOR)
public void autoLabeling(final Integer projectId, final AutoModelRequest request) {
public void autoLabeling(final Integer memberId, final Integer projectId, final AutoModelRequest request) {
progressService.predictProgressCheck(projectId);
Project project = getProject(projectId);
String endPoint = project.getProjectType().getValue() + "/predict";
@ -182,8 +189,13 @@ public class ProjectService {
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, aiModel.getModelKey(), labelMap, imageRequestList);
log.debug("요청 {}", autoLabelingRequest);
// progressService.registerPredictProgress(projectId);
progressService.registerPredictProgress(projectId);
List<AutoLabelingResult> list = aiService.postRequest(endPoint, autoLabelingRequest, List.class, this::converter);
log.debug("완료 후 삭제:{}", list);
alarmService.save(memberId, Alarm.AlarmType.PREDICT);
progressService.removePredictProgress(projectId);
saveAutoLabelList(list);
}

View File

@ -2,25 +2,38 @@ package com.worlabel.global.service;
import com.google.firebase.messaging.FirebaseMessaging;
import com.google.firebase.messaging.Message;
import com.worlabel.domain.alarm.entity.Alarm;
import com.worlabel.domain.alarm.entity.Alarm.AlarmType;
import com.worlabel.domain.alarm.service.AlarmService;
import com.worlabel.domain.auth.repository.FcmCacheRepository;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.Objects;
@Slf4j
@Service
@RequiredArgsConstructor
public class FcmService {
private final FirebaseMessaging firebaseMessaging;
private final FcmCacheRepository fcmCacheRepository;
public void testSend(String targetToken, String message){
sendNotification(targetToken, "testTitle", "testBody");
public void testSend(String targetToken){
sendNotification(targetToken, "testBody");
}
private void sendNotification(String targetToken, String title, String body){
public void send(Integer memberId, String data) {
String token = fcmCacheRepository.getToken(memberId);
if(Objects.nonNull(token)){
sendNotification(token, data);
}
}
private void sendNotification(String targetToken, String body){
Message message = Message.builder()
.setToken(targetToken)
.putData("title",title)
.putData("body",body)
.build();
try {