Skip to content

Commit

Permalink
Merge pull request #345 from apeun-gidaechi/featrue/#333
Browse files Browse the repository at this point in the history
Feature :: 소켓 애러 핸들링 #333
  • Loading branch information
yeseong0412 authored Nov 23, 2024
2 parents 7e39838 + 75ef605 commit 4e2c2c9
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 44 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package com.seugi.api.domain.chat.presentation.websocket.config

import com.seugi.api.domain.chat.application.service.chat.room.ChatRoomService
import com.seugi.api.domain.chat.presentation.websocket.handler.StompErrorHandler
import com.seugi.api.domain.chat.presentation.websocket.util.SecurityUtils
import com.seugi.api.domain.member.application.exception.MemberErrorCode
import com.seugi.api.global.auth.jwt.JwtUserDetails
import com.seugi.api.global.auth.jwt.JwtUtils
import com.seugi.api.global.exception.CustomException
import org.springframework.beans.factory.annotation.Value
Expand All @@ -15,7 +16,6 @@ import org.springframework.messaging.simp.config.ChannelRegistration
import org.springframework.messaging.simp.config.MessageBrokerRegistry
import org.springframework.messaging.simp.stomp.StompHeaderAccessor
import org.springframework.messaging.support.ChannelInterceptor
import org.springframework.messaging.support.MessageBuilder
import org.springframework.messaging.support.MessageHeaderAccessor
import org.springframework.util.AntPathMatcher
import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker
Expand All @@ -29,11 +29,13 @@ class StompWebSocketConfig(
private val jwtUtils: JwtUtils,
private val chatRoomService: ChatRoomService,
@Value("\${spring.rabbitmq.host}") private val rabbitmqHost: String,
private val stompErrorHandler: StompErrorHandler
) : WebSocketMessageBrokerConfigurer {

override fun registerStompEndpoints(registry: StompEndpointRegistry) {
registry.addEndpoint("/stomp/chat")
.setAllowedOrigins("*")
registry.setErrorHandler(stompErrorHandler)
}

override fun configureMessageBroker(registry: MessageBrokerRegistry) {
Expand All @@ -42,6 +44,7 @@ class StompWebSocketConfig(
registry.enableStompBrokerRelay("/queue", "/topic", "/exchange", "/amq/queue")
.setRelayHost(rabbitmqHost)
.setVirtualHost("/")
registry.setUserDestinationPrefix("/user")
}

override fun configureClientInboundChannel(registration: ChannelRegistration) {
Expand All @@ -50,56 +53,46 @@ class StompWebSocketConfig(
val accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor::class.java)!!

when (accessor.messageType) {
SimpMessageType.CONNECT -> handleConnect(message, accessor)
SimpMessageType.CONNECT -> handleConnect(accessor)
SimpMessageType.SUBSCRIBE -> handleSubscribe(accessor)
SimpMessageType.UNSUBSCRIBE, SimpMessageType.DISCONNECT -> handleUnsubscribeOrDisconnect()
SimpMessageType.UNSUBSCRIBE, SimpMessageType.DISCONNECT -> handleUnsubscribeOrDisconnect(accessor)
else -> {}
}
return message
}
})
}

private fun handleConnect(message: Message<*>, accessor: StompHeaderAccessor) {
private fun handleConnect(accessor: StompHeaderAccessor) {
val authToken = accessor.getNativeHeader("Authorization")?.firstOrNull()
if (authToken != null && authToken.startsWith("Bearer ")) {
val auth = jwtUtils.getAuthentication(authToken)
val userDetails = auth.principal as? JwtUserDetails
val userId: String? = userDetails?.id?.value?.toString()

if (userId != null) {
val simpAttributes = SimpAttributesContextHolder.currentAttributes()
simpAttributes.setAttribute("user-id", userId)
MessageBuilder.createMessage(message.payload, accessor.messageHeaders)
} else {
throw CustomException(MemberErrorCode.MEMBER_NOT_FOUND)
}
accessor.user = auth
} else {
throw CustomException(MemberErrorCode.MEMBER_NOT_FOUND)
}
}

private fun handleSubscribe(accessor: StompHeaderAccessor) {
accessor.destination?.let {
val simpAttributes = SimpAttributesContextHolder.currentAttributes()
simpAttributes.setAttribute("sub", it.substringAfterLast("."))
val userId = simpAttributes.getAttribute("user-id") as String
chatRoomService.sub(
userId = userId.toLong(),
roomId = it.substringAfterLast(".")
)
if (it.contains(".")) {
chatRoomService.sub(
userId = SecurityUtils.getUserId(accessor.user),
roomId = it.substringAfterLast(".")
)
}
}
}

private fun handleUnsubscribeOrDisconnect() {
val simpAttributes = SimpAttributesContextHolder.currentAttributes()
val userId = simpAttributes.getAttribute("user-id") as String?
val roomId = simpAttributes.getAttribute("sub") as String?
userId?.let {
roomId?.let {
chatRoomService.unSub(
userId = userId.toLong(),
roomId = it
)
}
private fun handleUnsubscribeOrDisconnect(accessor: StompHeaderAccessor) {
accessor.destination?.let {
val simpAttributes = SimpAttributesContextHolder.currentAttributes()
chatRoomService.unSub(
userId = SecurityUtils.getUserId(accessor.user),
roomId = simpAttributes.getAttribute("sub").toString()
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,20 @@ package com.seugi.api.domain.chat.presentation.websocket.controller

import com.seugi.api.domain.chat.application.service.message.MessageService
import com.seugi.api.domain.chat.presentation.websocket.dto.ChatMessageDto
import com.seugi.api.domain.chat.presentation.websocket.util.SecurityUtils
import org.springframework.messaging.handler.annotation.MessageMapping
import org.springframework.messaging.simp.SimpAttributesContextHolder
import org.springframework.stereotype.Controller
import java.security.Principal


@Controller
class StompRabbitMQController(
private val messageService: MessageService
private val messageService: MessageService,
) {

@MessageMapping("chat.message")
fun send(chat: ChatMessageDto) {
val simpAttributes = SimpAttributesContextHolder.currentAttributes()
val userId = simpAttributes.getAttribute("user-id") as String?
messageService.sendAndSaveMessage(chat, userId!!.toLong())
fun send(chat: ChatMessageDto, principal: Principal) {
messageService.sendAndSaveMessage(chat, SecurityUtils.getUserId(principal))
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package com.seugi.api.domain.chat.presentation.websocket.handler

import com.fasterxml.jackson.core.JsonProcessingException
import com.fasterxml.jackson.databind.ObjectMapper
import com.seugi.api.global.response.ErrorResponse
import io.jsonwebtoken.ExpiredJwtException
import io.jsonwebtoken.MalformedJwtException
import io.jsonwebtoken.UnsupportedJwtException
import org.springframework.context.annotation.Configuration
import org.springframework.messaging.Message
import org.springframework.messaging.MessageDeliveryException
import org.springframework.messaging.simp.stomp.StompCommand
import org.springframework.messaging.simp.stomp.StompHeaderAccessor
import org.springframework.messaging.support.MessageBuilder
import org.springframework.security.access.AccessDeniedException
import org.springframework.web.socket.messaging.StompSubProtocolErrorHandler
import java.nio.charset.StandardCharsets
import java.security.SignatureException

@Configuration
class StompErrorHandler(private val objectMapper: ObjectMapper) : StompSubProtocolErrorHandler() {

override fun handleClientMessageProcessingError(clientMessage: Message<ByteArray>?, ex: Throwable): Message<ByteArray>? {

return when (ex) {
is MessageDeliveryException -> {
when (val cause = ex.cause) {
is AccessDeniedException -> {
sendErrorMessage(ErrorResponse(status = 4403, message = "Access denied"))
}
else -> {
if (isJwtException(cause)) {
sendErrorMessage(ErrorResponse(status = 4403, message = cause?.message ?: "JWT Exception"))
} else {
sendErrorMessage(ErrorResponse(status = 4403, message = cause?.stackTraceToString() ?: "Unhandled exception"))
}
}
}
}
else -> {
sendErrorMessage(ErrorResponse(status = 4400, message = ex.message ?: "Unhandled root exception"))
}
}
}

private fun isJwtException(ex: Throwable?): Boolean {
return ex is SignatureException || ex is ExpiredJwtException || ex is MalformedJwtException || ex is UnsupportedJwtException || ex is IllegalArgumentException
}

private fun sendErrorMessage(errorResponse: ErrorResponse): Message<ByteArray> {
val headers = StompHeaderAccessor.create(StompCommand.ERROR).apply {
message = errorResponse.message
}
return try {
val json = objectMapper.writeValueAsString(errorResponse)
MessageBuilder.createMessage(json.toByteArray(StandardCharsets.UTF_8), headers.messageHeaders)
} catch (e: JsonProcessingException) {
MessageBuilder.createMessage(errorResponse.message.toByteArray(StandardCharsets.UTF_8), headers.messageHeaders)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.seugi.api.domain.chat.presentation.websocket.util

import com.seugi.api.global.auth.jwt.JwtUserDetails
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken
import java.security.Principal

object SecurityUtils {

fun getUserId(principal: Principal?): Long {
return (principal as? UsernamePasswordAuthenticationToken)?.principal.let { it as? JwtUserDetails }?.member?.id?.value
?: -1
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.seugi.api.domain.chat.presentation.websocket.config
package com.seugi.api.global.config

import org.springframework.amqp.core.BindingBuilder
import org.springframework.amqp.core.Queue
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.seugi.api.global.exception

import com.seugi.api.global.response.ErrorResponse
import org.springframework.messaging.handler.annotation.MessageExceptionHandler
import org.springframework.messaging.Message
import org.springframework.messaging.simp.SimpMessagingTemplate
import org.springframework.messaging.simp.stomp.StompHeaderAccessor
import org.springframework.web.bind.annotation.ControllerAdvice
import java.io.IOException
import java.net.SocketException
import java.security.Principal

@ControllerAdvice
class CustomSocketExceptionHandler(
private val template: SimpMessagingTemplate
) {

private val bindingUrl = "/queue/errors"

@MessageExceptionHandler(SocketException::class)
fun handleSocketException(message: Message<*>, principal: Principal, ex: Exception) {
removeSession(message)
template.convertAndSendToUser(
principal.name,
bindingUrl,
ErrorResponse(status = 4500, message = ex.cause?.message ?: "Socket Error")
)
}

@MessageExceptionHandler(RuntimeException::class)
fun handleRuntimeException(principal: Principal, ex: Exception) {
template.convertAndSendToUser(
principal.name,
bindingUrl,
ErrorResponse(status = 4500, message = ex.cause?.stackTraceToString() ?: "Socket Error")
)
}

@Throws(IOException::class)
private fun removeSession(message: Message<*>) {
val stompHeaderAccessor = StompHeaderAccessor.wrap(message)
val sessionId = stompHeaderAccessor.sessionId
stompHeaderAccessor.sessionAttributes?.remove(sessionId)
}


}
12 changes: 6 additions & 6 deletions src/main/kotlin/com/seugi/api/global/response/BaseResponse.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ import org.springframework.http.HttpStatus
@JsonInclude(JsonInclude.Include.NON_NULL)
data class BaseResponse<T>(

val status: Int = HttpStatus.OK.value(),
val success: Boolean = true,
val state: String? = "OK",
val message: String,
val data: T? = null
override val status: Int = HttpStatus.OK.value(),
override val success: Boolean = true,
override val state: String = "OK",
override val message: String,
val data: T? = null,

) {
) : ResponseInterface {

// errorResponse constructor
constructor(code: CustomErrorCode) : this(
Expand Down
11 changes: 11 additions & 0 deletions src/main/kotlin/com/seugi/api/global/response/ErrorResponse.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.seugi.api.global.response

import com.fasterxml.jackson.annotation.JsonInclude

@JsonInclude(JsonInclude.Include.NON_NULL)
data class ErrorResponse(
override val status: Int = 4500,
override val success: Boolean = false,
override val state: String = "Error",
override val message: String,
) : ResponseInterface
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.seugi.api.global.response

interface ResponseInterface {
val status: Int
val success: Boolean
val state: String
val message: String
}

0 comments on commit 4e2c2c9

Please sign in to comment.