Feat: 웹소켓 설정

This commit is contained in:
김용수 2024-09-11 17:15:06 +09:00
parent c4ba971a53
commit ba219fa208
9 changed files with 150 additions and 14 deletions

View File

@ -69,6 +69,9 @@ dependencies {
testImplementation 'org.junit.jupiter:junit-jupiter:5.7.1'
testImplementation 'org.mockito:mockito-core:3.9.0'
testImplementation 'org.mockito:mockito-junit-jupiter:3.9.0'
// WebSocket
implementation 'org.springframework.boot:spring-boot-starter-websocket'
}
tasks.named('test') {

View File

@ -74,6 +74,17 @@ public class ProjectController {
return SuccessResponse.of(project);
}
@Operation(summary = "프로젝트 모델 학습", description = "프로젝트 모델을 학습시킵니다..")
@SwaggerApiSuccess(description = "프로젝트 모델이 성공적으로 학습됩니다.")
@SwaggerApiError({ErrorCode.EMPTY_REQUEST_PARAMETER, ErrorCode.SERVER_ERROR})
@DeleteMapping("/projects/{project_id}/train")
public BaseResponse<Void> trainModel(
@CurrentUser final Integer memberId,
@PathVariable("project_id") final Integer projectId) {
projectService.train(memberId, projectId);
return SuccessResponse.empty();
}
@Operation(summary = "프로젝트 삭제", description = "프로젝트를 삭제합니다.")
@SwaggerApiSuccess(description = "프로젝트를 성공적으로 삭제합니다.")
@SwaggerApiError({ErrorCode.PROJECT_NOT_FOUND, ErrorCode.PARTICIPANT_UNAUTHORIZED, ErrorCode.SERVER_ERROR})
@ -120,4 +131,6 @@ public class ProjectController {
return SuccessResponse.empty();
}
}

View File

@ -114,6 +114,14 @@ public class ProjectService {
participantRepository.delete(participant);
}
public void train(final Integer memberId,final Integer projectId) {
// 멤버 권한 체크
// 레디스 train 테이블에 존재하는지 확인
// AI서버와 소켓 연결
}
private Workspace getWorkspace(final Integer memberId, final Integer workspaceId) {
return workspaceRepository.findByMemberIdAndId(memberId, workspaceId)
.orElseThrow(() -> new CustomException(ErrorCode.WORKSPACE_NOT_FOUND));
@ -152,5 +160,6 @@ public class ProjectService {
throw new CustomException(ErrorCode.PARTICIPANT_BAD_REQUEST);
}
}
}

View File

@ -15,7 +15,7 @@ public class CorsMvcConfig implements WebMvcConfigurer {
public void addCorsMappings(CorsRegistry registry) {
registry.addMapping("/**")
.exposedHeaders("Set-Cookie")
.allowedOrigins(frontend) // application.yml에서 가져온 사용
.allowedOrigins(frontend, "http://localhost:5173") // application.yml에서 가져온 사용
.allowCredentials(true);
}
}

View File

@ -6,6 +6,10 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.connection.lettuce.LettuceConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.listener.ChannelTopic;
import org.springframework.data.redis.listener.PatternTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.data.redis.listener.adapter.MessageListenerAdapter;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
@ -42,4 +46,19 @@ public class RedisConfig {
return redisTemplate;
}
@Bean
public RedisMessageListenerContainer redisContainer(RedisConnectionFactory connectionFactory,
MessageListenerAdapter listenerAdapter) {
RedisMessageListenerContainer container = new RedisMessageListenerContainer();
container.setConnectionFactory(connectionFactory);
container.addMessageListener(listenerAdapter, new ChannelTopic("/ai/train"));
return container;
}
@Bean
public MessageListenerAdapter listenerAdapter(MessageListenerAdapter listenerAdapter) {
return new MessageListenerAdapter(listenerAdapter, "onMessage");
}
}

View File

@ -48,7 +48,7 @@ public class SecurityConfig {
.formLogin((auth) -> auth.disable());
// 세션 설정 비활성화
http.sessionManagement((session)->session
http.sessionManagement((session) -> session
.sessionCreationPolicy(SessionCreationPolicy.STATELESS));
// CORS 설정
@ -57,16 +57,16 @@ public class SecurityConfig {
http
.exceptionHandling(configurer -> configurer
.authenticationEntryPoint(authenticationEntryPoint)
.accessDeniedHandler(authenticationDeniedHandler)
);
.authenticationEntryPoint(authenticationEntryPoint)
.accessDeniedHandler(authenticationDeniedHandler)
);
// 경로별 인가 작업
http
.authorizeHttpRequests(auth->auth
.authorizeHttpRequests(auth -> auth
.requestMatchers("/swagger", "/swagger-ui.html", "/swagger-ui/**", "/api-docs", "/api-docs/**", "/v3/api-docs/**").permitAll()
.requestMatchers("/api/auth/reissue").permitAll()
.anyRequest().authenticated()
.requestMatchers("/api/auth/reissue").permitAll()
.anyRequest().authenticated()
// .anyRequest().permitAll()
);
@ -74,15 +74,12 @@ public class SecurityConfig {
http
.oauth2Login(oauth2 -> oauth2
.authorizationEndpoint(authorization -> authorization.baseUri("/api/login/oauth2/authorization"))
.redirectionEndpoint(redirection -> redirection.baseUri("/api/login/oauth2/code/*"))
.redirectionEndpoint(redirection -> redirection.baseUri("/api/login/oauth2/code/*"))
.userInfoEndpoint(userInfo -> userInfo.userService(customOAuth2UserService))
.successHandler(oAuth2SuccessHandler)
);
// JWT 필터 추가
http
.addFilterBefore(jwtAuthenticationFilter, UsernamePasswordAuthenticationFilter.class);
@ -93,8 +90,8 @@ public class SecurityConfig {
public CorsConfigurationSource corsConfigurationSource() {
CorsConfiguration configuration = new CorsConfiguration();
configuration.setAllowCredentials(true);
configuration.setAllowedOrigins(List.of(frontend)); // 프론트엔드 URL 사용
configuration.setAllowedMethods(List.of("GET","POST","PUT","PATCH","DELETE","OPTIONS"));
configuration.setAllowedOrigins(List.of(frontend, "http://localhost:5173")); // 프론트엔드 URL 사용
configuration.setAllowedMethods(List.of("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"));
configuration.setAllowedHeaders(List.of("*"));
configuration.setMaxAge(3600L);

View File

@ -0,0 +1,25 @@
package com.worlabel.global.config;
import com.worlabel.global.handler.CustomWebSocketHandler;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
@Configuration
@EnableWebSocket
@RequiredArgsConstructor
public class WebSocketConfig implements WebSocketConfigurer {
private final CustomWebSocketHandler webSocketHandler;
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry
.addHandler(webSocketHandler, "/ws")
.setAllowedOrigins("*");
}
}

View File

@ -0,0 +1,35 @@
package com.worlabel.global.handler;
import com.worlabel.global.service.RedisMessageSubscriber;
import lombok.RequiredArgsConstructor;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.lang.NonNull;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
@Component
@RequiredArgsConstructor
public class CustomWebSocketHandler extends TextWebSocketHandler {
private final RedisTemplate<String , Object> redisTemplate;
private final RedisMessageSubscriber redisMessageSubscriber;
@Override
public void afterConnectionEstablished(@NonNull WebSocketSession session) {
redisMessageSubscriber.addSession(session);
}
@Override
public void afterConnectionClosed(@NonNull WebSocketSession session,@NonNull CloseStatus status) throws Exception {
redisMessageSubscriber.removeSession(session);
}
@Override
protected void handleTextMessage(@NonNull WebSocketSession session, TextMessage message) {
// Redis 메시지 발행
redisTemplate.convertAndSend("/ai/train", message.getPayload());
}
}

View File

@ -0,0 +1,35 @@
package com.worlabel.global.service;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import java.io.IOException;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArraySet;
@Slf4j
@Service
public class RedisMessageSubscriber {
private final Set<WebSocketSession> sessions = new CopyOnWriteArraySet<>();
public void addSession(WebSocketSession session) {
sessions.add(session);
}
public void removeSession(WebSocketSession session) {
sessions.remove(session);
}
public void onMessage(String message) {
for (WebSocketSession session : sessions) {
try {
session.sendMessage(new TextMessage(message));
} catch (IOException e) {
log.debug("", e);
}
}
}
}