Feat: 오토레이블링, 모델 학습 완료 알림 전송
This commit is contained in:
parent
ffe79cc79b
commit
ecca6d3876
@ -3,6 +3,7 @@ package com.worlabel.domain.alarm.service;
|
||||
import com.worlabel.domain.alarm.entity.Alarm;
|
||||
import com.worlabel.domain.alarm.entity.Alarm.AlarmType;
|
||||
import com.worlabel.domain.alarm.repository.AlarmCacheRepository;
|
||||
import com.worlabel.global.service.FcmService;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
@ -15,9 +16,11 @@ import java.util.List;
|
||||
public class AlarmService {
|
||||
|
||||
private final AlarmCacheRepository alarmCacheRepository;
|
||||
private final FcmService fcmService;
|
||||
|
||||
public void save(int memberId, AlarmType type) {
|
||||
alarmCacheRepository.save(memberId, type);
|
||||
fcmService.send(memberId, type.toString());
|
||||
}
|
||||
|
||||
public List<Alarm> getAlarmList(int memberId){
|
||||
|
@ -110,7 +110,7 @@ public class AuthController {
|
||||
@PostMapping("/test")
|
||||
public void testSend(@CurrentUser final Integer memberId) {
|
||||
String token = fcmCacheRepository.getToken(memberId);
|
||||
fcmService.testSend(token, "test알림입니다.");
|
||||
fcmService.testSend(token);
|
||||
}
|
||||
|
||||
private static String parseRefreshCookie(HttpServletRequest request) {
|
||||
|
@ -64,9 +64,10 @@ public class AiModelController {
|
||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||
@PostMapping("/projects/{project_id}/train")
|
||||
public void trainModel(
|
||||
@CurrentUser final Integer memberId,
|
||||
@PathVariable("project_id") final Integer projectId,
|
||||
@RequestBody final ModelTrainRequest trainRequest) {
|
||||
log.debug("모델 학습 요청 {}", trainRequest);
|
||||
aiModelService.train(projectId, trainRequest);
|
||||
aiModelService.train(memberId, projectId, trainRequest);
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,8 @@ package com.worlabel.domain.model.service;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.reflect.TypeToken;
|
||||
import com.worlabel.domain.alarm.entity.Alarm;
|
||||
import com.worlabel.domain.alarm.service.AlarmService;
|
||||
import com.worlabel.domain.image.entity.Image;
|
||||
import com.worlabel.domain.image.repository.ImageRepository;
|
||||
import com.worlabel.domain.labelcategory.entity.ProjectCategory;
|
||||
@ -41,16 +43,16 @@ import java.util.stream.Collectors;
|
||||
@RequiredArgsConstructor
|
||||
public class AiModelService {
|
||||
|
||||
private final RedisTemplate<String, Object> redisTemplate;
|
||||
private final AiModelRepository aiModelRepository;
|
||||
private final ProjectRepository projectRepository;
|
||||
private final AiRequestService aiRequestService;
|
||||
private final ImageRepository imageRepository;
|
||||
private final ResultRepository resultRepository;
|
||||
private final ProjectService projectService;
|
||||
private final Gson gson;
|
||||
private final AiRequestService aiRequestService;
|
||||
private final ProgressService progressService;
|
||||
private final ImageRepository imageRepository;
|
||||
private final ReportService reportService;
|
||||
private final AlarmService alarmService;
|
||||
|
||||
private final Gson gson;
|
||||
|
||||
@Transactional(readOnly = true)
|
||||
public List<AiModelResponse> getModelList(final Integer projectId) {
|
||||
@ -82,7 +84,7 @@ public class AiModelService {
|
||||
}
|
||||
|
||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||
public void train(final Integer projectId, final ModelTrainRequest trainRequest) {
|
||||
public void train(final Integer memberId, final Integer projectId, final ModelTrainRequest trainRequest) {
|
||||
// FastAPI 서버로 학습 요청을 전송
|
||||
Project project = getProject(projectId);
|
||||
AiModel model = getModel(trainRequest.getModelId());
|
||||
@ -122,6 +124,8 @@ public class AiModelService {
|
||||
|
||||
// 레디스 정보 DB에 저장
|
||||
reportService.changeReport(project.getId(), model.getId(), newModel);
|
||||
|
||||
alarmService.save(memberId, Alarm.AlarmType.TRAIN);
|
||||
}
|
||||
|
||||
private TrainResponse converterTrain(String data) {
|
||||
|
@ -19,7 +19,7 @@ public class ProgressService {
|
||||
|
||||
public void predictProgressCheck(final int projectId) {
|
||||
if (progressCacheRepository.predictProgressCheck(projectId)) {
|
||||
throw new CustomException(ErrorCode.AI_IN_PROGRESS);
|
||||
throw new CustomException(ErrorCode.AI_IN_PROGRESS, "해당 프로젝트 오토레이블링 진행 중");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -74,9 +74,10 @@ public class ProjectController {
|
||||
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||
@PostMapping("/projects/{project_id}/auto")
|
||||
public void autoLabeling(
|
||||
@CurrentUser final Integer memberId,
|
||||
@PathVariable("project_id") final Integer projectId,
|
||||
@RequestBody final AutoModelRequest request) {
|
||||
projectService.autoLabeling(projectId, request);
|
||||
projectService.autoLabeling(memberId, projectId, request);
|
||||
}
|
||||
|
||||
@Operation(summary = "프로젝트 삭제", description = "프로젝트를 삭제합니다.")
|
||||
|
@ -125,10 +125,10 @@ public class AiDto {
|
||||
}
|
||||
}
|
||||
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
@Getter
|
||||
@ToString
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
public static class AutoLabelingResult {
|
||||
|
||||
@SerializedName("image_id")
|
||||
@ -137,4 +137,5 @@ public class AiDto {
|
||||
@SerializedName("data")
|
||||
private String data;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -26,6 +26,6 @@ public class ProjectWithCategoryRequest {
|
||||
private List<String> categories;
|
||||
|
||||
@Schema(description = "프로젝트 유형", example = "classification")
|
||||
@NotNull(message = "카테고리를 입력하세요.")
|
||||
@NotNull(message = "프로젝트 타입을 선택해주세요.")
|
||||
private ProjectType projectType;
|
||||
}
|
||||
|
@ -3,6 +3,8 @@ package com.worlabel.domain.project.service;
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.JsonSyntaxException;
|
||||
import com.google.gson.reflect.TypeToken;
|
||||
import com.worlabel.domain.alarm.entity.Alarm;
|
||||
import com.worlabel.domain.alarm.service.AlarmService;
|
||||
import com.worlabel.domain.image.entity.Image;
|
||||
import com.worlabel.domain.image.entity.LabelStatus;
|
||||
import com.worlabel.domain.image.repository.ImageRepository;
|
||||
@ -32,6 +34,7 @@ import com.worlabel.global.annotation.CheckPrivilege;
|
||||
import com.worlabel.global.exception.CustomException;
|
||||
import com.worlabel.global.exception.ErrorCode;
|
||||
import com.worlabel.global.service.AiRequestService;
|
||||
import com.worlabel.global.service.FcmService;
|
||||
import com.worlabel.global.service.S3UploadService;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@ -62,6 +65,8 @@ public class ProjectService {
|
||||
private final AiRequestService aiService;
|
||||
|
||||
private final Gson gson;
|
||||
private final FcmService fcmService;
|
||||
private final AlarmService alarmService;
|
||||
|
||||
@Transactional
|
||||
public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectWithCategoryRequest projectRequest) {
|
||||
@ -163,7 +168,9 @@ public class ProjectService {
|
||||
* 프로젝트별 오토 레이블링
|
||||
*/
|
||||
@CheckPrivilege(PrivilegeType.EDITOR)
|
||||
public void autoLabeling(final Integer projectId, final AutoModelRequest request) {
|
||||
public void autoLabeling(final Integer memberId, final Integer projectId, final AutoModelRequest request) {
|
||||
progressService.predictProgressCheck(projectId);
|
||||
|
||||
Project project = getProject(projectId);
|
||||
String endPoint = project.getProjectType().getValue() + "/predict";
|
||||
|
||||
@ -182,8 +189,13 @@ public class ProjectService {
|
||||
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, aiModel.getModelKey(), labelMap, imageRequestList);
|
||||
|
||||
log.debug("요청 {}", autoLabelingRequest);
|
||||
// progressService.registerPredictProgress(projectId);
|
||||
progressService.registerPredictProgress(projectId);
|
||||
List<AutoLabelingResult> list = aiService.postRequest(endPoint, autoLabelingRequest, List.class, this::converter);
|
||||
log.debug("완료 후 삭제:{}", list);
|
||||
|
||||
alarmService.save(memberId, Alarm.AlarmType.PREDICT);
|
||||
progressService.removePredictProgress(projectId);
|
||||
|
||||
saveAutoLabelList(list);
|
||||
}
|
||||
|
||||
|
@ -2,25 +2,38 @@ package com.worlabel.global.service;
|
||||
|
||||
import com.google.firebase.messaging.FirebaseMessaging;
|
||||
import com.google.firebase.messaging.Message;
|
||||
import com.worlabel.domain.alarm.entity.Alarm;
|
||||
import com.worlabel.domain.alarm.entity.Alarm.AlarmType;
|
||||
import com.worlabel.domain.alarm.service.AlarmService;
|
||||
import com.worlabel.domain.auth.repository.FcmCacheRepository;
|
||||
import lombok.RequiredArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@RequiredArgsConstructor
|
||||
public class FcmService {
|
||||
|
||||
private final FirebaseMessaging firebaseMessaging;
|
||||
private final FcmCacheRepository fcmCacheRepository;
|
||||
|
||||
public void testSend(String targetToken, String message){
|
||||
sendNotification(targetToken, "testTitle", "testBody");
|
||||
public void testSend(String targetToken){
|
||||
sendNotification(targetToken, "testBody");
|
||||
}
|
||||
|
||||
private void sendNotification(String targetToken, String title, String body){
|
||||
public void send(Integer memberId, String data) {
|
||||
String token = fcmCacheRepository.getToken(memberId);
|
||||
if(Objects.nonNull(token)){
|
||||
sendNotification(token, data);
|
||||
}
|
||||
}
|
||||
|
||||
private void sendNotification(String targetToken, String body){
|
||||
Message message = Message.builder()
|
||||
.setToken(targetToken)
|
||||
.putData("title",title)
|
||||
.putData("body",body)
|
||||
.build();
|
||||
try {
|
||||
|
Loading…
Reference in New Issue
Block a user