Merge branch 'be/feat/websocket' into 'be/develop'
Feat: WebSocket 연결 See merge request s11-s-project/S11P21S002!64
This commit is contained in:
commit
3caf024254
@ -69,6 +69,9 @@ dependencies {
|
|||||||
testImplementation 'org.junit.jupiter:junit-jupiter:5.7.1'
|
testImplementation 'org.junit.jupiter:junit-jupiter:5.7.1'
|
||||||
testImplementation 'org.mockito:mockito-core:3.9.0'
|
testImplementation 'org.mockito:mockito-core:3.9.0'
|
||||||
testImplementation 'org.mockito:mockito-junit-jupiter:3.9.0'
|
testImplementation 'org.mockito:mockito-junit-jupiter:3.9.0'
|
||||||
|
|
||||||
|
// WebSocket
|
||||||
|
implementation 'org.springframework.boot:spring-boot-starter-websocket'
|
||||||
}
|
}
|
||||||
|
|
||||||
tasks.named('test') {
|
tasks.named('test') {
|
||||||
|
@ -45,7 +45,7 @@ public class FolderResponse {
|
|||||||
|
|
||||||
public static FolderResponse fromWithNeedReview(final Folder folder) {
|
public static FolderResponse fromWithNeedReview(final Folder folder) {
|
||||||
List<ImageResponse> images = folder.getImageList().stream()
|
List<ImageResponse> images = folder.getImageList().stream()
|
||||||
.filter(image -> image.getStatus() == LabelStatus.NEED_REVIEW)
|
.filter(image -> image.getStatus() == LabelStatus.REVIEW_REQUEST)
|
||||||
.map(ImageResponse::from)
|
.map(ImageResponse::from)
|
||||||
.toList();
|
.toList();
|
||||||
|
|
||||||
|
@ -2,6 +2,8 @@ package com.worlabel.domain.folder.repository;
|
|||||||
|
|
||||||
import com.worlabel.domain.folder.entity.Folder;
|
import com.worlabel.domain.folder.entity.Folder;
|
||||||
import org.springframework.data.jpa.repository.JpaRepository;
|
import org.springframework.data.jpa.repository.JpaRepository;
|
||||||
|
import org.springframework.data.jpa.repository.Query;
|
||||||
|
import org.springframework.data.repository.query.Param;
|
||||||
import org.springframework.stereotype.Repository;
|
import org.springframework.stereotype.Repository;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@ -10,9 +12,20 @@ import java.util.Optional;
|
|||||||
@Repository
|
@Repository
|
||||||
public interface FolderRepository extends JpaRepository<Folder, Integer> {
|
public interface FolderRepository extends JpaRepository<Folder, Integer> {
|
||||||
|
|
||||||
List<Folder> findAllByProjectIdAndParentIsNull(Integer projectId);
|
@Query("SELECT f FROM Folder f " +
|
||||||
|
"LEFT JOIN FETCH f.imageList i " +
|
||||||
|
"LEFT JOIN FETCH i.label " +
|
||||||
|
"WHERE f.project.id = :projectId " +
|
||||||
|
"AND f.parent IS NULL ")
|
||||||
|
List<Folder> findAllByProjectIdAndParentIsNull(@Param("projectId") Integer projectId);
|
||||||
|
|
||||||
|
@Query("SELECT f FROM Folder f " +
|
||||||
|
"LEFT JOIN FETCH f.imageList i " +
|
||||||
|
"LEFT JOIN FETCH i.label " +
|
||||||
|
"WHERE f.project.id = :projectId " +
|
||||||
|
"AND f.id = :folderId")
|
||||||
|
Optional<Folder> findAllByProjectIdAndId(@Param("projectId") Integer projectId, @Param("folderId") Integer folderId);
|
||||||
|
|
||||||
Optional<Folder> findAllByProjectIdAndId(Integer projectId, Integer folderId);
|
|
||||||
|
|
||||||
boolean existsByIdAndProjectId(Integer folderId, Integer projectId);
|
boolean existsByIdAndProjectId(Integer folderId, Integer projectId);
|
||||||
}
|
}
|
@ -9,9 +9,6 @@ import lombok.AccessLevel;
|
|||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@Entity
|
@Entity
|
||||||
@Table(name = "project_image")
|
@Table(name = "project_image")
|
||||||
@ -62,7 +59,7 @@ public class Image extends BaseEntity {
|
|||||||
/**
|
/**
|
||||||
* 이미지에 연결된 레이블
|
* 이미지에 연결된 레이블
|
||||||
*/
|
*/
|
||||||
@OneToOne(mappedBy = "image", cascade = CascadeType.ALL, orphanRemoval = true)
|
@OneToOne(fetch = FetchType.LAZY, mappedBy = "image", cascade = CascadeType.ALL, orphanRemoval = true)
|
||||||
private Label label;
|
private Label label;
|
||||||
|
|
||||||
private Image(final String imageTitle, final String imageUrl, final Integer order, final Folder folder) {
|
private Image(final String imageTitle, final String imageUrl, final Integer order, final Folder folder) {
|
||||||
|
@ -6,7 +6,8 @@ import com.fasterxml.jackson.annotation.JsonValue;
|
|||||||
public enum LabelStatus {
|
public enum LabelStatus {
|
||||||
PENDING,
|
PENDING,
|
||||||
IN_PROGRESS,
|
IN_PROGRESS,
|
||||||
NEED_REVIEW,
|
SAVE,
|
||||||
|
REVIEW_REQUEST,
|
||||||
COMPLETED;
|
COMPLETED;
|
||||||
|
|
||||||
// 입력 값을 enum 값과 일치시키기 위해 대소문자 구분 없이 변환
|
// 입력 값을 enum 값과 일치시키기 위해 대소문자 구분 없이 변환
|
||||||
|
@ -12,10 +12,8 @@ public interface ImageRepository extends JpaRepository<Image, Long> {
|
|||||||
|
|
||||||
Optional<Image> findByIdAndFolderId(Long imageId, Integer folderId);
|
Optional<Image> findByIdAndFolderId(Long imageId, Integer folderId);
|
||||||
|
|
||||||
// TODO: N + 1
|
|
||||||
@Query("select i from Image i " +
|
@Query("select i from Image i " +
|
||||||
"join fetch i.folder f " +
|
"join fetch i.label l " +
|
||||||
"join fetch f.project p " +
|
"where i.folder.project.id = :projectId")
|
||||||
"where p.id = :projectId")
|
|
||||||
List<Image> findImagesByProjectId(@Param("projectId") Integer projectId);
|
List<Image> findImagesByProjectId(@Param("projectId") Integer projectId);
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,4 @@ public class LabelController {
|
|||||||
labelService.save(imageId);
|
labelService.save(imageId);
|
||||||
return SuccessResponse.empty();
|
return SuccessResponse.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -34,19 +34,10 @@ public class Label extends BaseEntity {
|
|||||||
@JoinColumn(name = "image_id")
|
@JoinColumn(name = "image_id")
|
||||||
private Image image;
|
private Image image;
|
||||||
|
|
||||||
/**
|
|
||||||
* 속한 카테고리
|
|
||||||
* TODO: 한 레이블 카테고리에 속한걸 찾는데에 Json파일에 담기 때문에 카테고리는 Label Entity에 없어도 될 것 같음
|
|
||||||
*/
|
|
||||||
// @ManyToOne(fetch = FetchType.LAZY)
|
|
||||||
// @JoinColumn(name = "label_category_id")
|
|
||||||
// private LabelCategory labelCategory;
|
|
||||||
|
|
||||||
public static Label of(String jsonUrl, Image image) {
|
public static Label of(String jsonUrl, Image image) {
|
||||||
Label label = new Label();
|
Label label = new Label();
|
||||||
label.url = jsonUrl;
|
label.url = jsonUrl;
|
||||||
label.image = image;
|
label.image = image;
|
||||||
return label;
|
return label;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -8,4 +8,6 @@ import java.util.Optional;
|
|||||||
public interface LabelRepository extends JpaRepository<Label, Long> {
|
public interface LabelRepository extends JpaRepository<Label, Long> {
|
||||||
|
|
||||||
Optional<Label> findByImageId(Long imageId);
|
Optional<Label> findByImageId(Long imageId);
|
||||||
|
|
||||||
|
boolean existsByImageId(Long imageId);
|
||||||
}
|
}
|
||||||
|
@ -59,20 +59,21 @@ public class LabelService {
|
|||||||
List<Image> imageList = imageRepository.findImagesByProjectId(projectId);
|
List<Image> imageList = imageRepository.findImagesByProjectId(projectId);
|
||||||
List<ImageRequest> imageRequestList = imageList.stream().map(ImageRequest::of).toList();
|
List<ImageRequest> imageRequestList = imageList.stream().map(ImageRequest::of).toList();
|
||||||
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, imageRequestList);
|
AutoLabelingRequest autoLabelingRequest = AutoLabelingRequest.of(projectId, imageRequestList);
|
||||||
|
sendRequestToApi(autoLabelingRequest, projectType.getValue(), projectId);
|
||||||
|
}
|
||||||
|
|
||||||
List<AutoLabelingResponse> autoLabelingResponseList = sendRequestToApi(autoLabelingRequest, projectType.getValue(), projectId);
|
public void saveLabel(final AutoLabelingResponse autoLabelingResponse) {
|
||||||
for (int index = 0; index < autoLabelingResponseList.size(); index++) {
|
String uploadUrl = s3UploadService.uploadJson(autoLabelingResponse.getData(), autoLabelingResponse.getImageUrl());
|
||||||
AutoLabelingResponse response = autoLabelingResponseList.get(index);
|
Image image = imageRepository.findById(autoLabelingResponse.getImageId())
|
||||||
Image image = imageList.get(index);
|
.orElseThrow(() -> new CustomException(ErrorCode.IMAGE_NOT_FOUND));
|
||||||
String uploadedUrl = s3UploadService.uploadJson(response.getData(), response.getImageUrl());
|
|
||||||
|
|
||||||
Label label = labelRepository.findByImageId(response.getImageId())
|
if (!labelRepository.existsByImageId(autoLabelingResponse.getImageId())) {
|
||||||
.orElseGet(() -> Label.of(uploadedUrl, image));
|
Label label = Label.of(uploadUrl, image);
|
||||||
labelRepository.save(label);
|
labelRepository.save(label);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<AutoLabelingResponse> sendRequestToApi(AutoLabelingRequest autoLabelingRequest, String apiEndpoint, int projectId) {
|
private void sendRequestToApi(AutoLabelingRequest autoLabelingRequest, String apiEndpoint, int projectId) {
|
||||||
String url = createApiUrl(apiEndpoint);
|
String url = createApiUrl(apiEndpoint);
|
||||||
|
|
||||||
// RestTemplate을 동적으로 생성하여 사용
|
// RestTemplate을 동적으로 생성하여 사용
|
||||||
@ -94,9 +95,7 @@ public class LabelService {
|
|||||||
String responseBody = Optional.ofNullable(response.getBody())
|
String responseBody = Optional.ofNullable(response.getBody())
|
||||||
.orElseThrow(() -> new CustomException(ErrorCode.AI_SERVER_ERROR));
|
.orElseThrow(() -> new CustomException(ErrorCode.AI_SERVER_ERROR));
|
||||||
|
|
||||||
log.info("AI 서버 응답 -> {}", response.getBody());
|
// return parseAutoLabelingResponseList(responseBody);
|
||||||
|
|
||||||
return parseAutoLabelingResponseList(responseBody);
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("AI 서버 요청 중 오류 발생: ", e);
|
log.error("AI 서버 요청 중 오류 발생: ", e);
|
||||||
throw new CustomException(ErrorCode.AI_SERVER_ERROR);
|
throw new CustomException(ErrorCode.AI_SERVER_ERROR);
|
||||||
@ -107,17 +106,16 @@ public class LabelService {
|
|||||||
JsonElement jsonElement = JsonParser.parseString(responseBody);
|
JsonElement jsonElement = JsonParser.parseString(responseBody);
|
||||||
List<AutoLabelingResponse> autoLabelingResponseList = new ArrayList<>();
|
List<AutoLabelingResponse> autoLabelingResponseList = new ArrayList<>();
|
||||||
for (JsonElement element : jsonElement.getAsJsonArray()) {
|
for (JsonElement element : jsonElement.getAsJsonArray()) {
|
||||||
AutoLabelingResponse response = parseAutoLabelingResponse(element);
|
AutoLabelingResponse response = parseAutoLabelingResponse(element.getAsJsonObject());
|
||||||
autoLabelingResponseList.add(response);
|
autoLabelingResponseList.add(response);
|
||||||
}
|
}
|
||||||
return autoLabelingResponseList;
|
return autoLabelingResponseList;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* jsonElement -> AutoLabelingResponse
|
* jsonObject -> AutoLabelingResponse
|
||||||
*/
|
*/
|
||||||
private AutoLabelingResponse parseAutoLabelingResponse(JsonElement element) {
|
public AutoLabelingResponse parseAutoLabelingResponse(JsonObject jsonObject) {
|
||||||
JsonObject jsonObject = element.getAsJsonObject();
|
|
||||||
Long imageId = jsonObject.get("image_id").getAsLong();
|
Long imageId = jsonObject.get("image_id").getAsLong();
|
||||||
String imageUrl = jsonObject.get("image_url").getAsString();
|
String imageUrl = jsonObject.get("image_url").getAsString();
|
||||||
JsonObject data = jsonObject.get("data").getAsJsonObject();
|
JsonObject data = jsonObject.get("data").getAsJsonObject();
|
||||||
|
@ -26,6 +26,8 @@ public interface ParticipantRepository extends JpaRepository<Participant, Intege
|
|||||||
@Param("projectId") Integer projectId);
|
@Param("projectId") Integer projectId);
|
||||||
|
|
||||||
Optional<Participant> findByMemberIdAndProjectId(Integer memberId, Integer projectId);
|
Optional<Participant> findByMemberIdAndProjectId(Integer memberId, Integer projectId);
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -0,0 +1,42 @@
|
|||||||
|
package com.worlabel.domain.progress;
|
||||||
|
|
||||||
|
import com.google.gson.Gson;
|
||||||
|
import com.google.gson.JsonObject;
|
||||||
|
import com.worlabel.domain.label.entity.dto.AutoLabelingResponse;
|
||||||
|
import com.worlabel.domain.label.service.LabelService;
|
||||||
|
import com.worlabel.global.service.S3UploadService;
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.messaging.handler.annotation.MessageMapping;
|
||||||
|
import org.springframework.messaging.handler.annotation.SendTo;
|
||||||
|
import org.springframework.stereotype.Controller;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Controller
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class ProgressController {
|
||||||
|
|
||||||
|
private final LabelService labelService;
|
||||||
|
private final S3UploadService s3UploadService;
|
||||||
|
private final Gson gson;
|
||||||
|
|
||||||
|
@MessageMapping("/ai/train/progress")
|
||||||
|
@SendTo("/topic/progress")
|
||||||
|
public String handleTrainingProgress(String message) {
|
||||||
|
// FastAPI에서 전송한 학습 진행 상황 메시지를 처리하고 클라이언트로 전달
|
||||||
|
log.debug("Received message: {}", message);
|
||||||
|
return message;
|
||||||
|
}
|
||||||
|
|
||||||
|
@MessageMapping("/ai/predict/progress")
|
||||||
|
@SendTo("/topic/progress")
|
||||||
|
public String handlePredictProgress(String message) {
|
||||||
|
JsonObject jsonObject = gson.fromJson(message, JsonObject.class);
|
||||||
|
|
||||||
|
int progress = jsonObject.get("progress").getAsInt();
|
||||||
|
AutoLabelingResponse autoLabelingResponse = labelService.parseAutoLabelingResponse(jsonObject.getAsJsonObject("result"));
|
||||||
|
labelService.saveLabel(autoLabelingResponse);
|
||||||
|
log.debug("오토 레이블링 진행률 : {}%",progress);
|
||||||
|
return String.valueOf(progress);
|
||||||
|
}
|
||||||
|
}
|
@ -16,10 +16,12 @@ import io.swagger.v3.oas.annotations.Parameter;
|
|||||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||||
import jakarta.validation.Valid;
|
import jakarta.validation.Valid;
|
||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.web.bind.annotation.*;
|
import org.springframework.web.bind.annotation.*;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
@Tag(name = "프로젝트 관련 API")
|
@Tag(name = "프로젝트 관련 API")
|
||||||
@RestController
|
@RestController
|
||||||
@RequestMapping("/api")
|
@RequestMapping("/api")
|
||||||
@ -74,6 +76,17 @@ public class ProjectController {
|
|||||||
return SuccessResponse.of(project);
|
return SuccessResponse.of(project);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Operation(summary = "프로젝트 모델 학습", description = "프로젝트 모델을 학습시킵니다..")
|
||||||
|
@SwaggerApiSuccess(description = "프로젝트 모델이 성공적으로 학습됩니다.")
|
||||||
|
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
|
||||||
|
@PostMapping("/projects/{project_id}/train")
|
||||||
|
public BaseResponse<Void> trainModel(
|
||||||
|
@CurrentUser final Integer memberId,
|
||||||
|
@PathVariable("project_id") final Integer projectId) {
|
||||||
|
projectService.train(memberId, projectId);
|
||||||
|
return SuccessResponse.empty();
|
||||||
|
}
|
||||||
|
|
||||||
@Operation(summary = "프로젝트 삭제", description = "프로젝트를 삭제합니다.")
|
@Operation(summary = "프로젝트 삭제", description = "프로젝트를 삭제합니다.")
|
||||||
@SwaggerApiSuccess(description = "프로젝트를 성공적으로 삭제합니다.")
|
@SwaggerApiSuccess(description = "프로젝트를 성공적으로 삭제합니다.")
|
||||||
@SwaggerApiError({ErrorCode.PROJECT_NOT_FOUND, ErrorCode.PARTICIPANT_UNAUTHORIZED, ErrorCode.SERVER_ERROR})
|
@SwaggerApiError({ErrorCode.PROJECT_NOT_FOUND, ErrorCode.PARTICIPANT_UNAUTHORIZED, ErrorCode.SERVER_ERROR})
|
||||||
@ -120,4 +133,6 @@ public class ProjectController {
|
|||||||
return SuccessResponse.empty();
|
return SuccessResponse.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
package com.worlabel.domain.project.dto;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class RequestDto {
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class TrainDataInfo {
|
||||||
|
private String imageUrl;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public static class TrainRequest {
|
||||||
|
@JsonProperty("project_id")
|
||||||
|
private int projectId;
|
||||||
|
|
||||||
|
@JsonProperty("data")
|
||||||
|
private List<TrainDataInfo> data;
|
||||||
|
// private int seed; // Optional
|
||||||
|
// private float ratio; // Default = 0.8
|
||||||
|
// private int epochs; // Default = 50
|
||||||
|
// private float batch; // Default = -1
|
||||||
|
|
||||||
|
// Getters and Setters
|
||||||
|
}
|
||||||
|
}
|
@ -8,6 +8,7 @@ import com.worlabel.domain.participant.entity.WorkspaceParticipant;
|
|||||||
import com.worlabel.domain.participant.entity.dto.ParticipantRequest;
|
import com.worlabel.domain.participant.entity.dto.ParticipantRequest;
|
||||||
import com.worlabel.domain.participant.repository.ParticipantRepository;
|
import com.worlabel.domain.participant.repository.ParticipantRepository;
|
||||||
import com.worlabel.domain.participant.repository.WorkspaceParticipantRepository;
|
import com.worlabel.domain.participant.repository.WorkspaceParticipantRepository;
|
||||||
|
import com.worlabel.domain.project.dto.RequestDto.TrainRequest;
|
||||||
import com.worlabel.domain.project.entity.Project;
|
import com.worlabel.domain.project.entity.Project;
|
||||||
import com.worlabel.domain.project.entity.dto.ProjectRequest;
|
import com.worlabel.domain.project.entity.dto.ProjectRequest;
|
||||||
import com.worlabel.domain.project.entity.dto.ProjectResponse;
|
import com.worlabel.domain.project.entity.dto.ProjectResponse;
|
||||||
@ -19,8 +20,10 @@ import com.worlabel.global.exception.ErrorCode;
|
|||||||
import lombok.RequiredArgsConstructor;
|
import lombok.RequiredArgsConstructor;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.http.ResponseEntity;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
|
import org.springframework.web.client.RestTemplate;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
@ -36,7 +39,13 @@ public class ProjectService {
|
|||||||
private final ParticipantRepository participantRepository;
|
private final ParticipantRepository participantRepository;
|
||||||
private final MemberRepository memberRepository;
|
private final MemberRepository memberRepository;
|
||||||
private final WorkspaceParticipantRepository workspaceParticipantRepository;
|
private final WorkspaceParticipantRepository workspaceParticipantRepository;
|
||||||
|
private final RestTemplate restTemplate;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AI SERVER 주소
|
||||||
|
*/
|
||||||
|
@Value("${ai.server}")
|
||||||
|
private String aiServer;
|
||||||
|
|
||||||
public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectRequest projectRequest) {
|
public ProjectResponse createProject(final Integer memberId, final Integer workspaceId, final ProjectRequest projectRequest) {
|
||||||
Workspace workspace = getWorkspace(memberId, workspaceId);
|
Workspace workspace = getWorkspace(memberId, workspaceId);
|
||||||
@ -114,6 +123,33 @@ public class ProjectService {
|
|||||||
participantRepository.delete(participant);
|
participantRepository.delete(participant);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void train(final Integer memberId, final Integer projectId) {
|
||||||
|
// 멤버 권한 체크
|
||||||
|
checkEditorParticipant(memberId, projectId);
|
||||||
|
|
||||||
|
// TODO: 레디스 train 테이블에 존재하는지 확인 -> 이미 있으면 있다고 예외를 던져준다.
|
||||||
|
/*
|
||||||
|
없으면 redis 상태 테이블을 만든다. progressTable
|
||||||
|
*/
|
||||||
|
|
||||||
|
// FastAPI 서버로 학습 요청을 전송
|
||||||
|
String url = aiServer + "/detection/train";
|
||||||
|
|
||||||
|
TrainRequest trainRequest = new TrainRequest();
|
||||||
|
trainRequest.setProjectId(projectId);
|
||||||
|
trainRequest.setData(List.of());
|
||||||
|
|
||||||
|
// FastAPI 서버로 POST 요청 전송
|
||||||
|
try {
|
||||||
|
ResponseEntity<String> result = restTemplate.postForEntity(url, trainRequest, String.class);
|
||||||
|
log.debug("응답 결과 {} ",result);
|
||||||
|
log.debug("FastAPI 서버에 학습 요청을 성공적으로 전송했습니다. Project ID: {}", projectId);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("FastAPI 서버에 학습 요청을 전송하는 중 오류가 발생했습니다 {}",e.getMessage());
|
||||||
|
throw new CustomException(ErrorCode.AI_SERVER_ERROR);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private Workspace getWorkspace(final Integer memberId, final Integer workspaceId) {
|
private Workspace getWorkspace(final Integer memberId, final Integer workspaceId) {
|
||||||
return workspaceRepository.findByMemberIdAndId(memberId, workspaceId)
|
return workspaceRepository.findByMemberIdAndId(memberId, workspaceId)
|
||||||
.orElseThrow(() -> new CustomException(ErrorCode.WORKSPACE_NOT_FOUND));
|
.orElseThrow(() -> new CustomException(ErrorCode.WORKSPACE_NOT_FOUND));
|
||||||
@ -129,6 +165,12 @@ public class ProjectService {
|
|||||||
.orElseThrow(() -> new CustomException(ErrorCode.PROJECT_NOT_FOUND));
|
.orElseThrow(() -> new CustomException(ErrorCode.PROJECT_NOT_FOUND));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void checkEditorParticipant(final Integer memberId, final Integer projectId) {
|
||||||
|
if(participantRepository.doesParticipantUnauthorizedExistByMemberIdAndProjectId(memberId,projectId)){
|
||||||
|
throw new CustomException(ErrorCode.PARTICIPANT_UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private void checkExistParticipant(final Integer memberId, final Integer projectId) {
|
private void checkExistParticipant(final Integer memberId, final Integer projectId) {
|
||||||
if (!participantRepository.existsByMemberIdAndProjectId(memberId, projectId)) {
|
if (!participantRepository.existsByMemberIdAndProjectId(memberId, projectId)) {
|
||||||
throw new CustomException(ErrorCode.PARTICIPANT_UNAUTHORIZED);
|
throw new CustomException(ErrorCode.PARTICIPANT_UNAUTHORIZED);
|
||||||
@ -152,5 +194,6 @@ public class ProjectService {
|
|||||||
throw new CustomException(ErrorCode.PARTICIPANT_BAD_REQUEST);
|
throw new CustomException(ErrorCode.PARTICIPANT_BAD_REQUEST);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,10 @@
|
|||||||
|
package com.worlabel.global.config;
|
||||||
|
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
import org.springframework.scheduling.annotation.EnableAsync;
|
||||||
|
|
||||||
|
@EnableAsync
|
||||||
|
@Configuration
|
||||||
|
public class AsyncConfig {
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
package com.worlabel.global.config;
|
||||||
|
|
||||||
|
public class CustomWebSocketConfig1 {
|
||||||
|
}
|
@ -48,7 +48,7 @@ public class SecurityConfig {
|
|||||||
.formLogin((auth) -> auth.disable());
|
.formLogin((auth) -> auth.disable());
|
||||||
|
|
||||||
// 세션 설정 비활성화
|
// 세션 설정 비활성화
|
||||||
http.sessionManagement((session)->session
|
http.sessionManagement((session) -> session
|
||||||
.sessionCreationPolicy(SessionCreationPolicy.STATELESS));
|
.sessionCreationPolicy(SessionCreationPolicy.STATELESS));
|
||||||
|
|
||||||
// CORS 설정
|
// CORS 설정
|
||||||
@ -63,8 +63,8 @@ public class SecurityConfig {
|
|||||||
|
|
||||||
// 경로별 인가 작업
|
// 경로별 인가 작업
|
||||||
http
|
http
|
||||||
.authorizeHttpRequests(auth->auth
|
.authorizeHttpRequests(auth -> auth
|
||||||
.requestMatchers("/swagger", "/swagger-ui.html", "/swagger-ui/**", "/api-docs", "/api-docs/**", "/v3/api-docs/**").permitAll()
|
.requestMatchers("/swagger", "/swagger-ui.html", "/swagger-ui/**", "/api-docs", "/api-docs/**", "/v3/api-docs/**", "/ws/**").permitAll()
|
||||||
.requestMatchers("/api/auth/reissue").permitAll()
|
.requestMatchers("/api/auth/reissue").permitAll()
|
||||||
.anyRequest().authenticated()
|
.anyRequest().authenticated()
|
||||||
// .anyRequest().permitAll()
|
// .anyRequest().permitAll()
|
||||||
@ -80,9 +80,6 @@ public class SecurityConfig {
|
|||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// JWT 필터 추가
|
// JWT 필터 추가
|
||||||
http
|
http
|
||||||
.addFilterBefore(jwtAuthenticationFilter, UsernamePasswordAuthenticationFilter.class);
|
.addFilterBefore(jwtAuthenticationFilter, UsernamePasswordAuthenticationFilter.class);
|
||||||
@ -93,8 +90,8 @@ public class SecurityConfig {
|
|||||||
public CorsConfigurationSource corsConfigurationSource() {
|
public CorsConfigurationSource corsConfigurationSource() {
|
||||||
CorsConfiguration configuration = new CorsConfiguration();
|
CorsConfiguration configuration = new CorsConfiguration();
|
||||||
configuration.setAllowCredentials(true);
|
configuration.setAllowCredentials(true);
|
||||||
configuration.setAllowedOrigins(List.of(frontend)); // 프론트엔드 URL 사용
|
configuration.setAllowedOrigins(List.of(frontend, "http://localhost:5173")); // 프론트엔드 URL 사용
|
||||||
configuration.setAllowedMethods(List.of("GET","POST","PUT","PATCH","DELETE","OPTIONS"));
|
configuration.setAllowedMethods(List.of("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"));
|
||||||
configuration.setAllowedHeaders(List.of("*"));
|
configuration.setAllowedHeaders(List.of("*"));
|
||||||
configuration.setMaxAge(3600L);
|
configuration.setMaxAge(3600L);
|
||||||
|
|
||||||
|
@ -0,0 +1,26 @@
|
|||||||
|
package com.worlabel.global.config;
|
||||||
|
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
|
||||||
|
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
|
||||||
|
import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
|
||||||
|
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
|
||||||
|
|
||||||
|
@Configuration
|
||||||
|
@EnableWebSocketMessageBroker
|
||||||
|
public class WebSocketConfig implements WebSocketMessageBrokerConfigurer {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void configureMessageBroker(MessageBrokerRegistry config) {
|
||||||
|
// 메시지 브로커 설정: 클라이언트가 구독할 수 있는 경로 지정
|
||||||
|
config.enableSimpleBroker("/topic");
|
||||||
|
// 클라이언트가 메시지를 보낼 때 사용하는 경로의 접두사
|
||||||
|
config.setApplicationDestinationPrefixes("/app");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void registerStompEndpoints(StompEndpointRegistry registry) {
|
||||||
|
registry.addEndpoint("/ws")
|
||||||
|
.setAllowedOrigins("*");
|
||||||
|
}
|
||||||
|
}
|
@ -30,6 +30,7 @@ public class JwtAuthenticationFilter extends OncePerRequestFilter {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
|
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
|
||||||
|
log.debug(request.getRequestURI());
|
||||||
String token = resolveToken(request);
|
String token = resolveToken(request);
|
||||||
log.debug("token {}", token);
|
log.debug("token {}", token);
|
||||||
try {
|
try {
|
||||||
|
@ -0,0 +1,20 @@
|
|||||||
|
package com.worlabel.global.handler;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
import org.springframework.web.socket.TextMessage;
|
||||||
|
import org.springframework.web.socket.WebSocketSession;
|
||||||
|
import org.springframework.web.socket.handler.TextWebSocketHandler;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Component
|
||||||
|
public class CustomWebSocketHandler extends TextWebSocketHandler {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
|
||||||
|
// FastAPI에서 받은 메세지
|
||||||
|
log.debug("FastAPI로 부터 받은 메세지 : {}", message.getPayload());
|
||||||
|
|
||||||
|
// 클라이언트로 보내기
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,74 @@
|
|||||||
|
package com.worlabel.global.service;
|
||||||
|
|
||||||
|
import lombok.RequiredArgsConstructor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
import org.springframework.web.socket.TextMessage;
|
||||||
|
import org.springframework.web.socket.WebSocketHttpHeaders;
|
||||||
|
import org.springframework.web.socket.WebSocketSession;
|
||||||
|
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
|
||||||
|
import org.springframework.web.socket.handler.TextWebSocketHandler;
|
||||||
|
|
||||||
|
import java.net.URI;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Component
|
||||||
|
@RequiredArgsConstructor
|
||||||
|
public class AIWebSocketClient {
|
||||||
|
|
||||||
|
private WebSocketSession session;
|
||||||
|
|
||||||
|
// WebSocket 연결 설정
|
||||||
|
public void connect(String url) {
|
||||||
|
try {
|
||||||
|
StandardWebSocketClient client = new StandardWebSocketClient();
|
||||||
|
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
|
||||||
|
client.doHandshake(new WebSocketHandler(), headers, URI.create(url)).get();
|
||||||
|
log.info("Connected to WebSocket at {}", url);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Failed to connect to WebSocket: {}", e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocket 메시지 전송
|
||||||
|
public void sendMessage(String message) {
|
||||||
|
try {
|
||||||
|
if (session != null && session.isOpen()) {
|
||||||
|
session.sendMessage(new TextMessage(message));
|
||||||
|
log.info("Sent message: {}", message);
|
||||||
|
} else {
|
||||||
|
log.warn("WebSocket session is not open. Unable to send message.");
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Failed to send message: {}", e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// WebSocket 연결 종료
|
||||||
|
public void close() {
|
||||||
|
try {
|
||||||
|
if (session != null && session.isOpen()) {
|
||||||
|
session.close();
|
||||||
|
log.info("WebSocket connection closed.");
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Failed to close WebSocket session: {}", e.getMessage());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocket 핸들러 정의
|
||||||
|
private class WebSocketHandler extends TextWebSocketHandler {
|
||||||
|
@Override
|
||||||
|
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
|
||||||
|
AIWebSocketClient.this.session = session;
|
||||||
|
log.info("WebSocket connection established.");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
|
||||||
|
log.info("Received message: {}", message.getPayload());
|
||||||
|
// 여기서 메시지를 처리하는 로직을 추가할 수 있습니다.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -50,8 +50,7 @@ public class S3UploadService {
|
|||||||
// String targetUrl = projectId + "/" + title + ".json"; // S3에 업로드할 대상 URL
|
// String targetUrl = projectId + "/" + title + ".json"; // S3에 업로드할 대상 URL
|
||||||
String targetUrl = removeExtension(getKeyFromImageAddress(imageUrl)) + ".json";
|
String targetUrl = removeExtension(getKeyFromImageAddress(imageUrl)) + ".json";
|
||||||
|
|
||||||
log.debug("주소 {}", targetUrl);
|
// log.debug("주소 {}", targetUrl);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
byte[] jsonBytes = json.getBytes(StandardCharsets.UTF_8);
|
byte[] jsonBytes = json.getBytes(StandardCharsets.UTF_8);
|
||||||
ObjectMetadata metadata = new ObjectMetadata();
|
ObjectMetadata metadata = new ObjectMetadata();
|
||||||
@ -65,7 +64,7 @@ public class S3UploadService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
URL uploadedUrl = amazonS3.getUrl(bucket, targetUrl);
|
URL uploadedUrl = amazonS3.getUrl(bucket, targetUrl);
|
||||||
log.debug("Uploaded JSON URL: {}", uploadedUrl);
|
// log.debug("Uploaded JSON URL: {}", uploadedUrl);
|
||||||
return uploadedUrl.toString(); // 업로드된 파일의 URL 반환
|
return uploadedUrl.toString(); // 업로드된 파일의 URL 반환
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("JSON 업로드 중 오류 발생: ", e);
|
log.error("JSON 업로드 중 오류 발생: ", e);
|
||||||
|
Loading…
Reference in New Issue
Block a user