Feat: WebSocket 연결

This commit is contained in:
김용수 2024-09-12 17:03:10 +09:00
parent c7d613365b
commit 6c46b64ce2
12 changed files with 73 additions and 62 deletions

View File

@ -9,9 +9,6 @@ import lombok.AccessLevel;
import lombok.Getter; import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.List;
@Getter @Getter
@Entity @Entity
@Table(name = "project_image") @Table(name = "project_image")
@ -49,7 +46,7 @@ public class Image extends BaseEntity {
*/ */
@Column(name = "status", nullable = false) @Column(name = "status", nullable = false)
@Enumerated(EnumType.STRING) @Enumerated(EnumType.STRING)
private LabelStatus status = LabelStatus.PENDING; private LabelStatus status = LabelStatus.Pending;
/** /**
* 속한 폴더 * 속한 폴더
@ -62,7 +59,7 @@ public class Image extends BaseEntity {
/** /**
* 이미지에 연결된 레이블 * 이미지에 연결된 레이블
*/ */
@OneToOne(mappedBy = "image", cascade = CascadeType.ALL, orphanRemoval = true) @OneToOne(mappedBy = "image", fetch = FetchType.LAZY, cascade = CascadeType.ALL, orphanRemoval = true)
private Label label; private Label label;
private Image(final String imageTitle, final String imageUrl, final Integer order, final Folder folder) { private Image(final String imageTitle, final String imageUrl, final Integer order, final Folder folder) {

View File

@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue; import com.fasterxml.jackson.annotation.JsonValue;
public enum LabelStatus { public enum LabelStatus {
PENDING, Pending,
IN_PROGRESS, IN_PROGRESS,
NEED_REVIEW, NEED_REVIEW,
COMPLETED; COMPLETED;

View File

@ -12,10 +12,8 @@ public interface ImageRepository extends JpaRepository<Image, Long> {
Optional<Image> findByIdAndFolderId(Long imageId, Integer folderId); Optional<Image> findByIdAndFolderId(Long imageId, Integer folderId);
// TODO: N + 1
@Query("select i from Image i " + @Query("select i from Image i " +
"join fetch i.folder f " + "join fetch i.label l " +
"join fetch f.project p " + "where i.folder.project.id = :projectId")
"where p.id = :projectId")
List<Image> findImagesByProjectId(@Param("projectId") Integer projectId); List<Image> findImagesByProjectId(@Param("projectId") Integer projectId);
} }

View File

@ -26,7 +26,6 @@ public class LabelController {
private final LabelService labelService; private final LabelService labelService;
@Operation(summary = "프로젝트 단위 오토레이블링", description = "해당 프로젝트 이미지를 오토레이블링합니다.") @Operation(summary = "프로젝트 단위 오토레이블링", description = "해당 프로젝트 이미지를 오토레이블링합니다.")
@SwaggerApiSuccess(description = "해당 프로젝트가 오토 레이블링 됩니다.") @SwaggerApiSuccess(description = "해당 프로젝트가 오토 레이블링 됩니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@ -51,7 +50,4 @@ public class LabelController {
labelService.save(imageId); labelService.save(imageId);
return SuccessResponse.empty(); return SuccessResponse.empty();
} }
} }

View File

@ -34,14 +34,6 @@ public class Label extends BaseEntity {
@JoinColumn(name = "image_id") @JoinColumn(name = "image_id")
private Image image; private Image image;
/**
* 속한 카테고리
* TODO: 레이블 카테고리에 속한걸 찾는데에 Json파일에 담기 때문에 카테고리는 Label Entity에 없어도 같음
*/
// @ManyToOne(fetch = FetchType.LAZY)
// @JoinColumn(name = "label_category_id")
// private LabelCategory labelCategory;
public static Label of(String jsonUrl, Image image) { public static Label of(String jsonUrl, Image image) {
Label label = new Label(); Label label = new Label();
label.url = jsonUrl; label.url = jsonUrl;
@ -49,4 +41,7 @@ public class Label extends BaseEntity {
return label; return label;
} }
public void changeUrl(String newUrl){
}
} }

View File

@ -8,4 +8,6 @@ import java.util.Optional;
public interface LabelRepository extends JpaRepository<Label, Long> { public interface LabelRepository extends JpaRepository<Label, Long> {
Optional<Label> findByImageId(Long imageId); Optional<Label> findByImageId(Long imageId);
boolean existsByImageId(Long imageId);
} }

View File

@ -59,20 +59,21 @@ public class LabelService {
List<Image> imageList = imageRepository.findImagesByProjectId(projectId); List<Image> imageList = imageRepository.findImagesByProjectId(projectId);
List<ImageRequest> imageRequestList = imageList.stream().map(ImageRequest::of).toList(); List<ImageRequest> imageRequestList = imageList.stream().map(ImageRequest::of).toList();
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, imageRequestList); AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, imageRequestList);
sendRequestToApi(autoLabelingRequest, projectType.getValue(), projectId);
}
List<AutoLabelingResponse> autoLabelingResponseList = sendRequestToApi(autoLabelingRequest, projectType.getValue(), projectId); public void saveLabel(final AutoLabelingResponse autoLabelingResponse) {
for (int index = 0; index < autoLabelingResponseList.size(); index++) { String uploadUrl = s3UploadService.uploadJson(autoLabelingResponse.getData(), autoLabelingResponse.getImageUrl());
AutoLabelingResponse response = autoLabelingResponseList.get(index); Image image = imageRepository.findById(autoLabelingResponse.getImageId())
Image image = imageList.get(index); .orElseThrow(() -> new CustomException(ErrorCode.IMAGE_NOT_FOUND));
String uploadedUrl = s3UploadService.uploadJson(response.getData(), response.getImageUrl());
Label label = labelRepository.findByImageId(response.getImageId()) if (!labelRepository.existsByImageId(autoLabelingResponse.getImageId())) {
.orElseGet(() -> Label.of(uploadedUrl, image)); Label label = Label.of(uploadUrl, image);
labelRepository.save(label); labelRepository.save(label);
} }
} }
private List<AutoLabelingResponse> sendRequestToApi(AutoLabelingRequest autoLabelingRequest, String apiEndpoint, int projectId) { private void sendRequestToApi(AutoLabelingRequest autoLabelingRequest, String apiEndpoint, int projectId) {
String url = createApiUrl(apiEndpoint); String url = createApiUrl(apiEndpoint);
// RestTemplate을 동적으로 생성하여 사용 // RestTemplate을 동적으로 생성하여 사용
@ -94,9 +95,7 @@ public class LabelService {
String responseBody = Optional.ofNullable(response.getBody()) String responseBody = Optional.ofNullable(response.getBody())
.orElseThrow(() -> new CustomException(ErrorCode.AI_SERVER_ERROR)); .orElseThrow(() -> new CustomException(ErrorCode.AI_SERVER_ERROR));
log.info("AI 서버 응답 -> {}", response.getBody()); // return parseAutoLabelingResponseList(responseBody);
return parseAutoLabelingResponseList(responseBody);
} catch (Exception e) { } catch (Exception e) {
log.error("AI 서버 요청 중 오류 발생: ", e); log.error("AI 서버 요청 중 오류 발생: ", e);
throw new CustomException(ErrorCode.AI_SERVER_ERROR); throw new CustomException(ErrorCode.AI_SERVER_ERROR);
@ -107,17 +106,16 @@ public class LabelService {
JsonElement jsonElement = JsonParser.parseString(responseBody); JsonElement jsonElement = JsonParser.parseString(responseBody);
List<AutoLabelingResponse> autoLabelingResponseList = new ArrayList<>(); List<AutoLabelingResponse> autoLabelingResponseList = new ArrayList<>();
for (JsonElement element : jsonElement.getAsJsonArray()) { for (JsonElement element : jsonElement.getAsJsonArray()) {
AutoLabelingResponse response = parseAutoLabelingResponse(element); AutoLabelingResponse response = parseAutoLabelingResponse(element.getAsJsonObject());
autoLabelingResponseList.add(response); autoLabelingResponseList.add(response);
} }
return autoLabelingResponseList; return autoLabelingResponseList;
} }
/** /**
* jsonElement -> AutoLabelingResponse * jsonObject -> AutoLabelingResponse
*/ */
private AutoLabelingResponse parseAutoLabelingResponse(JsonElement element) { public AutoLabelingResponse parseAutoLabelingResponse(JsonObject jsonObject) {
JsonObject jsonObject = element.getAsJsonObject();
Long imageId = jsonObject.get("image_id").getAsLong(); Long imageId = jsonObject.get("image_id").getAsLong();
String imageUrl = jsonObject.get("image_url").getAsString(); String imageUrl = jsonObject.get("image_url").getAsString();
JsonObject data = jsonObject.get("data").getAsJsonObject(); JsonObject data = jsonObject.get("data").getAsJsonObject();

View File

@ -26,6 +26,8 @@ public interface ParticipantRepository extends JpaRepository<Participant, Intege
@Param("projectId") Integer projectId); @Param("projectId") Integer projectId);
Optional<Participant> findByMemberIdAndProjectId(Integer memberId, Integer projectId); Optional<Participant> findByMemberIdAndProjectId(Integer memberId, Integer projectId);
} }

View File

@ -1,5 +1,11 @@
package com.worlabel.domain.progress; package com.worlabel.domain.progress;
import com.google.gson.Gson;
import com.google.gson.JsonObject;
import com.worlabel.domain.label.entity.dto.AutoLabelingResponse;
import com.worlabel.domain.label.service.LabelService;
import com.worlabel.global.service.S3UploadService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.messaging.handler.annotation.MessageMapping; import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.SendTo; import org.springframework.messaging.handler.annotation.SendTo;
@ -7,8 +13,13 @@ import org.springframework.stereotype.Controller;
@Slf4j @Slf4j
@Controller @Controller
@RequiredArgsConstructor
public class ProgressController { public class ProgressController {
private final LabelService labelService;
private final S3UploadService s3UploadService;
private final Gson gson;
@MessageMapping("/ai/train/progress") @MessageMapping("/ai/train/progress")
@SendTo("/topic/progress") @SendTo("/topic/progress")
public String handleTrainingProgress(String message) { public String handleTrainingProgress(String message) {
@ -16,4 +27,16 @@ public class ProgressController {
log.debug("Received message: {}", message); log.debug("Received message: {}", message);
return message; return message;
} }
@MessageMapping("/ai/predict/progress")
@SendTo("/topic/progress")
public String handlePredictProgress(String message) {
JsonObject jsonObject = gson.fromJson(message, JsonObject.class);
int progress = jsonObject.get("progress").getAsInt();
AutoLabelingResponse autoLabelingResponse = labelService.parseAutoLabelingResponse(jsonObject.getAsJsonObject("result"));
labelService.saveLabel(autoLabelingResponse);
log.debug("오토 레이블링 진행률 : {}%",progress);
return String.valueOf(progress);
}
} }

View File

@ -83,8 +83,6 @@ public class ProjectController {
public BaseResponse<Void> trainModel( public BaseResponse<Void> trainModel(
@CurrentUser final Integer memberId, @CurrentUser final Integer memberId,
@PathVariable("project_id") final Integer projectId) { @PathVariable("project_id") final Integer projectId) {
log.debug("훈련 요청 ");
projectService.train(memberId, projectId); projectService.train(memberId, projectId);
return SuccessResponse.empty(); return SuccessResponse.empty();
} }

View File

@ -8,7 +8,6 @@ import com.worlabel.domain.participant.entity.WorkspaceParticipant;
import com.worlabel.domain.participant.entity.dto.ParticipantRequest; import com.worlabel.domain.participant.entity.dto.ParticipantRequest;
import com.worlabel.domain.participant.repository.ParticipantRepository; import com.worlabel.domain.participant.repository.ParticipantRepository;
import com.worlabel.domain.participant.repository.WorkspaceParticipantRepository; import com.worlabel.domain.participant.repository.WorkspaceParticipantRepository;
import com.worlabel.domain.project.dto.RequestDto;
import com.worlabel.domain.project.dto.RequestDto.TrainRequest; import com.worlabel.domain.project.dto.RequestDto.TrainRequest;
import com.worlabel.domain.project.entity.Project; import com.worlabel.domain.project.entity.Project;
import com.worlabel.domain.project.entity.dto.ProjectRequest; import com.worlabel.domain.project.entity.dto.ProjectRequest;
@ -18,26 +17,16 @@ import com.worlabel.domain.workspace.entity.Workspace;
import com.worlabel.domain.workspace.repository.WorkspaceRepository; import com.worlabel.domain.workspace.repository.WorkspaceRepository;
import com.worlabel.global.exception.CustomException; import com.worlabel.global.exception.CustomException;
import com.worlabel.global.exception.ErrorCode; import com.worlabel.global.exception.ErrorCode;
import com.worlabel.global.service.AIWebSocketClient;
import lombok.RequiredArgsConstructor; import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHttpHeaders;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.CompletableFuture;
@Slf4j @Slf4j
@Service @Service
@ -52,6 +41,12 @@ public class ProjectService {
private final WorkspaceParticipantRepository workspaceParticipantRepository; private final WorkspaceParticipantRepository workspaceParticipantRepository;
private final RestTemplate restTemplate; private final RestTemplate restTemplate;
/**
* AI SERVER 주소
*/
@Value("${ai.server}")
private String aiServer;
public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectRequest projectRequest) { public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectRequest projectRequest) {
Workspace workspace = getWorkspace(memberId, workspaceId); Workspace workspace = getWorkspace(memberId, workspaceId);
Member member = getMember(memberId); Member member = getMember(memberId);
@ -128,14 +123,17 @@ public class ProjectService {
participantRepository.delete(participant); participantRepository.delete(participant);
} }
@Async public void train(final Integer memberId, final Integer projectId) {
public CompletableFuture<Void> train(final Integer memberId, final Integer projectId) {
// 멤버 권한 체크 // 멤버 권한 체크
checkEditorParticipant(memberId, projectId);
// 레디스 train 테이블에 존재하는지 확인 // TODO: 레디스 train 테이블에 존재하는지 확인 -> 이미 있으면 있다고 예외를 던져준다.
/*
없으면 redis 상태 테이블을 만든다. progressTable
*/
// FastAPI 서버로 학습 요청을 전송 // FastAPI 서버로 학습 요청을 전송
String url = "http://localhost:8000/api/detection/train"; String url = aiServer + "/detection/train";
TrainRequest trainRequest = new TrainRequest(); TrainRequest trainRequest = new TrainRequest();
trainRequest.setProjectId(projectId); trainRequest.setProjectId(projectId);
@ -145,12 +143,11 @@ public class ProjectService {
try { try {
ResponseEntity<String> result = restTemplate.postForEntity(url, trainRequest, String.class); ResponseEntity<String> result = restTemplate.postForEntity(url, trainRequest, String.class);
log.debug("응답 결과 {} ",result); log.debug("응답 결과 {} ",result);
System.out.println("FastAPI 서버에 학습 요청을 성공적으로 전송했습니다. Project ID: " + projectId); log.debug("FastAPI 서버에 학습 요청을 성공적으로 전송했습니다. Project ID: {}", projectId);
} catch (Exception e) { } catch (Exception e) {
System.err.println("FastAPI 서버에 학습 요청을 전송하는 중 오류가 발생했습니다: " + e.getMessage()); log.error("FastAPI 서버에 학습 요청을 전송하는 중 오류가 발생했습니다 {}",e.getMessage());
throw new CustomException(ErrorCode.AI_SERVER_ERROR);
} }
return CompletableFuture.completedFuture(null);
} }
private Workspace getWorkspace(final Integer memberId, final Integer workspaceId) { private Workspace getWorkspace(final Integer memberId, final Integer workspaceId) {
@ -168,6 +165,12 @@ public class ProjectService {
.orElseThrow(() -> new CustomException(ErrorCode.PROJECT_NOT_FOUND)); .orElseThrow(() -> new CustomException(ErrorCode.PROJECT_NOT_FOUND));
} }
private void checkEditorParticipant(final Integer memberId, final Integer projectId) {
if(participantRepository.doesParticipantUnauthorizedExistByMemberIdAndProjectId(memberId,projectId)){
throw new CustomException(ErrorCode.PARTICIPANT_UNAUTHORIZED);
}
}
private void checkExistParticipant(final Integer memberId, final Integer projectId) { private void checkExistParticipant(final Integer memberId, final Integer projectId) {
if (!participantRepository.existsByMemberIdAndProjectId(memberId, projectId)) { if (!participantRepository.existsByMemberIdAndProjectId(memberId, projectId)) {
throw new CustomException(ErrorCode.PARTICIPANT_UNAUTHORIZED); throw new CustomException(ErrorCode.PARTICIPANT_UNAUTHORIZED);

View File

@ -50,8 +50,7 @@ public class S3UploadService {
// String targetUrl = projectId + "/" + title + ".json"; // S3에 업로드할 대상 URL // String targetUrl = projectId + "/" + title + ".json"; // S3에 업로드할 대상 URL
String targetUrl = removeExtension(getKeyFromImageAddress(imageUrl)) + ".json"; String targetUrl = removeExtension(getKeyFromImageAddress(imageUrl)) + ".json";
log.debug("주소 {}", targetUrl); // log.debug("주소 {}", targetUrl);
try { try {
byte[] jsonBytes = json.getBytes(StandardCharsets.UTF_8); byte[] jsonBytes = json.getBytes(StandardCharsets.UTF_8);
ObjectMetadata metadata = new ObjectMetadata(); ObjectMetadata metadata = new ObjectMetadata();
@ -65,7 +64,7 @@ public class S3UploadService {
} }
URL uploadedUrl = amazonS3.getUrl(bucket, targetUrl); URL uploadedUrl = amazonS3.getUrl(bucket, targetUrl);
log.debug("Uploaded JSON URL: {}", uploadedUrl); // log.debug("Uploaded JSON URL: {}", uploadedUrl);
return uploadedUrl.toString(); // 업로드된 파일의 URL 반환 return uploadedUrl.toString(); // 업로드된 파일의 URL 반환
} catch (Exception e) { } catch (Exception e) {
log.error("JSON 업로드 중 오류 발생: ", e); log.error("JSON 업로드 중 오류 발생: ", e);