From ba219fa20801c610ce2de56021cec6a363d3ab3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EA=B9=80=EC=9A=A9=EC=88=98?= Date: Wed, 11 Sep 2024 17:15:06 +0900 Subject: [PATCH] =?UTF-8?q?Feat:=20=EC=9B=B9=EC=86=8C=EC=BC=93=20=EC=84=A4?= =?UTF-8?q?=EC=A0=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/build.gradle | 3 ++ .../project/controller/ProjectController.java | 13 +++++++ .../project/service/ProjectService.java | 9 +++++ .../worlabel/global/config/CorsMvcConfig.java | 2 +- .../worlabel/global/config/RedisConfig.java | 19 ++++++++++ .../global/config/SecurityConfig.java | 23 ++++++------ .../global/config/WebSocketConfig.java | 25 +++++++++++++ .../handler/CustomWebSocketHandler.java | 35 +++++++++++++++++++ .../service/RedisMessageSubscriber.java | 35 +++++++++++++++++++ 9 files changed, 150 insertions(+), 14 deletions(-) create mode 100644 backend/src/main/java/com/worlabel/global/config/WebSocketConfig.java create mode 100644 backend/src/main/java/com/worlabel/global/handler/CustomWebSocketHandler.java create mode 100644 backend/src/main/java/com/worlabel/global/service/RedisMessageSubscriber.java diff --git a/backend/build.gradle b/backend/build.gradle index 0c0b514..14778a9 100644 --- a/backend/build.gradle +++ b/backend/build.gradle @@ -69,6 +69,9 @@ dependencies { testImplementation 'org.junit.jupiter:junit-jupiter:5.7.1' testImplementation 'org.mockito:mockito-core:3.9.0' testImplementation 'org.mockito:mockito-junit-jupiter:3.9.0' + + // WebSocket + implementation 'org.springframework.boot:spring-boot-starter-websocket' } tasks.named('test') { diff --git a/backend/src/main/java/com/worlabel/domain/project/controller/ProjectController.java b/backend/src/main/java/com/worlabel/domain/project/controller/ProjectController.java index da3701b..152883d 100644 --- a/backend/src/main/java/com/worlabel/domain/project/controller/ProjectController.java +++ b/backend/src/main/java/com/worlabel/domain/project/controller/ProjectController.java @@ -74,6 +74,17 @@ public class ProjectController { return SuccessResponse.of(project); } + @Operation(summary = "프로젝트 모델 학습", description = "프로젝트 모델을 학습시킵니다..") + @SwaggerApiSuccess(description = "프로젝트 모델이 성공적으로 학습됩니다.") + @SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR}) + @DeleteMapping("/projects/{project_id}/train") + public BaseResponse trainModel( + @CurrentUser final Integer memberId, + @PathVariable("project_id") final Integer projectId) { + projectService.train(memberId, projectId); + return SuccessResponse.empty(); + } + @Operation(summary = "프로젝트 삭제", description = "프로젝트를 삭제합니다.") @SwaggerApiSuccess(description = "프로젝트를 성공적으로 삭제합니다.") @SwaggerApiError({ErrorCode.PROJECT_NOT_FOUND, ErrorCode.PARTICIPANT_UNAUTHORIZED, ErrorCode.SERVER_ERROR}) @@ -120,4 +131,6 @@ public class ProjectController { return SuccessResponse.empty(); } + + } diff --git a/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java b/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java index 3efbaa2..43fde3b 100644 --- a/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java +++ b/backend/src/main/java/com/worlabel/domain/project/service/ProjectService.java @@ -114,6 +114,14 @@ public class ProjectService { participantRepository.delete(participant); } + public void train(final Integer memberId,final Integer projectId) { + // 멤버 권한 체크 + + // 레디스 train 테이블에 존재하는지 확인 + + // AI서버와 웹 소켓 연결 + } + private Workspace getWorkspace(final Integer memberId, final Integer workspaceId) { return workspaceRepository.findByMemberIdAndId(memberId, workspaceId) .orElseThrow(() -> new CustomException(ErrorCode.WORKSPACE_NOT_FOUND)); @@ -152,5 +160,6 @@ public class ProjectService { throw new CustomException(ErrorCode.PARTICIPANT_BAD_REQUEST); } } + } diff --git a/backend/src/main/java/com/worlabel/global/config/CorsMvcConfig.java b/backend/src/main/java/com/worlabel/global/config/CorsMvcConfig.java index d138275..add5b16 100644 --- a/backend/src/main/java/com/worlabel/global/config/CorsMvcConfig.java +++ b/backend/src/main/java/com/worlabel/global/config/CorsMvcConfig.java @@ -15,7 +15,7 @@ public class CorsMvcConfig implements WebMvcConfigurer { public void addCorsMappings(CorsRegistry registry) { registry.addMapping("/**") .exposedHeaders("Set-Cookie") - .allowedOrigins(frontend) // application.yml에서 가져온 값 사용 + .allowedOrigins(frontend, "http://localhost:5173") // application.yml에서 가져온 값 사용 .allowCredentials(true); } } diff --git a/backend/src/main/java/com/worlabel/global/config/RedisConfig.java b/backend/src/main/java/com/worlabel/global/config/RedisConfig.java index c546bdc..3c0cde3 100644 --- a/backend/src/main/java/com/worlabel/global/config/RedisConfig.java +++ b/backend/src/main/java/com/worlabel/global/config/RedisConfig.java @@ -6,6 +6,10 @@ import org.springframework.context.annotation.Configuration; import org.springframework.data.redis.connection.RedisConnectionFactory; import org.springframework.data.redis.connection.lettuce.LettuceConnectionFactory; import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.listener.ChannelTopic; +import org.springframework.data.redis.listener.PatternTopic; +import org.springframework.data.redis.listener.RedisMessageListenerContainer; +import org.springframework.data.redis.listener.adapter.MessageListenerAdapter; import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer; import org.springframework.data.redis.serializer.StringRedisSerializer; @@ -42,4 +46,19 @@ public class RedisConfig { return redisTemplate; } + + @Bean + public RedisMessageListenerContainer redisContainer(RedisConnectionFactory connectionFactory, + MessageListenerAdapter listenerAdapter) { + + RedisMessageListenerContainer container = new RedisMessageListenerContainer(); + container.setConnectionFactory(connectionFactory); + container.addMessageListener(listenerAdapter, new ChannelTopic("/ai/train")); + return container; + } + + @Bean + public MessageListenerAdapter listenerAdapter(MessageListenerAdapter listenerAdapter) { + return new MessageListenerAdapter(listenerAdapter, "onMessage"); + } } diff --git a/backend/src/main/java/com/worlabel/global/config/SecurityConfig.java b/backend/src/main/java/com/worlabel/global/config/SecurityConfig.java index 9c37c6d..5fbeb33 100644 --- a/backend/src/main/java/com/worlabel/global/config/SecurityConfig.java +++ b/backend/src/main/java/com/worlabel/global/config/SecurityConfig.java @@ -48,7 +48,7 @@ public class SecurityConfig { .formLogin((auth) -> auth.disable()); // 세션 설정 비활성화 - http.sessionManagement((session)->session + http.sessionManagement((session) -> session .sessionCreationPolicy(SessionCreationPolicy.STATELESS)); // CORS 설정 @@ -57,16 +57,16 @@ public class SecurityConfig { http .exceptionHandling(configurer -> configurer - .authenticationEntryPoint(authenticationEntryPoint) - .accessDeniedHandler(authenticationDeniedHandler) - ); + .authenticationEntryPoint(authenticationEntryPoint) + .accessDeniedHandler(authenticationDeniedHandler) + ); // 경로별 인가 작업 http - .authorizeHttpRequests(auth->auth + .authorizeHttpRequests(auth -> auth .requestMatchers("/swagger", "/swagger-ui.html", "/swagger-ui/**", "/api-docs", "/api-docs/**", "/v3/api-docs/**").permitAll() - .requestMatchers("/api/auth/reissue").permitAll() - .anyRequest().authenticated() + .requestMatchers("/api/auth/reissue").permitAll() + .anyRequest().authenticated() // .anyRequest().permitAll() ); @@ -74,15 +74,12 @@ public class SecurityConfig { http .oauth2Login(oauth2 -> oauth2 .authorizationEndpoint(authorization -> authorization.baseUri("/api/login/oauth2/authorization")) - .redirectionEndpoint(redirection -> redirection.baseUri("/api/login/oauth2/code/*")) + .redirectionEndpoint(redirection -> redirection.baseUri("/api/login/oauth2/code/*")) .userInfoEndpoint(userInfo -> userInfo.userService(customOAuth2UserService)) .successHandler(oAuth2SuccessHandler) ); - - - // JWT 필터 추가 http .addFilterBefore(jwtAuthenticationFilter, UsernamePasswordAuthenticationFilter.class); @@ -93,8 +90,8 @@ public class SecurityConfig { public CorsConfigurationSource corsConfigurationSource() { CorsConfiguration configuration = new CorsConfiguration(); configuration.setAllowCredentials(true); - configuration.setAllowedOrigins(List.of(frontend)); // 프론트엔드 URL 사용 - configuration.setAllowedMethods(List.of("GET","POST","PUT","PATCH","DELETE","OPTIONS")); + configuration.setAllowedOrigins(List.of(frontend, "http://localhost:5173")); // 프론트엔드 URL 사용 + configuration.setAllowedMethods(List.of("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")); configuration.setAllowedHeaders(List.of("*")); configuration.setMaxAge(3600L); diff --git a/backend/src/main/java/com/worlabel/global/config/WebSocketConfig.java b/backend/src/main/java/com/worlabel/global/config/WebSocketConfig.java new file mode 100644 index 0000000..432a474 --- /dev/null +++ b/backend/src/main/java/com/worlabel/global/config/WebSocketConfig.java @@ -0,0 +1,25 @@ +package com.worlabel.global.config; + +import com.worlabel.global.handler.CustomWebSocketHandler; +import lombok.RequiredArgsConstructor; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.socket.config.annotation.EnableWebSocket; +import org.springframework.web.socket.config.annotation.WebSocketConfigurer; +import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; + +@Configuration +@EnableWebSocket +@RequiredArgsConstructor +public class WebSocketConfig implements WebSocketConfigurer { + + private final CustomWebSocketHandler webSocketHandler; + + @Override + public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { + registry + .addHandler(webSocketHandler, "/ws") + .setAllowedOrigins("*"); + } + +} diff --git a/backend/src/main/java/com/worlabel/global/handler/CustomWebSocketHandler.java b/backend/src/main/java/com/worlabel/global/handler/CustomWebSocketHandler.java new file mode 100644 index 0000000..712ab26 --- /dev/null +++ b/backend/src/main/java/com/worlabel/global/handler/CustomWebSocketHandler.java @@ -0,0 +1,35 @@ +package com.worlabel.global.handler; + +import com.worlabel.global.service.RedisMessageSubscriber; +import lombok.RequiredArgsConstructor; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.lang.NonNull; +import org.springframework.stereotype.Component; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.handler.TextWebSocketHandler; + +@Component +@RequiredArgsConstructor +public class CustomWebSocketHandler extends TextWebSocketHandler { + + private final RedisTemplate redisTemplate; + private final RedisMessageSubscriber redisMessageSubscriber; + + @Override + public void afterConnectionEstablished(@NonNull WebSocketSession session) { + redisMessageSubscriber.addSession(session); + } + + @Override + public void afterConnectionClosed(@NonNull WebSocketSession session,@NonNull CloseStatus status) throws Exception { + redisMessageSubscriber.removeSession(session); + } + + @Override + protected void handleTextMessage(@NonNull WebSocketSession session, TextMessage message) { + // Redis 메시지 발행 + redisTemplate.convertAndSend("/ai/train", message.getPayload()); + } +} diff --git a/backend/src/main/java/com/worlabel/global/service/RedisMessageSubscriber.java b/backend/src/main/java/com/worlabel/global/service/RedisMessageSubscriber.java new file mode 100644 index 0000000..79621cd --- /dev/null +++ b/backend/src/main/java/com/worlabel/global/service/RedisMessageSubscriber.java @@ -0,0 +1,35 @@ +package com.worlabel.global.service; + +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; + +import java.io.IOException; +import java.util.Set; +import java.util.concurrent.CopyOnWriteArraySet; + +@Slf4j +@Service +public class RedisMessageSubscriber { + + private final Set sessions = new CopyOnWriteArraySet<>(); + + public void addSession(WebSocketSession session) { + sessions.add(session); + } + + public void removeSession(WebSocketSession session) { + sessions.remove(session); + } + + public void onMessage(String message) { + for (WebSocketSession session : sessions) { + try { + session.sendMessage(new TextMessage(message)); + } catch (IOException e) { + log.debug("", e); + } + } + } +} \ No newline at end of file