diff --git a/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/config/StompWebSocketConfig.kt b/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/config/StompWebSocketConfig.kt index bc25ff7e1..afd3d78d4 100644 --- a/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/config/StompWebSocketConfig.kt +++ b/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/config/StompWebSocketConfig.kt @@ -1,8 +1,9 @@ package com.seugi.api.domain.chat.presentation.websocket.config import com.seugi.api.domain.chat.application.service.chat.room.ChatRoomService +import com.seugi.api.domain.chat.presentation.websocket.handler.StompErrorHandler +import com.seugi.api.domain.chat.presentation.websocket.util.SecurityUtils import com.seugi.api.domain.member.application.exception.MemberErrorCode -import com.seugi.api.global.auth.jwt.JwtUserDetails import com.seugi.api.global.auth.jwt.JwtUtils import com.seugi.api.global.exception.CustomException import org.springframework.beans.factory.annotation.Value @@ -15,7 +16,6 @@ import org.springframework.messaging.simp.config.ChannelRegistration import org.springframework.messaging.simp.config.MessageBrokerRegistry import org.springframework.messaging.simp.stomp.StompHeaderAccessor import org.springframework.messaging.support.ChannelInterceptor -import org.springframework.messaging.support.MessageBuilder import org.springframework.messaging.support.MessageHeaderAccessor import org.springframework.util.AntPathMatcher import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker @@ -29,11 +29,13 @@ class StompWebSocketConfig( private val jwtUtils: JwtUtils, private val chatRoomService: ChatRoomService, @Value("\${spring.rabbitmq.host}") private val rabbitmqHost: String, + private val stompErrorHandler: StompErrorHandler ) : WebSocketMessageBrokerConfigurer { override fun registerStompEndpoints(registry: StompEndpointRegistry) { registry.addEndpoint("/stomp/chat") .setAllowedOrigins("*") + registry.setErrorHandler(stompErrorHandler) } override fun configureMessageBroker(registry: MessageBrokerRegistry) { @@ -42,6 +44,7 @@ class StompWebSocketConfig( registry.enableStompBrokerRelay("/queue", "/topic", "/exchange", "/amq/queue") .setRelayHost(rabbitmqHost) .setVirtualHost("/") + registry.setUserDestinationPrefix("/user") } override fun configureClientInboundChannel(registration: ChannelRegistration) { @@ -50,9 +53,9 @@ class StompWebSocketConfig( val accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor::class.java)!! when (accessor.messageType) { - SimpMessageType.CONNECT -> handleConnect(message, accessor) + SimpMessageType.CONNECT -> handleConnect(accessor) SimpMessageType.SUBSCRIBE -> handleSubscribe(accessor) - SimpMessageType.UNSUBSCRIBE, SimpMessageType.DISCONNECT -> handleUnsubscribeOrDisconnect() + SimpMessageType.UNSUBSCRIBE, SimpMessageType.DISCONNECT -> handleUnsubscribeOrDisconnect(accessor) else -> {} } return message @@ -60,20 +63,13 @@ class StompWebSocketConfig( }) } - private fun handleConnect(message: Message<*>, accessor: StompHeaderAccessor) { + private fun handleConnect(accessor: StompHeaderAccessor) { val authToken = accessor.getNativeHeader("Authorization")?.firstOrNull() if (authToken != null && authToken.startsWith("Bearer ")) { val auth = jwtUtils.getAuthentication(authToken) - val userDetails = auth.principal as? JwtUserDetails - val userId: String? = userDetails?.id?.value?.toString() - - if (userId != null) { - val simpAttributes = SimpAttributesContextHolder.currentAttributes() - simpAttributes.setAttribute("user-id", userId) - MessageBuilder.createMessage(message.payload, accessor.messageHeaders) - } else { - throw CustomException(MemberErrorCode.MEMBER_NOT_FOUND) - } + accessor.user = auth + } else { + throw CustomException(MemberErrorCode.MEMBER_NOT_FOUND) } } @@ -81,25 +77,22 @@ class StompWebSocketConfig( accessor.destination?.let { val simpAttributes = SimpAttributesContextHolder.currentAttributes() simpAttributes.setAttribute("sub", it.substringAfterLast(".")) - val userId = simpAttributes.getAttribute("user-id") as String - chatRoomService.sub( - userId = userId.toLong(), - roomId = it.substringAfterLast(".") - ) + if (it.contains(".")) { + chatRoomService.sub( + userId = SecurityUtils.getUserId(accessor.user), + roomId = it.substringAfterLast(".") + ) + } } } - private fun handleUnsubscribeOrDisconnect() { - val simpAttributes = SimpAttributesContextHolder.currentAttributes() - val userId = simpAttributes.getAttribute("user-id") as String? - val roomId = simpAttributes.getAttribute("sub") as String? - userId?.let { - roomId?.let { - chatRoomService.unSub( - userId = userId.toLong(), - roomId = it - ) - } + private fun handleUnsubscribeOrDisconnect(accessor: StompHeaderAccessor) { + accessor.destination?.let { + val simpAttributes = SimpAttributesContextHolder.currentAttributes() + chatRoomService.unSub( + userId = SecurityUtils.getUserId(accessor.user), + roomId = simpAttributes.getAttribute("sub").toString() + ) } } } \ No newline at end of file diff --git a/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/controller/StompRabbitMQController.kt b/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/controller/StompRabbitMQController.kt index 83a36cabc..9730f004b 100644 --- a/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/controller/StompRabbitMQController.kt +++ b/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/controller/StompRabbitMQController.kt @@ -2,21 +2,20 @@ package com.seugi.api.domain.chat.presentation.websocket.controller import com.seugi.api.domain.chat.application.service.message.MessageService import com.seugi.api.domain.chat.presentation.websocket.dto.ChatMessageDto +import com.seugi.api.domain.chat.presentation.websocket.util.SecurityUtils import org.springframework.messaging.handler.annotation.MessageMapping -import org.springframework.messaging.simp.SimpAttributesContextHolder import org.springframework.stereotype.Controller +import java.security.Principal @Controller class StompRabbitMQController( - private val messageService: MessageService + private val messageService: MessageService, ) { @MessageMapping("chat.message") - fun send(chat: ChatMessageDto) { - val simpAttributes = SimpAttributesContextHolder.currentAttributes() - val userId = simpAttributes.getAttribute("user-id") as String? - messageService.sendAndSaveMessage(chat, userId!!.toLong()) + fun send(chat: ChatMessageDto, principal: Principal) { + messageService.sendAndSaveMessage(chat, SecurityUtils.getUserId(principal)) } } \ No newline at end of file diff --git a/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/handler/StompErrorHandler.kt b/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/handler/StompErrorHandler.kt new file mode 100644 index 000000000..b7bc2750d --- /dev/null +++ b/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/handler/StompErrorHandler.kt @@ -0,0 +1,61 @@ +package com.seugi.api.domain.chat.presentation.websocket.handler + +import com.fasterxml.jackson.core.JsonProcessingException +import com.fasterxml.jackson.databind.ObjectMapper +import com.seugi.api.global.response.ErrorResponse +import io.jsonwebtoken.ExpiredJwtException +import io.jsonwebtoken.MalformedJwtException +import io.jsonwebtoken.UnsupportedJwtException +import org.springframework.context.annotation.Configuration +import org.springframework.messaging.Message +import org.springframework.messaging.MessageDeliveryException +import org.springframework.messaging.simp.stomp.StompCommand +import org.springframework.messaging.simp.stomp.StompHeaderAccessor +import org.springframework.messaging.support.MessageBuilder +import org.springframework.security.access.AccessDeniedException +import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler +import java.nio.charset.StandardCharsets +import java.security.SignatureException + +@Configuration +class StompErrorHandler(private val objectMapper: ObjectMapper) : StompSubProtocolErrorHandler() { + + override fun handleClientMessageProcessingError(clientMessage: Message?, ex: Throwable): Message? { + + return when (ex) { + is MessageDeliveryException -> { + when (val cause = ex.cause) { + is AccessDeniedException -> { + sendErrorMessage(ErrorResponse(status = 4403, message = "Access denied")) + } + else -> { + if (isJwtException(cause)) { + sendErrorMessage(ErrorResponse(status = 4403, message = cause?.message ?: "JWT Exception")) + } else { + sendErrorMessage(ErrorResponse(status = 4403, message = cause?.stackTraceToString() ?: "Unhandled exception")) + } + } + } + } + else -> { + sendErrorMessage(ErrorResponse(status = 4400, message = ex.message ?: "Unhandled root exception")) + } + } + } + + private fun isJwtException(ex: Throwable?): Boolean { + return ex is SignatureException || ex is ExpiredJwtException || ex is MalformedJwtException || ex is UnsupportedJwtException || ex is IllegalArgumentException + } + + private fun sendErrorMessage(errorResponse: ErrorResponse): Message { + val headers = StompHeaderAccessor.create(StompCommand.ERROR).apply { + message = errorResponse.message + } + return try { + val json = objectMapper.writeValueAsString(errorResponse) + MessageBuilder.createMessage(json.toByteArray(StandardCharsets.UTF_8), headers.messageHeaders) + } catch (e: JsonProcessingException) { + MessageBuilder.createMessage(errorResponse.message.toByteArray(StandardCharsets.UTF_8), headers.messageHeaders) + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/util/SecurityUtils.kt b/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/util/SecurityUtils.kt new file mode 100644 index 000000000..ed81b2875 --- /dev/null +++ b/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/util/SecurityUtils.kt @@ -0,0 +1,13 @@ +package com.seugi.api.domain.chat.presentation.websocket.util + +import com.seugi.api.global.auth.jwt.JwtUserDetails +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken +import java.security.Principal + +object SecurityUtils { + + fun getUserId(principal: Principal?): Long { + return (principal as? UsernamePasswordAuthenticationToken)?.principal.let { it as? JwtUserDetails }?.member?.id?.value + ?: -1 + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/config/RabbitMQConfig.kt b/src/main/kotlin/com/seugi/api/global/config/RabbitMQConfig.kt similarity index 96% rename from src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/config/RabbitMQConfig.kt rename to src/main/kotlin/com/seugi/api/global/config/RabbitMQConfig.kt index c0add3b53..29708a5f3 100644 --- a/src/main/kotlin/com/seugi/api/domain/chat/presentation/websocket/config/RabbitMQConfig.kt +++ b/src/main/kotlin/com/seugi/api/global/config/RabbitMQConfig.kt @@ -1,4 +1,4 @@ -package com.seugi.api.domain.chat.presentation.websocket.config +package com.seugi.api.global.config import org.springframework.amqp.core.BindingBuilder import org.springframework.amqp.core.Queue diff --git a/src/main/kotlin/com/seugi/api/global/exception/CustomSocketExceptionHandler.kt b/src/main/kotlin/com/seugi/api/global/exception/CustomSocketExceptionHandler.kt new file mode 100644 index 000000000..9a3e215e6 --- /dev/null +++ b/src/main/kotlin/com/seugi/api/global/exception/CustomSocketExceptionHandler.kt @@ -0,0 +1,47 @@ +package com.seugi.api.global.exception + +import com.seugi.api.global.response.ErrorResponse +import org.springframework.messaging.handler.annotation.MessageExceptionHandler +import org.springframework.messaging.Message +import org.springframework.messaging.simp.SimpMessagingTemplate +import org.springframework.messaging.simp.stomp.StompHeaderAccessor +import org.springframework.web.bind.annotation.ControllerAdvice +import java.io.IOException +import java.net.SocketException +import java.security.Principal + +@ControllerAdvice +class CustomSocketExceptionHandler( + private val template: SimpMessagingTemplate +) { + + private val bindingUrl = "/queue/errors" + + @MessageExceptionHandler(SocketException::class) + fun handleSocketException(message: Message<*>, principal: Principal, ex: Exception) { + removeSession(message) + template.convertAndSendToUser( + principal.name, + bindingUrl, + ErrorResponse(status = 4500, message = ex.cause?.message ?: "Socket Error") + ) + } + + @MessageExceptionHandler(RuntimeException::class) + fun handleRuntimeException(principal: Principal, ex: Exception) { + template.convertAndSendToUser( + principal.name, + bindingUrl, + ErrorResponse(status = 4500, message = ex.cause?.stackTraceToString() ?: "Socket Error") + ) + } + + @Throws(IOException::class) + private fun removeSession(message: Message<*>) { + val stompHeaderAccessor = StompHeaderAccessor.wrap(message) + val sessionId = stompHeaderAccessor.sessionId + stompHeaderAccessor.sessionAttributes?.remove(sessionId) + } + + +} \ No newline at end of file diff --git a/src/main/kotlin/com/seugi/api/global/response/BaseResponse.kt b/src/main/kotlin/com/seugi/api/global/response/BaseResponse.kt index ffbe3c4d2..d204b3508 100644 --- a/src/main/kotlin/com/seugi/api/global/response/BaseResponse.kt +++ b/src/main/kotlin/com/seugi/api/global/response/BaseResponse.kt @@ -7,13 +7,13 @@ import org.springframework.http.HttpStatus @JsonInclude(JsonInclude.Include.NON_NULL) data class BaseResponse( - val status: Int = HttpStatus.OK.value(), - val success: Boolean = true, - val state: String? = "OK", - val message: String, - val data: T? = null + override val status: Int = HttpStatus.OK.value(), + override val success: Boolean = true, + override val state: String = "OK", + override val message: String, + val data: T? = null, -) { + ) : ResponseInterface { // errorResponse constructor constructor(code: CustomErrorCode) : this( diff --git a/src/main/kotlin/com/seugi/api/global/response/ErrorResponse.kt b/src/main/kotlin/com/seugi/api/global/response/ErrorResponse.kt new file mode 100644 index 000000000..1dc9942fc --- /dev/null +++ b/src/main/kotlin/com/seugi/api/global/response/ErrorResponse.kt @@ -0,0 +1,11 @@ +package com.seugi.api.global.response + +import com.fasterxml.jackson.annotation.JsonInclude + +@JsonInclude(JsonInclude.Include.NON_NULL) +data class ErrorResponse( + override val status: Int = 4500, + override val success: Boolean = false, + override val state: String = "Error", + override val message: String, +) : ResponseInterface diff --git a/src/main/kotlin/com/seugi/api/global/response/ResponseInterface.kt b/src/main/kotlin/com/seugi/api/global/response/ResponseInterface.kt new file mode 100644 index 000000000..1bb82a47f --- /dev/null +++ b/src/main/kotlin/com/seugi/api/global/response/ResponseInterface.kt @@ -0,0 +1,8 @@ +package com.seugi.api.global.response + +interface ResponseInterface { + val status: Int + val success: Boolean + val state: String + val message: String +} \ No newline at end of file