Skip to content

Commit

Permalink
Merge pull request #139 from kitakkun/feature/session_negotiation
Browse files Browse the repository at this point in the history
[Ktor WebSocket] SessionId Negotiation
  • Loading branch information
kitakkun authored Dec 29, 2024
2 parents 0167ddc + 6e7caba commit ba4597c
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ import com.kitakkun.backintime.core.runtime.connector.BackInTimeWebSocketConnect
import com.kitakkun.backintime.core.runtime.event.BackInTimeDebuggableInstanceEvent
import com.kitakkun.backintime.core.websocket.event.BackInTimeDebugServiceEvent
import com.kitakkun.backintime.core.websocket.event.BackInTimeDebuggerEvent
import kotlinx.coroutines.*
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancelChildren
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import kotlinx.datetime.Clock
import kotlinx.serialization.SerializationException

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
package com.kitakkun.backintime.core.runtime.connector

import com.kitakkun.backintime.core.websocket.client.BackInTimeWebSocketClient
import com.kitakkun.backintime.core.websocket.client.BackInTimeWebSocketClientEvent
import com.kitakkun.backintime.core.websocket.event.BackInTimeDebugServiceEvent
import com.kitakkun.backintime.core.websocket.event.BackInTimeDebuggerEvent
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filterIsInstance
import kotlinx.coroutines.flow.flow

/**
* This class is responsible for sending and receiving events
Expand All @@ -22,12 +19,7 @@ class BackInTimeKtorWebSocketConnector(

override suspend fun connect(): Flow<BackInTimeDebuggerEvent> {
client.openSession()

return flow {
client.clientEventFlow.filterIsInstance<BackInTimeWebSocketClientEvent.ReceiveDebuggerEvent>().collect {
emit(it.debuggerEvent)
}
}
return client.receivedDebuggerEventFlow
}

override suspend fun sendEventToDebugger(event: BackInTimeDebugServiceEvent) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.kitakkun.backintime.core.websocket.client

import com.kitakkun.backintime.core.websocket.event.BackInTimeDebugServiceEvent
import com.kitakkun.backintime.core.websocket.event.BackInTimeDebuggerEvent
import com.kitakkun.backintime.core.websocket.event.BackInTimeSessionNegotiationEvent
import io.ktor.client.HttpClient
import io.ktor.client.engine.HttpClientEngine
import io.ktor.client.engine.cio.CIO
Expand All @@ -19,12 +20,8 @@ import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.serialization.json.Json

sealed interface BackInTimeWebSocketClientEvent {
data class ReceiveDebuggerEvent(val debuggerEvent: BackInTimeDebuggerEvent) : BackInTimeWebSocketClientEvent
data object CloseSuccessfully : BackInTimeWebSocketClientEvent
data class CloseWithError(val error: Throwable) : BackInTimeWebSocketClientEvent
}
import kotlin.coroutines.resume
import kotlin.coroutines.suspendCoroutine

/**
* This class is responsible for connecting to the back-in-time debugger server
Expand All @@ -35,10 +32,12 @@ class BackInTimeWebSocketClient(
engine: HttpClientEngine = CIO.create(),
client: HttpClient? = null,
) {
private var sessionId: String? = null

private var session: DefaultClientWebSocketSession? = null

private val mutableClientEventFlow = MutableSharedFlow<BackInTimeWebSocketClientEvent>()
val clientEventFlow = mutableClientEventFlow.asSharedFlow()
private val mutableReceivedDebuggerEventFlow = MutableSharedFlow<BackInTimeDebuggerEvent>()
val receivedDebuggerEventFlow = mutableReceivedDebuggerEventFlow.asSharedFlow()

private val eventDispatchQueueMutex = Mutex()
private val eventDispatchQueue = mutableListOf<BackInTimeDebugServiceEvent>()
Expand All @@ -50,40 +49,50 @@ class BackInTimeWebSocketClient(
}
}

private suspend fun DefaultClientWebSocketSession.handleSession() {
val receiveJob = launch {
while (true) {
val debuggerEvent = receiveDeserialized<BackInTimeDebuggerEvent>()
mutableClientEventFlow.emit(BackInTimeWebSocketClientEvent.ReceiveDebuggerEvent(debuggerEvent))
}
}
private suspend fun DefaultClientWebSocketSession.setupSessionHandling() {
suspendCoroutine {
launch {
// sessionId negotiation
sendSerialized(BackInTimeSessionNegotiationEvent.Request(sessionId))
sessionId = receiveDeserialized<BackInTimeSessionNegotiationEvent.Accept>().sessionId

clientLog("sessionId: $sessionId")

val sendJob = launch {
while (true) {
delay(500)
eventDispatchQueueMutex.withLock {
eventDispatchQueue.forEach { sendSerialized(it) }
eventDispatchQueue.clear()
}
}
}

val sendJob = launch {
while (true) {
delay(500)
eventDispatchQueueMutex.withLock {
eventDispatchQueue.forEach { sendSerialized(it) }
eventDispatchQueue.clear()
val receiveJob = launch {
while (true) {
val debuggerEvent = receiveDeserialized<BackInTimeDebuggerEvent>()
mutableReceivedDebuggerEventFlow.emit(debuggerEvent)
}
}
}
}

closeReason.invokeOnCompletion { error ->
session = null
closeReason.invokeOnCompletion { error ->
if (error == null) {
clientLog("session closed successfully")
} else {
clientLog("session closed with error: $error")
}

receiveJob.cancel()
sendJob.cancel()
session = null

val event = error?.let {
BackInTimeWebSocketClientEvent.CloseWithError(it)
} ?: BackInTimeWebSocketClientEvent.CloseSuccessfully
sendJob.cancel()
receiveJob.cancel()
}

launch {
mutableClientEventFlow.emit(event)
println("client is ready")
it.resume(Unit) // setup completed: meaning that client is ready.
closeReason.await()
}
}

closeReason.await()
}

suspend fun awaitClose() {
Expand All @@ -95,10 +104,8 @@ class BackInTimeWebSocketClient(
host = host,
port = port,
path = "/backintime",
)

session?.launch {
session?.handleSession()
).also {
it.setupSessionHandling()
}
}

Expand All @@ -110,4 +117,8 @@ class BackInTimeWebSocketClient(
session?.close()
session = null
}

private fun clientLog(message: String) {
println("[${this::class.simpleName}] $message")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.kitakkun.backintime.core.websocket.client

import com.kitakkun.backintime.core.websocket.event.BackInTimeDebugServiceEvent
import com.kitakkun.backintime.core.websocket.event.BackInTimeDebuggerEvent
import com.kitakkun.backintime.core.websocket.event.BackInTimeSessionNegotiationEvent
import io.ktor.serialization.kotlinx.KotlinxWebsocketSerializationConverter
import io.ktor.server.application.install
import io.ktor.server.engine.connector
Expand All @@ -13,15 +14,15 @@ import io.ktor.server.websocket.receiveDeserialized
import io.ktor.server.websocket.sendSerialized
import io.ktor.server.websocket.webSocket
import io.ktor.websocket.close
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.filterIsInstance
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.runTest
import kotlinx.serialization.json.Json
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid

// FIXME: This test fails for native targets
class BackInTimeWebSocketClientTest {
Expand All @@ -30,6 +31,7 @@ class BackInTimeWebSocketClientTest {
private const val TEST_PORT = 50026
}

@OptIn(ExperimentalUuidApi::class)
private fun ApplicationTestBuilder.configureServer(
host: String,
port: Int,
Expand All @@ -49,7 +51,32 @@ class BackInTimeWebSocketClientTest {
routing {
webSocket(
path = "/backintime",
handler = serverSession,
handler = {
println("New websocket session established!")
println("waiting sessionId negotiation request from the client...")

val requestedSessionId = receiveDeserialized<BackInTimeSessionNegotiationEvent.Request>().sessionId
println("starting sessionId negotiation...")

if (requestedSessionId == null) {
println("requested sessionId is null. generating new sessionId...")

val sessionId = Uuid.random().toString()
println("generated new sessionId: $sessionId")

sendSerialized(BackInTimeSessionNegotiationEvent.Accept(sessionId))
} else {
sendSerialized(BackInTimeSessionNegotiationEvent.Accept(requestedSessionId))
}

println("sessionId negotiation completed!")
println("start server session...")

serverSession()

println("keeping session active...")
this.closeReason.await()
}
)
}
}
Expand All @@ -64,7 +91,6 @@ class BackInTimeWebSocketClientTest {

assertFailsWith(Throwable::class) {
client.openSession()
client.awaitClose()
}
}

Expand All @@ -73,7 +99,7 @@ class BackInTimeWebSocketClientTest {
configureServer(
host = TEST_HOST,
port = TEST_PORT,
serverSession = { /* Do nothing */ },
serverSession = { close() },
)

val client = BackInTimeWebSocketClient(
Expand All @@ -83,7 +109,7 @@ class BackInTimeWebSocketClientTest {
)

client.openSession()
client.awaitClose()
client.close()
}

@Test
Expand All @@ -94,7 +120,7 @@ class BackInTimeWebSocketClientTest {
host = TEST_HOST,
port = TEST_PORT,
serverSession = {
serverReceivedEvent = receiveDeserialized<BackInTimeDebugServiceEvent>()
serverReceivedEvent = receiveDeserialized()
close()
},
)
Expand All @@ -115,14 +141,12 @@ class BackInTimeWebSocketClientTest {

@Test
fun `test success to receive event`() = testApplication {
var clientReceivedEvent: BackInTimeDebuggerEvent? = null

configureServer(
host = TEST_HOST,
port = TEST_PORT,
serverSession = {
sendSerialized<BackInTimeDebuggerEvent>(BackInTimeDebuggerEvent.Ping)
delay(100) // need this to pass the test
close()
},
serverSession = { sendSerialized<BackInTimeDebuggerEvent>(BackInTimeDebuggerEvent.Ping) },
)

val client = BackInTimeWebSocketClient(
Expand All @@ -133,13 +157,14 @@ class BackInTimeWebSocketClientTest {

runTest {
launch {
assertEquals(
expected = BackInTimeDebuggerEvent.Ping,
actual = client.clientEventFlow.filterIsInstance<BackInTimeWebSocketClientEvent.ReceiveDebuggerEvent>().first().debuggerEvent
)
clientReceivedEvent = client.receivedDebuggerEventFlow.first()
client.close()
}

client.openSession()
client.awaitClose()

assertEquals(expected = BackInTimeDebuggerEvent.Ping, actual = clientReceivedEvent)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.kitakkun.backintime.core.websocket.event

import kotlinx.serialization.Serializable

sealed interface BackInTimeSessionNegotiationEvent {
@Serializable
data class Request(val sessionId: String?) : BackInTimeSessionNegotiationEvent

@Serializable
data class Accept(val sessionId: String) : BackInTimeSessionNegotiationEvent
}
Loading

0 comments on commit ba4597c

Please sign in to comment.