diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index ab19c0aa..21085930 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -8,6 +8,12 @@ kotlinx-bcv = "0.14.0" ktor = "2.3.11" +netty = "4.1.110.Final" +netty-quic = "0.0.63.Final" + +# for netty TLS tests +bouncycastle = "1.78.1" + turbine = "1.1.0" rsocket-java = "1.1.3" @@ -39,6 +45,12 @@ ktor-server-cio = { module = "io.ktor:ktor-server-cio", version.ref = "ktor" } ktor-server-netty = { module = "io.ktor:ktor-server-netty", version.ref = "ktor" } ktor-server-jetty = { module = "io.ktor:ktor-server-jetty", version.ref = "ktor" } +netty-handler = { module = "io.netty:netty-handler", version.ref = "netty" } +netty-codec-http = { module = "io.netty:netty-codec-http", version.ref = "netty" } +netty-codec-quic = { module = "io.netty.incubator:netty-incubator-codec-native-quic", version.ref = "netty-quic" } + +bouncycastle = { module = "org.bouncycastle:bcpkix-jdk18on", version.ref = "bouncycastle" } + turbine = { module = "app.cash.turbine:turbine", version.ref = "turbine" } rsocket-java-core = { module = 'io.rsocket:rsocket-core', version.ref = "rsocket-java" } diff --git a/rsocket-transports/netty-internal/api/rsocket-transport-netty-internal.api b/rsocket-transports/netty-internal/api/rsocket-transport-netty-internal.api new file mode 100644 index 00000000..66fea981 --- /dev/null +++ b/rsocket-transports/netty-internal/api/rsocket-transport-netty-internal.api @@ -0,0 +1,8 @@ +public final class io/rsocket/kotlin/transport/netty/internal/CoroutinesKt { + public static final fun awaitChannel (Lio/netty/channel/ChannelFuture;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static final fun awaitFuture (Lio/netty/util/concurrent/Future;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static final fun callOnCancellation (Lkotlinx/coroutines/CoroutineScope;Lkotlin/jvm/functions/Function1;)V + public static final fun toByteBuf (Lio/ktor/utils/io/core/ByteReadPacket;)Lio/netty/buffer/ByteBuf; + public static final fun toByteReadPacket (Lio/netty/buffer/ByteBuf;)Lio/ktor/utils/io/core/ByteReadPacket; +} + diff --git a/rsocket-transports/netty-internal/build.gradle.kts b/rsocket-transports/netty-internal/build.gradle.kts new file mode 100644 index 00000000..42d9382e --- /dev/null +++ b/rsocket-transports/netty-internal/build.gradle.kts @@ -0,0 +1,35 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import rsocketbuild.* + +plugins { + id("rsocketbuild.multiplatform-library") +} + +description = "rsocket-kotlin Netty transport utils" + +kotlin { + jvmTarget() + + sourceSets { + jvmMain.dependencies { + implementation(projects.rsocketInternalIo) + api(projects.rsocketCore) + api(libs.netty.handler) + } + } +} diff --git a/rsocket-transports/netty-internal/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/internal/coroutines.kt b/rsocket-transports/netty-internal/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/internal/coroutines.kt new file mode 100644 index 00000000..25be5e32 --- /dev/null +++ b/rsocket-transports/netty-internal/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/internal/coroutines.kt @@ -0,0 +1,64 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.internal + +import io.ktor.utils.io.core.* +import io.netty.buffer.* +import io.netty.channel.* +import io.netty.util.concurrent.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@Suppress("UNCHECKED_CAST") +public suspend inline fun Future.awaitFuture(): T = suspendCancellableCoroutine { cont -> + addListener { + when { + it.isSuccess -> cont.resume(it.now as T) + else -> cont.resumeWithException(it.cause()) + } + } + cont.invokeOnCancellation { + cancel(true) + } +} + +public suspend fun ChannelFuture.awaitChannel(): Channel { + awaitFuture() + return channel() +} + +// it should be used only for cleanup and so should not really block, only suspend +public inline fun CoroutineScope.callOnCancellation(crossinline block: suspend () -> Unit) { + launch(Dispatchers.Unconfined) { + try { + awaitCancellation() + } catch (cause: Throwable) { + withContext(NonCancellable) { + try { + block() + } catch (suppressed: Throwable) { + cause.addSuppressed(suppressed) + } + } + throw cause + } + } +} + +// TODO: what to use: this or ByteReadPacket(msg.nioBuffer()) +public fun ByteBuf.toByteReadPacket(): ByteReadPacket = buildPacket { writeFully(nioBuffer()) } +public fun ByteReadPacket.toByteBuf(): ByteBuf = Unpooled.wrappedBuffer(readByteBuffer()) diff --git a/rsocket-transports/netty-quic/api/rsocket-transport-netty-quic.api b/rsocket-transports/netty-quic/api/rsocket-transport-netty-quic.api new file mode 100644 index 00000000..bc42f673 --- /dev/null +++ b/rsocket-transports/netty-quic/api/rsocket-transport-netty-quic.api @@ -0,0 +1,43 @@ +public abstract interface class io/rsocket/kotlin/transport/netty/quic/NettyQuicClientTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/netty/quic/NettyQuicClientTransport$Factory; + public abstract fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public abstract fun target (Ljava/net/InetSocketAddress;)Lio/rsocket/kotlin/transport/RSocketClientTarget; +} + +public final class io/rsocket/kotlin/transport/netty/quic/NettyQuicClientTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/netty/quic/NettyQuicClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun bootstrap (Lkotlin/jvm/functions/Function1;)V + public abstract fun channel (Lkotlin/reflect/KClass;)V + public abstract fun channelFactory (Lio/netty/channel/ChannelFactory;)V + public abstract fun codec (Lkotlin/jvm/functions/Function1;)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun quicBootstrap (Lkotlin/jvm/functions/Function1;)V + public abstract fun ssl (Lkotlin/jvm/functions/Function1;)V +} + +public abstract interface class io/rsocket/kotlin/transport/netty/quic/NettyQuicServerInstance : io/rsocket/kotlin/transport/RSocketServerInstance { + public abstract fun getLocalAddress ()Ljava/net/InetSocketAddress; +} + +public abstract interface class io/rsocket/kotlin/transport/netty/quic/NettyQuicServerTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/netty/quic/NettyQuicServerTransport$Factory; + public abstract fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public abstract fun target (Ljava/net/InetSocketAddress;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/netty/quic/NettyQuicServerTransport;Ljava/lang/String;IILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/netty/quic/NettyQuicServerTransport;Ljava/net/InetSocketAddress;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; +} + +public final class io/rsocket/kotlin/transport/netty/quic/NettyQuicServerTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/netty/quic/NettyQuicServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun bootstrap (Lkotlin/jvm/functions/Function1;)V + public abstract fun channel (Lkotlin/reflect/KClass;)V + public abstract fun channelFactory (Lio/netty/channel/ChannelFactory;)V + public abstract fun codec (Lkotlin/jvm/functions/Function1;)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun ssl (Lkotlin/jvm/functions/Function1;)V +} + diff --git a/rsocket-transports/netty-quic/build.gradle.kts b/rsocket-transports/netty-quic/build.gradle.kts new file mode 100644 index 00000000..5696ae9d --- /dev/null +++ b/rsocket-transports/netty-quic/build.gradle.kts @@ -0,0 +1,57 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import rsocketbuild.* + +plugins { + id("rsocketbuild.multiplatform-library") +} + +description = "rsocket-kotlin Netty QUIC client/server transport implementation" + +kotlin { + jvmTarget() + + sourceSets { + jvmMain.dependencies { + implementation(projects.rsocketTransportNettyInternal) + implementation(projects.rsocketInternalIo) + api(projects.rsocketCore) + api(libs.netty.handler) + api(libs.netty.codec.quic) + } + jvmTest.dependencies { + implementation(projects.rsocketTransportTests) + implementation(libs.bouncycastle) + implementation(libs.netty.codec.quic.map { + val javaOsName = System.getProperty("os.name") + val javaOsArch = System.getProperty("os.arch") + val suffix = when { + javaOsName.contains("mac", ignoreCase = true) -> "osx" + javaOsName.contains("linux", ignoreCase = true) -> "linux" + javaOsName.contains("windows", ignoreCase = true) -> "windows" + else -> error("Unknown os.name: $javaOsName") + } + "-" + when (javaOsArch) { + "x86_64", "amd64" -> "x86_64" + "arm64", "aarch64" -> "aarch_64" + else -> error("Unknown os.arch: $javaOsArch") + } + "$it:$suffix" + }) + //implementation("ch.qos.logback:logback-classic:1.2.11") + } + } +} diff --git a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicClientTransport.kt b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicClientTransport.kt new file mode 100644 index 00000000..f6c2cef3 --- /dev/null +++ b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicClientTransport.kt @@ -0,0 +1,159 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.quic + +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.* +import io.netty.channel.socket.nio.* +import io.netty.incubator.codec.quic.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.netty.internal.* +import kotlinx.coroutines.* +import java.net.* +import kotlin.coroutines.* +import kotlin.reflect.* + +@OptIn(RSocketTransportApi::class) +public sealed interface NettyQuicClientTransport : RSocketTransport { + public fun target(remoteAddress: InetSocketAddress): RSocketClientTarget + public fun target(host: String, port: Int): RSocketClientTarget + + public companion object Factory : + RSocketTransportFactory(::NettyQuicClientTransportBuilderImpl) +} + +@OptIn(RSocketTransportApi::class) +public sealed interface NettyQuicClientTransportBuilder : RSocketTransportBuilder { + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: Bootstrap.() -> Unit) + public fun codec(block: QuicClientCodecBuilder.() -> Unit) + public fun ssl(block: QuicSslContextBuilder.() -> Unit) + public fun quicBootstrap(block: QuicChannelBootstrap.() -> Unit) +} + +private class NettyQuicClientTransportBuilderImpl : NettyQuicClientTransportBuilder { + private var channelFactory: ChannelFactory? = null + private var eventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = false + private var bootstrap: (Bootstrap.() -> Unit)? = null + private var codec: (QuicClientCodecBuilder.() -> Unit)? = null + private var ssl: (QuicSslContextBuilder.() -> Unit)? = null + private var quicBootstrap: (QuicChannelBootstrap.() -> Unit)? = null + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.eventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: Bootstrap.() -> Unit) { + bootstrap = block + } + + override fun codec(block: QuicClientCodecBuilder.() -> Unit) { + codec = block + } + + override fun ssl(block: QuicSslContextBuilder.() -> Unit) { + ssl = block + } + + override fun quicBootstrap(block: QuicChannelBootstrap.() -> Unit) { + quicBootstrap = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): NettyQuicClientTransport { + val codecHandler = QuicClientCodecBuilder().apply { + // by default, we allow Int.MAX_VALUE of active stream + initialMaxData(Int.MAX_VALUE.toLong()) + initialMaxStreamDataBidirectionalLocal(Int.MAX_VALUE.toLong()) + initialMaxStreamDataBidirectionalRemote(Int.MAX_VALUE.toLong()) + initialMaxStreamsBidirectional(Int.MAX_VALUE.toLong()) + codec?.invoke(this) + ssl?.let { + sslContext(QuicSslContextBuilder.forClient().apply(it).build()) + } + }.build() + val bootstrap = Bootstrap().apply { + bootstrap?.invoke(this) + localAddress(0) + handler(codecHandler) + channelFactory(channelFactory ?: ReflectiveChannelFactory(NioDatagramChannel::class.java)) + group(eventLoopGroup ?: NioEventLoopGroup()) + } + + return NettyQuicClientTransportImpl( + coroutineContext = context.supervisorContext() + bootstrap.config().group().asCoroutineDispatcher(), + bootstrap = bootstrap, + quicBootstrap = quicBootstrap, + manageBootstrap = manageEventLoopGroup + ) + } +} + +private class NettyQuicClientTransportImpl( + override val coroutineContext: CoroutineContext, + private val bootstrap: Bootstrap, + private val quicBootstrap: (QuicChannelBootstrap.() -> Unit)?, + manageBootstrap: Boolean, +) : NettyQuicClientTransport { + init { + if (manageBootstrap) callOnCancellation { + bootstrap.config().group().shutdownGracefully().awaitFuture() + } + } + + override fun target(remoteAddress: InetSocketAddress): NettyQuicClientTargetImpl = NettyQuicClientTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + bootstrap = bootstrap, + quicBootstrap = quicBootstrap, + remoteAddress = remoteAddress + ) + + override fun target(host: String, port: Int): RSocketClientTarget = target(InetSocketAddress(host, port)) +} + +@OptIn(RSocketTransportApi::class) +private class NettyQuicClientTargetImpl( + override val coroutineContext: CoroutineContext, + private val bootstrap: Bootstrap, + private val quicBootstrap: (QuicChannelBootstrap.() -> Unit)?, + private val remoteAddress: SocketAddress, +) : RSocketClientTarget { + @RSocketTransportApi + override fun connectClient(handler: RSocketConnectionHandler): Job = launch { + QuicChannel.newBootstrap(bootstrap.bind().awaitChannel()).also { quicBootstrap?.invoke(it) } + .handler( + NettyQuicConnectionInitializer(handler, coroutineContext, isClient = true) + ).remoteAddress(remoteAddress).connect().awaitFuture() + } +} diff --git a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicConnectionHandler.kt b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicConnectionHandler.kt new file mode 100644 index 00000000..b13ed194 --- /dev/null +++ b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicConnectionHandler.kt @@ -0,0 +1,134 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.quic + +import io.netty.channel.* +import io.netty.channel.socket.* +import io.netty.incubator.codec.quic.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.netty.internal.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import kotlinx.coroutines.channels.Channel +import java.util.concurrent.atomic.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal class NettyQuicConnectionHandler( + private val channel: QuicChannel, + private val handler: RSocketConnectionHandler, + scope: CoroutineScope, + private val isClient: Boolean, +) : ChannelInboundHandlerAdapter() { + private val inbound = Channel(Channel.UNLIMITED) + + private val connectionJob = Job(scope.coroutineContext.job) + private val streamsContext = scope.coroutineContext + SupervisorJob(connectionJob) + + private val handlerJob = scope.launch(connectionJob, start = CoroutineStart.LAZY) { + try { + handler.handleConnection(NettyQuicConnection(channel, inbound, streamsContext, isClient)) + } finally { + inbound.cancel() + withContext(NonCancellable) { + streamsContext.job.cancelAndJoin() + channel.close().awaitFuture() + } + } + } + + override fun channelActive(ctx: ChannelHandlerContext) { + handlerJob.start() + connectionJob.complete() + ctx.pipeline().addLast("rsocket-inbound", NettyQuicConnectionInboundHandler(inbound, streamsContext, isClient)) + + ctx.fireChannelActive() + } + + override fun channelInactive(ctx: ChannelHandlerContext) { + handlerJob.cancel("Channel is not active") + + ctx.fireChannelInactive() + } + + @Suppress("OVERRIDE_DEPRECATION") + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable?) { + handlerJob.cancel("exceptionCaught", cause) + } +} + +// TODO: implement support for isAutoRead=false to support `inbound` backpressure +@RSocketTransportApi +private class NettyQuicConnectionInboundHandler( + private val inbound: SendChannel, + private val streamsContext: CoroutineContext, + private val isClient: Boolean, +) : ChannelInboundHandlerAdapter() { + // Note: QUIC streams could be received unordered, so f.e we could receive first stream with id 4 and then with id 0 + override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { + msg as QuicStreamChannel + val state = NettyQuicStreamState(null) + if (inbound.trySend(state.wrapStream(msg)).isSuccess) { + msg.pipeline().addLast(NettyQuicStreamInitializer(streamsContext, state, isClient)) + } + ctx.fireChannelRead(msg) + } + + override fun userEventTriggered(ctx: ChannelHandlerContext?, evt: Any?) { + if (evt is ChannelInputShutdownEvent) { + inbound.close() + } + super.userEventTriggered(ctx, evt) + } +} + +@RSocketTransportApi +private class NettyQuicConnection( + private val channel: QuicChannel, + private val inbound: ReceiveChannel, + private val streamsContext: CoroutineContext, + private val isClient: Boolean, +) : RSocketMultiplexedConnection { + private val startMarker = Job() + + // we need to `hack` only first stream created for client - stream where frames with streamId=0 will be sent + private val first = AtomicBoolean(isClient) + override suspend fun createStream(): RSocketMultiplexedConnection.Stream { + val startMarker = if (first.getAndSet(false)) { + startMarker + } else { + startMarker.join() + null + } + val state = NettyQuicStreamState(startMarker) + val stream = try { + channel.createStream( + QuicStreamType.BIDIRECTIONAL, + NettyQuicStreamInitializer(streamsContext, state, isClient) + ).awaitFuture() + } catch (cause: Throwable) { + state.closeMarker.complete() + throw cause + } + + return state.wrapStream(stream) + } + + override suspend fun acceptStream(): RSocketMultiplexedConnection.Stream? { + return inbound.receiveCatching().getOrNull() + } +} diff --git a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicConnectionInitializer.kt b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicConnectionInitializer.kt new file mode 100644 index 00000000..5caf07f6 --- /dev/null +++ b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicConnectionInitializer.kt @@ -0,0 +1,37 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.quic + +import io.netty.channel.* +import io.netty.incubator.codec.quic.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal class NettyQuicConnectionInitializer( + private val handler: RSocketConnectionHandler, + override val coroutineContext: CoroutineContext, + private val isClient: Boolean, +) : ChannelInitializer(), CoroutineScope { + override fun initChannel(channel: QuicChannel) { + with(channel.pipeline()) { + //addLast(LoggingHandler(if (isClient) "CLIENT" else "SERVER")) + addLast("rsocket", NettyQuicConnectionHandler(channel, handler, this@NettyQuicConnectionInitializer, isClient)) + } + } +} \ No newline at end of file diff --git a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicServerTransport.kt b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicServerTransport.kt new file mode 100644 index 00000000..4c01e135 --- /dev/null +++ b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicServerTransport.kt @@ -0,0 +1,180 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.quic + +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.* +import io.netty.channel.socket.nio.* +import io.netty.incubator.codec.quic.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.netty.internal.* +import kotlinx.coroutines.* +import java.net.* +import javax.net.ssl.* +import kotlin.coroutines.* +import kotlin.reflect.* + +@OptIn(RSocketTransportApi::class) +public sealed interface NettyQuicServerInstance : RSocketServerInstance { + public val localAddress: InetSocketAddress +} + +@OptIn(RSocketTransportApi::class) +public sealed interface NettyQuicServerTransport : RSocketTransport { + public fun target(localAddress: InetSocketAddress? = null): RSocketServerTarget + public fun target(host: String = "127.0.0.1", port: Int = 0): RSocketServerTarget + + public companion object Factory : + RSocketTransportFactory(::NettyQuicServerTransportBuilderImpl) +} + +@OptIn(RSocketTransportApi::class) +public sealed interface NettyQuicServerTransportBuilder : RSocketTransportBuilder { + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: Bootstrap.() -> Unit) + public fun codec(block: QuicServerCodecBuilder.() -> Unit) + public fun ssl(block: QuicSslContextBuilder.() -> Unit) +} + +private class NettyQuicServerTransportBuilderImpl : NettyQuicServerTransportBuilder { + private var channelFactory: ChannelFactory? = null + private var eventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = false + private var bootstrap: (Bootstrap.() -> Unit)? = null + private var codec: (QuicServerCodecBuilder.() -> Unit)? = null + private var ssl: (QuicSslContextBuilder.() -> Unit)? = null + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.eventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: Bootstrap.() -> Unit) { + bootstrap = block + } + + override fun codec(block: QuicServerCodecBuilder.() -> Unit) { + codec = block + } + + override fun ssl(block: QuicSslContextBuilder.() -> Unit) { + ssl = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): NettyQuicServerTransport { + val codecBuilder = QuicServerCodecBuilder().apply { + // by default, we allow Int.MAX_VALUE of active stream + initialMaxData(Int.MAX_VALUE.toLong()) + initialMaxStreamDataBidirectionalLocal(Int.MAX_VALUE.toLong()) + initialMaxStreamDataBidirectionalRemote(Int.MAX_VALUE.toLong()) + initialMaxStreamsBidirectional(Int.MAX_VALUE.toLong()) + codec?.invoke(this) + ssl?.let { + val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) + sslContext(QuicSslContextBuilder.forServer(keyManagerFactory, null).apply(it).build()) + } + } + + val bootstrap = Bootstrap().apply { + bootstrap?.invoke(this) + channelFactory(channelFactory ?: ReflectiveChannelFactory(NioDatagramChannel::class.java)) + group(eventLoopGroup ?: NioEventLoopGroup()) + } + + return NettyQuicServerTransportImpl( + coroutineContext = context.supervisorContext() + bootstrap.config().group().asCoroutineDispatcher(), + bootstrap = bootstrap, + codecBuilder = codecBuilder, + manageBootstrap = manageEventLoopGroup + ) + } +} + +private class NettyQuicServerTransportImpl( + override val coroutineContext: CoroutineContext, + private val bootstrap: Bootstrap, + private val codecBuilder: QuicServerCodecBuilder, + manageBootstrap: Boolean, +) : NettyQuicServerTransport { + init { + if (manageBootstrap) callOnCancellation { + bootstrap.config().group().shutdownGracefully().awaitFuture() + } + } + + override fun target(localAddress: InetSocketAddress?): NettyQuicServerTargetImpl = NettyQuicServerTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + bootstrap = bootstrap, + codecBuilder = codecBuilder, + localAddress = localAddress ?: InetSocketAddress(0) + ) + + override fun target(host: String, port: Int): RSocketServerTarget = + target(InetSocketAddress(host, port)) +} + +@OptIn(RSocketTransportApi::class) +private class NettyQuicServerTargetImpl( + override val coroutineContext: CoroutineContext, + private val bootstrap: Bootstrap, + private val codecBuilder: QuicServerCodecBuilder, + private val localAddress: SocketAddress, +) : RSocketServerTarget { + @RSocketTransportApi + override suspend fun startServer(handler: RSocketConnectionHandler): NettyQuicServerInstance { + currentCoroutineContext().ensureActive() + coroutineContext.ensureActive() + + val instanceContext = coroutineContext.childContext() + val channel = try { + bootstrap.clone().handler( + codecBuilder.clone().handler( + NettyQuicConnectionInitializer(handler, instanceContext.supervisorContext(), isClient = false) + ).build() + ).bind(localAddress).awaitChannel() + } catch (cause: Throwable) { + instanceContext.job.cancel("Failed to bind", cause) + throw cause + } + + return NettyQuicServerInstanceImpl( + coroutineContext = instanceContext, + localAddress = (channel as DatagramChannel).localAddress() as InetSocketAddress + ) + } +} + +private class NettyQuicServerInstanceImpl( + override val coroutineContext: CoroutineContext, + override val localAddress: InetSocketAddress, +) : NettyQuicServerInstance diff --git a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicStreamHandler.kt b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicStreamHandler.kt new file mode 100644 index 00000000..e624c0b0 --- /dev/null +++ b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicStreamHandler.kt @@ -0,0 +1,162 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.quic + +import io.ktor.utils.io.core.* +import io.netty.buffer.* +import io.netty.channel.* +import io.netty.channel.socket.* +import io.netty.incubator.codec.quic.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.netty.internal.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import kotlinx.coroutines.channels.Channel + +// TODO: first stream is a hack to initiate first stream because of buffering +// quic streams could be received unordered by server, so f.e we could receive first stream with id 4 and then with id 0 +// for this, we disable buffering for first client stream, so that first frame will be sent first +// this will affect performance for this stream, so we need to do something else here. +@RSocketTransportApi +internal class NettyQuicStreamState(val startMarker: CompletableJob?) { + val closeMarker: CompletableJob = Job() + val outbound = channelForCloseable(Channel.BUFFERED) + val inbound = channelForCloseable(Channel.UNLIMITED) + + fun wrapStream(stream: QuicStreamChannel): RSocketMultiplexedConnection.Stream = + NettyQuicStream(stream, outbound, inbound, closeMarker) +} + +@RSocketTransportApi +internal class NettyQuicStreamHandler( + private val channel: QuicStreamChannel, + scope: CoroutineScope, + private val state: NettyQuicStreamState, + private val isClient: Boolean, +) : ChannelInboundHandlerAdapter() { + private val handlerJob = scope.launch(start = CoroutineStart.LAZY) { + val outbound = state.outbound + + val writerJob = launch(start = CoroutineStart.UNDISPATCHED) { + try { + while (true) { + // we write all available frames here, and only after it flush + // in this case, if there are several buffered frames we can send them in one go + // avoiding unnecessary flushes + // TODO: could be optimized to avoid allocation of not-needed promises + + var lastWriteFuture = channel.write(outbound.receiveCatching().getOrNull()?.toByteBuf() ?: break) + while (true) lastWriteFuture = channel.write(outbound.tryReceive().getOrNull()?.toByteBuf() ?: break) + //println("FLUSH: $isClient: ${channel.streamId()}") + channel.flush() + // await writing to respect transport backpressure + lastWriteFuture.awaitFuture() + state.startMarker?.complete() + } + } finally { + withContext(NonCancellable) { + channel.shutdownOutput().awaitFuture() + } + } + }.onCompletion { outbound.cancel() } + + try { + state.closeMarker.join() + } finally { + outbound.close() // will cause `writerJob` completion + // no more reading + state.inbound.cancel() + withContext(NonCancellable) { + writerJob.join() + // TODO: what is the correct way to properly shutdown stream? + channel.shutdownInput().awaitFuture() + channel.close().awaitFuture() + } + } + } + + override fun channelActive(ctx: ChannelHandlerContext) { + handlerJob.start() + ctx.pipeline().addLast("rsocket-inbound", NettyQuicStreamInboundHandler(state.inbound)) + + ctx.fireChannelActive() + } + + override fun channelInactive(ctx: ChannelHandlerContext) { + handlerJob.cancel("Channel is not active") + + ctx.fireChannelInactive() + } + + @Suppress("OVERRIDE_DEPRECATION") + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable?) { + handlerJob.cancel("exceptionCaught", cause) + } +} + +private class NettyQuicStreamInboundHandler( + private val inbound: SendChannel, +) : ChannelInboundHandlerAdapter() { + override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { + msg as ByteBuf + try { + val frame = msg.toByteReadPacket() + if (inbound.trySend(frame).isFailure) { + frame.close() + } + } finally { + msg.release() + } + } + + override fun userEventTriggered(ctx: ChannelHandlerContext?, evt: Any?) { + if (evt is ChannelInputShutdownEvent) { + inbound.close() + } + super.userEventTriggered(ctx, evt) + } +} + +@RSocketTransportApi +private class NettyQuicStream( + // for priority + private val stream: QuicStreamChannel, + private val outbound: SendChannel, + private val inbound: ReceiveChannel, + private val closeMarker: CompletableJob, +) : RSocketMultiplexedConnection.Stream { + + @OptIn(DelicateCoroutinesApi::class) + override val isClosedForSend: Boolean get() = outbound.isClosedForSend + + override fun setSendPriority(priority: Int) { + stream.updatePriority(QuicStreamPriority(priority, false)) + } + + override suspend fun sendFrame(frame: ByteReadPacket) { + outbound.send(frame) + } + + override suspend fun receiveFrame(): ByteReadPacket? { + return inbound.receiveCatching().getOrNull() + } + + override fun close() { + closeMarker.complete() + } +} diff --git a/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicStreamInitializer.kt b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicStreamInitializer.kt new file mode 100644 index 00000000..68c2f7a6 --- /dev/null +++ b/rsocket-transports/netty-quic/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicStreamInitializer.kt @@ -0,0 +1,51 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.quic + +import io.netty.channel.* +import io.netty.handler.codec.* +import io.netty.incubator.codec.quic.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal class NettyQuicStreamInitializer( + override val coroutineContext: CoroutineContext, + private val state: NettyQuicStreamState, + private val isClient: Boolean, +) : ChannelInitializer(), CoroutineScope { + override fun initChannel(channel: QuicStreamChannel): Unit = with(channel.pipeline()) { + addLast( + "rsocket-length-encoder", + LengthFieldPrepender( + /* lengthFieldLength = */ 3 + ) + ) + addLast( + "rsocket-length-decoder", + LengthFieldBasedFrameDecoder( + /* maxFrameLength = */ Int.MAX_VALUE, + /* lengthFieldOffset = */ 0, + /* lengthFieldLength = */ 3, + /* lengthAdjustment = */ 0, + /* initialBytesToStrip = */ 3 + ) + ) + addLast("rsocket", NettyQuicStreamHandler(channel, this@NettyQuicStreamInitializer, state, isClient)) + } +} diff --git a/rsocket-transports/netty-quic/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicTransportTest.kt b/rsocket-transports/netty-quic/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicTransportTest.kt new file mode 100644 index 00000000..18774a16 --- /dev/null +++ b/rsocket-transports/netty-quic/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/quic/NettyQuicTransportTest.kt @@ -0,0 +1,57 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.kotlin.transport.netty.quic + +import io.netty.channel.nio.* +import io.netty.handler.ssl.util.* +import io.netty.incubator.codec.quic.* +import io.rsocket.kotlin.transport.tests.* +import kotlin.concurrent.* + +private val eventLoop = NioEventLoopGroup().also { + Runtime.getRuntime().addShutdownHook(thread(start = false) { + it.shutdownGracefully().await(1000) + }) +} +private val certificates = SelfSignedCertificate() + +private val protos = arrayOf("hq-29") + +class NettyQuicTransportTest : TransportTest() { + override suspend fun before() { + val server = startServer( + NettyQuicServerTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + ssl { + keyManager(certificates.privateKey(), null, certificates.certificate()) + applicationProtocols(*protos) + } + codec { + tokenHandler(InsecureQuicTokenHandler.INSTANCE) + } + }.target("127.0.0.1") + ) + client = connectClient( + NettyQuicClientTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + ssl { + trustManager(InsecureTrustManagerFactory.INSTANCE) + applicationProtocols(*protos) + } + }.target(server.localAddress) + ) + } +} diff --git a/rsocket-transports/netty-tcp/api/rsocket-transport-netty-tcp.api b/rsocket-transports/netty-tcp/api/rsocket-transport-netty-tcp.api new file mode 100644 index 00000000..46cf76c7 --- /dev/null +++ b/rsocket-transports/netty-tcp/api/rsocket-transport-netty-tcp.api @@ -0,0 +1,41 @@ +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport$Factory; + public abstract fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public abstract fun target (Ljava/net/SocketAddress;)Lio/rsocket/kotlin/transport/RSocketClientTarget; +} + +public final class io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun bootstrap (Lkotlin/jvm/functions/Function1;)V + public abstract fun channel (Lkotlin/reflect/KClass;)V + public abstract fun channelFactory (Lio/netty/channel/ChannelFactory;)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun ssl (Lkotlin/jvm/functions/Function1;)V +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerInstance : io/rsocket/kotlin/transport/RSocketServerInstance { + public abstract fun getLocalAddress ()Ljava/net/SocketAddress; +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport$Factory; + public abstract fun target (Ljava/lang/String;I)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public abstract fun target (Ljava/net/SocketAddress;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport;Ljava/lang/String;IILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport;Ljava/net/SocketAddress;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; +} + +public final class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun bootstrap (Lkotlin/jvm/functions/Function1;)V + public abstract fun channel (Lkotlin/reflect/KClass;)V + public abstract fun channelFactory (Lio/netty/channel/ChannelFactory;)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun ssl (Lkotlin/jvm/functions/Function1;)V +} + diff --git a/rsocket-transports/netty-tcp/build.gradle.kts b/rsocket-transports/netty-tcp/build.gradle.kts new file mode 100644 index 00000000..6d85750d --- /dev/null +++ b/rsocket-transports/netty-tcp/build.gradle.kts @@ -0,0 +1,40 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import rsocketbuild.* + +plugins { + id("rsocketbuild.multiplatform-library") +} + +description = "rsocket-kotlin Netty TCP client/server transport implementation" + +kotlin { + jvmTarget() + + sourceSets { + jvmMain.dependencies { + implementation(projects.rsocketTransportNettyInternal) + implementation(projects.rsocketInternalIo) + api(projects.rsocketCore) + api(libs.netty.handler) + } + jvmTest.dependencies { + implementation(projects.rsocketTransportTests) + implementation(libs.bouncycastle) + } + } +} diff --git a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport.kt b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport.kt new file mode 100644 index 00000000..9675af6b --- /dev/null +++ b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpClientTransport.kt @@ -0,0 +1,142 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.* +import io.netty.channel.socket.nio.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.netty.internal.* +import kotlinx.coroutines.* +import java.net.* +import kotlin.coroutines.* +import kotlin.reflect.* + +@OptIn(RSocketTransportApi::class) +public sealed interface NettyTcpClientTransport : RSocketTransport { + public fun target(remoteAddress: SocketAddress): RSocketClientTarget + public fun target(host: String, port: Int): RSocketClientTarget + + public companion object Factory : + RSocketTransportFactory(::NettyTcpClientTransportBuilderImpl) +} + +@OptIn(RSocketTransportApi::class) +public sealed interface NettyTcpClientTransportBuilder : RSocketTransportBuilder { + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: Bootstrap.() -> Unit) + public fun ssl(block: SslContextBuilder.() -> Unit) +} + +private class NettyTcpClientTransportBuilderImpl : NettyTcpClientTransportBuilder { + private var channelFactory: ChannelFactory? = null + private var eventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = true + private var bootstrap: (Bootstrap.() -> Unit)? = null + private var ssl: (SslContextBuilder.() -> Unit)? = null + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.eventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: Bootstrap.() -> Unit) { + bootstrap = block + } + + override fun ssl(block: SslContextBuilder.() -> Unit) { + ssl = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): NettyTcpClientTransport { + val sslContext = ssl?.let { + SslContextBuilder.forClient().apply(it).build() + } + + val bootstrap = Bootstrap().apply { + bootstrap?.invoke(this) + channelFactory(channelFactory ?: ReflectiveChannelFactory(NioSocketChannel::class.java)) + group(eventLoopGroup ?: NioEventLoopGroup()) + } + + return NettyTcpClientTransportImpl( + coroutineContext = context.supervisorContext() + bootstrap.config().group().asCoroutineDispatcher(), + sslContext = sslContext, + bootstrap = bootstrap, + manageBootstrap = manageEventLoopGroup + ) + } +} + +private class NettyTcpClientTransportImpl( + override val coroutineContext: CoroutineContext, + private val sslContext: SslContext?, + private val bootstrap: Bootstrap, + manageBootstrap: Boolean, +) : NettyTcpClientTransport { + init { + if (manageBootstrap) callOnCancellation { + bootstrap.config().group().shutdownGracefully().awaitFuture() + } + } + + override fun target(remoteAddress: SocketAddress): NettyTcpClientTargetImpl = NettyTcpClientTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + bootstrap = bootstrap, + sslContext = sslContext, + remoteAddress = remoteAddress + ) + + override fun target(host: String, port: Int): RSocketClientTarget = target(InetSocketAddress(host, port)) +} + +@OptIn(RSocketTransportApi::class) +private class NettyTcpClientTargetImpl( + override val coroutineContext: CoroutineContext, + private val bootstrap: Bootstrap, + private val sslContext: SslContext?, + private val remoteAddress: SocketAddress, +) : RSocketClientTarget { + @RSocketTransportApi + override fun connectClient(handler: RSocketConnectionHandler): Job = launch { + bootstrap.clone().handler( + NettyTcpConnectionInitializer( + sslContext = sslContext, + remoteAddress = remoteAddress as? InetSocketAddress, + handler = handler, + coroutineContext = coroutineContext + ) + ).connect(remoteAddress).awaitFuture() + } +} diff --git a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpConnectionHandler.kt b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpConnectionHandler.kt new file mode 100644 index 00000000..fa557ab9 --- /dev/null +++ b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpConnectionHandler.kt @@ -0,0 +1,133 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.ktor.utils.io.core.* +import io.netty.buffer.* +import io.netty.channel.* +import io.netty.channel.socket.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.internal.* +import io.rsocket.kotlin.transport.netty.internal.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import kotlinx.coroutines.channels.Channel +import io.netty.channel.socket.DuplexChannel as NettyDuplexChannel + +@RSocketTransportApi +internal class NettyTcpConnectionHandler( + private val channel: NettyDuplexChannel, + private val handler: RSocketConnectionHandler, + scope: CoroutineScope, +) : ChannelInboundHandlerAdapter() { + private val inbound = channelForCloseable(Channel.UNLIMITED) + + private val handlerJob = scope.launch(start = CoroutineStart.LAZY) { + val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED) + + val writerJob = launch { + try { + while (true) { + // we write all available frames here, and only after it flush + // in this case, if there are several buffered frames we can send them in one go + // avoiding unnecessary flushes + // TODO: could be optimized to avoid allocation of not-needed promises + var lastWriteFuture = channel.write(outboundQueue.dequeueFrame()?.toByteBuf() ?: break) + while (true) lastWriteFuture = channel.write(outboundQueue.tryDequeueFrame()?.toByteBuf() ?: break) + channel.flush() + // await writing to respect transport backpressure + lastWriteFuture.awaitFuture() + } + } finally { + withContext(NonCancellable) { + channel.shutdownOutput().awaitFuture() + } + } + }.onCompletion { outboundQueue.cancel() } + + try { + handler.handleConnection(NettyTcpConnection(outboundQueue, inbound)) + } finally { + outboundQueue.close() // will cause `writerJob` completion + // no more reading + inbound.cancel() + withContext(NonCancellable) { + writerJob.join() + channel.close().awaitFuture() + } + } + } + + override fun channelActive(ctx: ChannelHandlerContext) { + handlerJob.start() + ctx.pipeline().addLast("rsocket-inbound", NettyTcpConnectionInboundHandler(inbound)) + + ctx.fireChannelActive() + } + + override fun channelInactive(ctx: ChannelHandlerContext) { + handlerJob.cancel("Channel is not active") + + ctx.fireChannelInactive() + } + + @Suppress("OVERRIDE_DEPRECATION") + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable?) { + handlerJob.cancel("exceptionCaught", cause) + } +} + +// TODO: implement support for isAutoRead=false to support `inbound` backpressure +private class NettyTcpConnectionInboundHandler( + private val inbound: SendChannel, +) : ChannelInboundHandlerAdapter() { + override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { + msg as ByteBuf + try { + val frame = msg.toByteReadPacket() + if (inbound.trySend(frame).isFailure) { + frame.close() + error("inbound is closed") + } + } finally { + msg.release() + } + } + + override fun userEventTriggered(ctx: ChannelHandlerContext?, evt: Any?) { + if (evt is ChannelInputShutdownEvent) { + inbound.close() + } + super.userEventTriggered(ctx, evt) + } +} + +@RSocketTransportApi +private class NettyTcpConnection( + private val outboundQueue: PrioritizationFrameQueue, + private val inbound: ReceiveChannel, +) : RSocketSequentialConnection { + override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend + override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) { + return outboundQueue.enqueueFrame(streamId, frame) + } + + override suspend fun receiveFrame(): ByteReadPacket? { + return inbound.receiveCatching().getOrNull() + } +} diff --git a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpConnectionInitializer.kt b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpConnectionInitializer.kt new file mode 100644 index 00000000..2d040061 --- /dev/null +++ b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpConnectionInitializer.kt @@ -0,0 +1,63 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.netty.channel.* +import io.netty.channel.socket.* +import io.netty.handler.codec.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import java.net.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal class NettyTcpConnectionInitializer( + private val sslContext: SslContext?, + private val remoteAddress: InetSocketAddress?, + private val handler: RSocketConnectionHandler, + override val coroutineContext: CoroutineContext, +) : ChannelInitializer(), CoroutineScope { + override fun initChannel(channel: DuplexChannel): Unit = with(channel.pipeline()) { + if (sslContext != null) { + addLast( + "ssl", + when { + remoteAddress != null -> sslContext.newHandler(channel.alloc(), remoteAddress.hostName, remoteAddress.port) + else -> sslContext.newHandler(channel.alloc()) + } + ) + } + addLast( + "rsocket-length-encoder", + LengthFieldPrepender( + /* lengthFieldLength = */ 3 + ) + ) + addLast( + "rsocket-length-decoder", + LengthFieldBasedFrameDecoder( + /* maxFrameLength = */ kotlin.Int.MAX_VALUE, + /* lengthFieldOffset = */ 0, + /* lengthFieldLength = */ 3, + /* lengthAdjustment = */ 0, + /* initialBytesToStrip = */ 3 + ) + ) + addLast("rsocket", NettyTcpConnectionHandler(channel, handler, this@NettyTcpConnectionInitializer)) + } +} diff --git a/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport.kt b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport.kt new file mode 100644 index 00000000..9566f451 --- /dev/null +++ b/rsocket-transports/netty-tcp/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpServerTransport.kt @@ -0,0 +1,178 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.nio.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.netty.internal.* +import kotlinx.coroutines.* +import java.net.* +import javax.net.ssl.* +import kotlin.coroutines.* +import kotlin.reflect.* + +@OptIn(RSocketTransportApi::class) +public sealed interface NettyTcpServerInstance : RSocketServerInstance { + public val localAddress: SocketAddress +} + +@OptIn(RSocketTransportApi::class) +public sealed interface NettyTcpServerTransport : RSocketTransport { + public fun target(localAddress: SocketAddress? = null): RSocketServerTarget + public fun target(host: String = "0.0.0.0", port: Int = 0): RSocketServerTarget + + public companion object Factory : + RSocketTransportFactory(::NettyTcpServerTransportBuilderImpl) +} + +@OptIn(RSocketTransportApi::class) +public sealed interface NettyTcpServerTransportBuilder : RSocketTransportBuilder { + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(parentGroup: EventLoopGroup, childGroup: EventLoopGroup, manage: Boolean) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: ServerBootstrap.() -> Unit) + public fun ssl(block: SslContextBuilder.() -> Unit) +} + +private class NettyTcpServerTransportBuilderImpl : NettyTcpServerTransportBuilder { + private var channelFactory: ChannelFactory? = null + private var parentEventLoopGroup: EventLoopGroup? = null + private var childEventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = true + private var bootstrap: (ServerBootstrap.() -> Unit)? = null + private var ssl: (SslContextBuilder.() -> Unit)? = null + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(parentGroup: EventLoopGroup, childGroup: EventLoopGroup, manage: Boolean) { + this.parentEventLoopGroup = parentGroup + this.childEventLoopGroup = childGroup + this.manageEventLoopGroup = manage + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.parentEventLoopGroup = group + this.childEventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: ServerBootstrap.() -> Unit) { + bootstrap = block + } + + override fun ssl(block: SslContextBuilder.() -> Unit) { + ssl = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): NettyTcpServerTransport { + val sslContext = ssl?.let { + SslContextBuilder.forServer(KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())).apply(it).build() + } + + val bootstrap = ServerBootstrap().apply { + bootstrap?.invoke(this) + channelFactory(channelFactory ?: ReflectiveChannelFactory(NioServerSocketChannel::class.java)) + group(parentEventLoopGroup ?: NioEventLoopGroup(), childEventLoopGroup ?: NioEventLoopGroup()) + } + + return NettyTcpServerTransportImpl( + coroutineContext = context.supervisorContext() + bootstrap.config().childGroup().asCoroutineDispatcher(), + bootstrap = bootstrap, + sslContext = sslContext, + manageBootstrap = manageEventLoopGroup + ) + } +} + +private class NettyTcpServerTransportImpl( + override val coroutineContext: CoroutineContext, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, + manageBootstrap: Boolean, +) : NettyTcpServerTransport { + init { + if (manageBootstrap) callOnCancellation { + bootstrap.config().childGroup().shutdownGracefully().awaitFuture() + bootstrap.config().group().shutdownGracefully().awaitFuture() + } + } + + override fun target(localAddress: SocketAddress?): NettyTcpServerTargetImpl = NettyTcpServerTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + bootstrap = bootstrap, + sslContext = sslContext, + localAddress = localAddress ?: InetSocketAddress(0), + ) + + override fun target(host: String, port: Int): RSocketServerTarget = + target(InetSocketAddress(host, port)) +} + +@OptIn(RSocketTransportApi::class) +private class NettyTcpServerTargetImpl( + override val coroutineContext: CoroutineContext, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, + private val localAddress: SocketAddress, +) : RSocketServerTarget { + @RSocketTransportApi + override suspend fun startServer(handler: RSocketConnectionHandler): NettyTcpServerInstance { + currentCoroutineContext().ensureActive() + coroutineContext.ensureActive() + + val instanceContext = coroutineContext.childContext() + val channel = try { + bootstrap.clone().childHandler( + NettyTcpConnectionInitializer( + sslContext = sslContext, + remoteAddress = null, + handler = handler, + coroutineContext = instanceContext.supervisorContext() + ) + ).bind(localAddress).awaitChannel() + } catch (cause: Throwable) { + instanceContext.job.cancel("Failed to bind", cause) + throw cause + } + + // TODO: handle server closure + return NettyTcpServerInstanceImpl( + coroutineContext = instanceContext, + localAddress = (channel as ServerChannel).localAddress() + ) + } +} + +private class NettyTcpServerInstanceImpl( + override val coroutineContext: CoroutineContext, + override val localAddress: SocketAddress, +) : NettyTcpServerInstance diff --git a/rsocket-transports/netty-tcp/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpTransportTest.kt b/rsocket-transports/netty-tcp/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpTransportTest.kt new file mode 100644 index 00000000..f1b43e9c --- /dev/null +++ b/rsocket-transports/netty-tcp/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/tcp/NettyTcpTransportTest.kt @@ -0,0 +1,65 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.tcp + +import io.netty.channel.nio.* +import io.netty.handler.ssl.util.* +import io.rsocket.kotlin.transport.tests.* +import kotlin.concurrent.* + +private val eventLoop = NioEventLoopGroup().also { + Runtime.getRuntime().addShutdownHook(thread(start = false) { + it.shutdownGracefully().await(1000) + }) +} +private val certificates = SelfSignedCertificate() + +class NettyTcpTransportTest : TransportTest() { + override suspend fun before() { + val server = startServer( + NettyTcpServerTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + }.target() + ) + client = connectClient( + NettyTcpClientTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + }.target(server.localAddress) + ) + } +} + +class NettyTcpSslTransportTest : TransportTest() { + override suspend fun before() { + val server = startServer( + NettyTcpServerTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + ssl { + keyManager(certificates.certificate(), certificates.privateKey()) + } + }.target() + ) + client = connectClient( + NettyTcpClientTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + ssl { + trustManager(InsecureTrustManagerFactory.INSTANCE) + } + }.target(server.localAddress) + ) + } +} diff --git a/rsocket-transports/netty-websocket/api/rsocket-transport-netty-websocket.api b/rsocket-transports/netty-websocket/api/rsocket-transport-netty-websocket.api new file mode 100644 index 00000000..b42b0cd5 --- /dev/null +++ b/rsocket-transports/netty-websocket/api/rsocket-transport-netty-websocket.api @@ -0,0 +1,49 @@ +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport$Factory; + public abstract fun target (Ljava/lang/String;Ljava/lang/Integer;Ljava/lang/String;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public abstract fun target (Ljava/lang/String;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public abstract fun target (Ljava/net/URI;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public abstract fun target (Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport;Ljava/lang/String;Ljava/lang/Integer;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport;Ljava/lang/String;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketClientTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport;Ljava/net/URI;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketClientTarget; +} + +public final class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun bootstrap (Lkotlin/jvm/functions/Function1;)V + public abstract fun channel (Lkotlin/reflect/KClass;)V + public abstract fun channelFactory (Lio/netty/channel/ChannelFactory;)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun ssl (Lkotlin/jvm/functions/Function1;)V + public abstract fun webSocketProtocolConfig (Lkotlin/jvm/functions/Function1;)V +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerInstance : io/rsocket/kotlin/transport/RSocketServerInstance { + public abstract fun getLocalAddress ()Ljava/net/InetSocketAddress; + public abstract fun getWebSocketProtocolConfig ()Lio/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig; +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport$Factory; + public abstract fun target (Ljava/lang/String;ILjava/lang/String;Ljava/lang/String;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public abstract fun target (Ljava/net/InetSocketAddress;Ljava/lang/String;Ljava/lang/String;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport;Ljava/lang/String;ILjava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; + public static synthetic fun target$default (Lio/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport;Ljava/net/InetSocketAddress;Ljava/lang/String;Ljava/lang/String;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketServerTarget; +} + +public final class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun bootstrap (Lkotlin/jvm/functions/Function1;)V + public abstract fun channel (Lkotlin/reflect/KClass;)V + public abstract fun channelFactory (Lio/netty/channel/ChannelFactory;)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun eventLoopGroup (Lio/netty/channel/EventLoopGroup;Z)V + public abstract fun ssl (Lkotlin/jvm/functions/Function1;)V + public abstract fun webSocketProtocolConfig (Lkotlin/jvm/functions/Function1;)V +} + diff --git a/rsocket-transports/netty-websocket/build.gradle.kts b/rsocket-transports/netty-websocket/build.gradle.kts new file mode 100644 index 00000000..c72d6c15 --- /dev/null +++ b/rsocket-transports/netty-websocket/build.gradle.kts @@ -0,0 +1,43 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import rsocketbuild.* + +plugins { + id("rsocketbuild.multiplatform-library") +} + +description = "rsocket-kotlin Netty WebSocket client/server transport implementation" + +kotlin { + jvmTarget() + + sourceSets { + jvmMain.dependencies { + implementation(projects.rsocketTransportNettyInternal) + implementation(projects.rsocketInternalIo) + api(projects.rsocketCore) + api(libs.netty.handler) + api(libs.netty.codec.http) + } + jvmTest.dependencies { + implementation(projects.rsocketTransportTests) + implementation(libs.bouncycastle) + // TODO: add JVM logging consistently + // implementation("ch.qos.logback:logback-classic:1.2.11") + } + } +} diff --git a/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport.kt b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport.kt new file mode 100644 index 00000000..8befdcfd --- /dev/null +++ b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketClientTransport.kt @@ -0,0 +1,212 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.* +import io.netty.channel.socket.nio.* +import io.netty.handler.codec.http.* +import io.netty.handler.codec.http.websocketx.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.netty.internal.* +import kotlinx.coroutines.* +import java.net.* +import kotlin.coroutines.* +import kotlin.reflect.* + +@OptIn(RSocketTransportApi::class) +public sealed interface NettyWebSocketClientTransport : RSocketTransport { + public fun target(configure: WebSocketClientProtocolConfig.Builder.() -> Unit): RSocketClientTarget + public fun target(uri: URI, configure: WebSocketClientProtocolConfig.Builder.() -> Unit = {}): RSocketClientTarget + public fun target(urlString: String, configure: WebSocketClientProtocolConfig.Builder.() -> Unit = {}): RSocketClientTarget + + public fun target( + host: String? = null, + port: Int? = null, + path: String? = null, + configure: WebSocketClientProtocolConfig.Builder.() -> Unit = {}, + ): RSocketClientTarget + + public companion object Factory : + RSocketTransportFactory(::NettyWebSocketClientTransportBuilderImpl) +} + +@OptIn(RSocketTransportApi::class) +public sealed interface NettyWebSocketClientTransportBuilder : RSocketTransportBuilder { + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: Bootstrap.() -> Unit) + public fun ssl(block: SslContextBuilder.() -> Unit) + public fun webSocketProtocolConfig(block: WebSocketClientProtocolConfig.Builder.() -> Unit) +} + +private class NettyWebSocketClientTransportBuilderImpl : NettyWebSocketClientTransportBuilder { + private var channelFactory: ChannelFactory? = null + private var eventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = false + private var bootstrap: (Bootstrap.() -> Unit)? = null + private var ssl: (SslContextBuilder.() -> Unit)? = null + private var webSocketProtocolConfig: (WebSocketClientProtocolConfig.Builder.() -> Unit)? = null + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.eventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: Bootstrap.() -> Unit) { + bootstrap = block + } + + override fun ssl(block: SslContextBuilder.() -> Unit) { + ssl = block + } + + override fun webSocketProtocolConfig(block: WebSocketClientProtocolConfig.Builder.() -> Unit) { + webSocketProtocolConfig = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): NettyWebSocketClientTransport { + val sslContext = ssl?.let { + SslContextBuilder.forClient().apply(it).build() + } + + val bootstrap = Bootstrap().apply { + bootstrap?.invoke(this) + channelFactory(channelFactory ?: ReflectiveChannelFactory(NioSocketChannel::class.java)) + group(eventLoopGroup ?: NioEventLoopGroup()) + } + + return NettyWebSocketClientTransportImpl( + coroutineContext = context.supervisorContext() + bootstrap.config().group().asCoroutineDispatcher(), + sslContext = sslContext, + bootstrap = bootstrap, + webSocketProtocolConfig = webSocketProtocolConfig, + manageBootstrap = manageEventLoopGroup + ) + } +} + +private class NettyWebSocketClientTransportImpl( + override val coroutineContext: CoroutineContext, + private val sslContext: SslContext?, + private val bootstrap: Bootstrap, + private val webSocketProtocolConfig: (WebSocketClientProtocolConfig.Builder.() -> Unit)?, + manageBootstrap: Boolean, +) : NettyWebSocketClientTransport { + init { + if (manageBootstrap) callOnCancellation { + bootstrap.config().group().shutdownGracefully().awaitFuture() + } + } + + override fun target(configure: WebSocketClientProtocolConfig.Builder.() -> Unit): RSocketClientTarget { + val webSocketProtocolConfig = WebSocketClientProtocolConfig.newBuilder().apply { + // transport config first + webSocketProtocolConfig?.invoke(this) + // target config + configure.invoke(this) + }.build() + return NettyWebSocketClientTransportTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + bootstrap = bootstrap, + sslContext = sslContext, + webSocketProtocolConfig = webSocketProtocolConfig, + remoteAddress = InetSocketAddress( + /* hostname = */ webSocketProtocolConfig.webSocketUri().host, + /* port = */ webSocketProtocolConfig.webSocketUri().port + ) + ) + } + + override fun target(uri: URI, configure: WebSocketClientProtocolConfig.Builder.() -> Unit): RSocketClientTarget = target { + webSocketUri(uri) + } + + override fun target(urlString: String, configure: WebSocketClientProtocolConfig.Builder.() -> Unit): RSocketClientTarget = target { + webSocketUri(urlString) + } + + override fun target( + host: String?, + port: Int?, + path: String?, + configure: WebSocketClientProtocolConfig.Builder.() -> Unit, + ): RSocketClientTarget = target { + webSocketUri( + URI( + /* scheme = */ "ws", + /* userInfo = */ null, + /* host = */ host ?: "localhost", + /* port = */ port ?: -1, + /* path = */ if (path?.startsWith("/") == false) "/$path" else path, + /* query = */ null, + /* fragment = */ null + ) + ) + } +} + +@OptIn(RSocketTransportApi::class) +private class NettyWebSocketClientTransportTargetImpl( + override val coroutineContext: CoroutineContext, + private val bootstrap: Bootstrap, + private val sslContext: SslContext?, + private val webSocketProtocolConfig: WebSocketClientProtocolConfig, + private val remoteAddress: InetSocketAddress, +) : RSocketClientTarget { + + @RSocketTransportApi + override fun connectClient(handler: RSocketConnectionHandler): Job = launch { + bootstrap.clone().handler( + NettyWebSocketClientConnectionInitializer( + sslContext = sslContext, + webSocketProtocolConfig = webSocketProtocolConfig, + remoteAddress = remoteAddress, + handler = handler, + coroutineContext = coroutineContext, + ) + ).connect(remoteAddress).awaitFuture() + } +} + +@RSocketTransportApi +private class NettyWebSocketClientConnectionInitializer( + sslContext: SslContext?, + private val webSocketProtocolConfig: WebSocketClientProtocolConfig, + remoteAddress: InetSocketAddress?, + handler: RSocketConnectionHandler, + coroutineContext: CoroutineContext, +) : NettyWebSocketConnectionInitializer(sslContext, remoteAddress, handler, coroutineContext) { + override fun createHttpHandler(): ChannelHandler = HttpClientCodec() + override fun createWebSocketHandler(): ChannelHandler = WebSocketClientProtocolHandler(webSocketProtocolConfig) +} diff --git a/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketConnectionHandler.kt b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketConnectionHandler.kt new file mode 100644 index 00000000..13919215 --- /dev/null +++ b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketConnectionHandler.kt @@ -0,0 +1,140 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.ktor.utils.io.core.* +import io.netty.channel.* +import io.netty.channel.socket.* +import io.netty.handler.codec.http.websocketx.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.internal.* +import io.rsocket.kotlin.transport.netty.internal.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import kotlinx.coroutines.channels.Channel + +@RSocketTransportApi +internal class NettyWebSocketConnectionHandler( + private val channel: DuplexChannel, + private val handler: RSocketConnectionHandler, + scope: CoroutineScope, +) : ChannelInboundHandlerAdapter() { + private val inbound = channelForCloseable(Channel.UNLIMITED) + + private val handlerJob = scope.launch(start = CoroutineStart.LAZY) { + val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED) + + val writerJob = launch { + try { + while (true) { + // we write all available frames here, and only after it flush + // in this case, if there are several buffered frames we can send them in one go + // avoiding unnecessary flushes + // TODO: could be optimized to avoid allocation of not-needed promises + var lastWriteFuture = channel.write(BinaryWebSocketFrame(outboundQueue.dequeueFrame()?.toByteBuf() ?: break)) + while (true) lastWriteFuture = + channel.write(BinaryWebSocketFrame(outboundQueue.tryDequeueFrame()?.toByteBuf() ?: break)) + channel.flush() + // await writing to respect transport backpressure + lastWriteFuture.awaitFuture() + } + } finally { + withContext(NonCancellable) { + channel.shutdownOutput().awaitFuture() + } + } + }.onCompletion { outboundQueue.cancel() } + + try { + handler.handleConnection(NettyWebSocketConnection(outboundQueue, inbound)) + } finally { + outboundQueue.close() // will cause `writerJob` completion + // no more reading + inbound.cancel() + withContext(NonCancellable) { + writerJob.join() + channel.close().awaitFuture() + } + } + } + + override fun channelInactive(ctx: ChannelHandlerContext) { + handlerJob.cancel("Channel is not active") + + ctx.fireChannelInactive() + } + + @Suppress("OVERRIDE_DEPRECATION") + override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable?) { + handlerJob.cancel("exceptionCaught", cause) + } + + // TODO: handle error, timeout? + override fun userEventTriggered(ctx: ChannelHandlerContext, evt: Any) { + if ( + evt is WebSocketServerProtocolHandler.HandshakeComplete || + evt == WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE + ) { + handlerJob.start() + ctx.pipeline().addLast("rsocket-inbound", NettyWebSocketConnectionInboundHandler(inbound)) + } + + ctx.fireUserEventTriggered(evt) + } +} + +// TODO: implement support for isAutoRead=false to support `inbound` backpressure +private class NettyWebSocketConnectionInboundHandler( + private val inbound: SendChannel, +) : ChannelInboundHandlerAdapter() { + + override fun channelRead(ctx: ChannelHandlerContext, msg: Any) { + msg as WebSocketFrame + try { + val frame = msg.content().toByteReadPacket() + if (inbound.trySend(frame).isFailure) { + frame.close() + error("inbound is closed") + } + } finally { + msg.release() + } + } + + override fun userEventTriggered(ctx: ChannelHandlerContext?, evt: Any?) { + if (evt is ChannelInputShutdownEvent) { + inbound.close() + } + super.userEventTriggered(ctx, evt) + } +} + +@RSocketTransportApi +private class NettyWebSocketConnection( + private val outboundQueue: PrioritizationFrameQueue, + private val inbound: ReceiveChannel, +) : RSocketSequentialConnection { + override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend + override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) { + return outboundQueue.enqueueFrame(streamId, frame) + } + + override suspend fun receiveFrame(): ByteReadPacket? { + return inbound.receiveCatching().getOrNull() + } +} diff --git a/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketConnectionInitializer.kt b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketConnectionInitializer.kt new file mode 100644 index 00000000..676df169 --- /dev/null +++ b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketConnectionInitializer.kt @@ -0,0 +1,56 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.netty.channel.* +import io.netty.channel.socket.* +import io.netty.handler.codec.http.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import java.net.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal abstract class NettyWebSocketConnectionInitializer( + private val sslContext: SslContext?, + private val remoteAddress: InetSocketAddress?, + private val handler: RSocketConnectionHandler, + final override val coroutineContext: CoroutineContext, +) : ChannelInitializer(), CoroutineScope { + protected abstract fun createHttpHandler(): ChannelHandler + protected abstract fun createWebSocketHandler(): ChannelHandler + + final override fun initChannel(channel: DuplexChannel): Unit = with(channel.pipeline()) { + //addLast(LoggingHandler(if (remoteAddress == null) "server" else "client")) + if (sslContext != null) { + addLast( + "ssl", + when { + remoteAddress != null -> sslContext.newHandler(channel.alloc(), remoteAddress.hostName, remoteAddress.port) + else -> sslContext.newHandler(channel.alloc()) + } + ) + } + // TODO: should those handlers be configurable? + // what is the the good defaults here and for HttpObjectAggregator + addLast("http", createHttpHandler()) + addLast(HttpObjectAggregator(65536)) + addLast("websocket", createWebSocketHandler()) + addLast("rsocket", NettyWebSocketConnectionHandler(channel, handler, this@NettyWebSocketConnectionInitializer)) + } +} diff --git a/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport.kt b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport.kt new file mode 100644 index 00000000..bb16103f --- /dev/null +++ b/rsocket-transports/netty-websocket/src/jvmMain/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketServerTransport.kt @@ -0,0 +1,227 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.netty.bootstrap.* +import io.netty.channel.* +import io.netty.channel.ChannelFactory +import io.netty.channel.nio.* +import io.netty.channel.socket.nio.* +import io.netty.handler.codec.http.* +import io.netty.handler.codec.http.websocketx.* +import io.netty.handler.ssl.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.netty.internal.* +import kotlinx.coroutines.* +import java.net.* +import javax.net.ssl.* +import kotlin.coroutines.* +import kotlin.reflect.* + +@OptIn(RSocketTransportApi::class) +public sealed interface NettyWebSocketServerInstance : RSocketServerInstance { + public val localAddress: InetSocketAddress + public val webSocketProtocolConfig: WebSocketServerProtocolConfig +} + +@OptIn(RSocketTransportApi::class) +public sealed interface NettyWebSocketServerTransport : RSocketTransport { + + public fun target( + localAddress: InetSocketAddress? = null, + path: String = "", + protocol: String? = null, + ): RSocketServerTarget + + public fun target( + host: String = "0.0.0.0", + port: Int = 0, + path: String = "", + protocol: String? = null, + ): RSocketServerTarget + + public companion object Factory : + RSocketTransportFactory(::NettyWebSocketServerTransportBuilderImpl) +} + +@OptIn(RSocketTransportApi::class) +public sealed interface NettyWebSocketServerTransportBuilder : RSocketTransportBuilder { + public fun channel(cls: KClass) + public fun channelFactory(factory: ChannelFactory) + public fun eventLoopGroup(parentGroup: EventLoopGroup, childGroup: EventLoopGroup, manage: Boolean) + public fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) + + public fun bootstrap(block: ServerBootstrap.() -> Unit) + public fun ssl(block: SslContextBuilder.() -> Unit) + public fun webSocketProtocolConfig(block: WebSocketServerProtocolConfig.Builder.() -> Unit) +} + +private class NettyWebSocketServerTransportBuilderImpl : NettyWebSocketServerTransportBuilder { + private var channelFactory: ChannelFactory? = null + private var parentEventLoopGroup: EventLoopGroup? = null + private var childEventLoopGroup: EventLoopGroup? = null + private var manageEventLoopGroup: Boolean = false + private var bootstrap: (ServerBootstrap.() -> Unit)? = null + private var ssl: (SslContextBuilder.() -> Unit)? = null + private var webSocketProtocolConfig: (WebSocketServerProtocolConfig.Builder.() -> Unit)? = null + + override fun channel(cls: KClass) { + this.channelFactory = ReflectiveChannelFactory(cls.java) + } + + override fun channelFactory(factory: ChannelFactory) { + this.channelFactory = factory + } + + override fun eventLoopGroup(parentGroup: EventLoopGroup, childGroup: EventLoopGroup, manage: Boolean) { + this.parentEventLoopGroup = parentGroup + this.childEventLoopGroup = childGroup + this.manageEventLoopGroup = manage + } + + override fun eventLoopGroup(group: EventLoopGroup, manage: Boolean) { + this.parentEventLoopGroup = group + this.childEventLoopGroup = group + this.manageEventLoopGroup = manage + } + + override fun bootstrap(block: ServerBootstrap.() -> Unit) { + bootstrap = block + } + + override fun ssl(block: SslContextBuilder.() -> Unit) { + ssl = block + } + + override fun webSocketProtocolConfig(block: WebSocketServerProtocolConfig.Builder.() -> Unit) { + webSocketProtocolConfig = block + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): NettyWebSocketServerTransport { + val sslContext = ssl?.let { + SslContextBuilder + .forServer(KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())) + .apply(it) + .build() + } + + val bootstrap = ServerBootstrap().apply { + bootstrap?.invoke(this) + channelFactory(channelFactory ?: ReflectiveChannelFactory(NioServerSocketChannel::class.java)) + group(parentEventLoopGroup ?: NioEventLoopGroup(), childEventLoopGroup ?: NioEventLoopGroup()) + } + + return NettyWebSocketServerTransportImpl( + coroutineContext = context.supervisorContext() + bootstrap.config().childGroup().asCoroutineDispatcher(), + bootstrap = bootstrap, + sslContext = sslContext, + webSocketProtocolConfig = webSocketProtocolConfig, + manageBootstrap = manageEventLoopGroup + ) + } +} + +private class NettyWebSocketServerTransportImpl( + override val coroutineContext: CoroutineContext, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, + private val webSocketProtocolConfig: (WebSocketServerProtocolConfig.Builder.() -> Unit)?, + manageBootstrap: Boolean, +) : NettyWebSocketServerTransport { + init { + if (manageBootstrap) callOnCancellation { + bootstrap.config().childGroup().shutdownGracefully().awaitFuture() + bootstrap.config().group().shutdownGracefully().awaitFuture() + } + } + + override fun target( + localAddress: InetSocketAddress?, + path: String, + protocol: String?, + ): RSocketServerTarget = NettyWebSocketServerTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + bootstrap = bootstrap, + sslContext = sslContext, + webSocketProtocolConfig = WebSocketServerProtocolConfig.newBuilder().apply { + webSocketProtocolConfig?.invoke(this) + websocketPath(if (!path.startsWith("/")) "/$path" else path) + subprotocols(protocol) + }.build(), + localAddress = localAddress ?: InetSocketAddress(0) + ) + + override fun target(host: String, port: Int, path: String, protocol: String?): RSocketServerTarget = + target(InetSocketAddress(host, port), path, protocol) +} + +@OptIn(RSocketTransportApi::class) +private class NettyWebSocketServerTargetImpl( + override val coroutineContext: CoroutineContext, + private val bootstrap: ServerBootstrap, + private val sslContext: SslContext?, + private val webSocketProtocolConfig: WebSocketServerProtocolConfig, + private val localAddress: SocketAddress?, +) : RSocketServerTarget { + + @RSocketTransportApi + override suspend fun startServer(handler: RSocketConnectionHandler): NettyWebSocketServerInstance { + currentCoroutineContext().ensureActive() + coroutineContext.ensureActive() + + val instanceContext = coroutineContext.childContext() + val channel = try { + bootstrap.clone().childHandler( + NettyWebSocketServerConnectionInitializer( + sslContext = sslContext, + webSocketProtocolConfig = webSocketProtocolConfig, + handler = handler, + coroutineContext = instanceContext.supervisorContext() + ) + ).bind(localAddress).awaitChannel() + } catch (cause: Throwable) { + instanceContext.job.cancel("Failed to bind", cause) + throw cause + } + + // TODO: handle server closure + return NettyWebSocketServerInstanceImpl( + coroutineContext = instanceContext, + localAddress = (channel as ServerChannel).localAddress() as InetSocketAddress, + webSocketProtocolConfig = webSocketProtocolConfig + ) + } +} + +@RSocketTransportApi +private class NettyWebSocketServerConnectionInitializer( + sslContext: SslContext?, + private val webSocketProtocolConfig: WebSocketServerProtocolConfig, + handler: RSocketConnectionHandler, + coroutineContext: CoroutineContext, +) : NettyWebSocketConnectionInitializer(sslContext, null, handler, coroutineContext) { + override fun createHttpHandler(): ChannelHandler = HttpServerCodec() + override fun createWebSocketHandler(): ChannelHandler = WebSocketServerProtocolHandler(webSocketProtocolConfig) +} + +private class NettyWebSocketServerInstanceImpl( + override val coroutineContext: CoroutineContext, + override val localAddress: InetSocketAddress, + override val webSocketProtocolConfig: WebSocketServerProtocolConfig, +) : NettyWebSocketServerInstance diff --git a/rsocket-transports/netty-websocket/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketTransportTest.kt b/rsocket-transports/netty-websocket/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketTransportTest.kt new file mode 100644 index 00000000..6e9f4354 --- /dev/null +++ b/rsocket-transports/netty-websocket/src/jvmTest/kotlin/io/rsocket/kotlin/transport/netty/websocket/NettyWebSocketTransportTest.kt @@ -0,0 +1,66 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.netty.websocket + +import io.netty.channel.nio.* +import io.netty.handler.ssl.util.* +import io.rsocket.kotlin.transport.tests.* +import kotlin.concurrent.* + +private val eventLoop = NioEventLoopGroup().also { + Runtime.getRuntime().addShutdownHook(thread(start = false) { + it.shutdownGracefully().await(1000) + }) +} +private val certificates = SelfSignedCertificate() + +// TODO: add tests for paths +class NettyWebSocketTransportTest : TransportTest() { + override suspend fun before() { + val server = startServer( + NettyWebSocketServerTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + }.target() + ) + client = connectClient( + NettyWebSocketClientTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + }.target(port = server.localAddress.port) + ) + } +} + +class NettyWebSocketSslTransportTest : TransportTest() { + override suspend fun before() { + val server = startServer( + NettyWebSocketServerTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + ssl { + keyManager(certificates.certificate(), certificates.privateKey()) + } + }.target() + ) + client = connectClient( + NettyWebSocketClientTransport(testContext) { + eventLoopGroup(eventLoop, manage = false) + ssl { + trustManager(InsecureTrustManagerFactory.INSTANCE) + } + }.target(port = server.localAddress.port) + ) + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 5a431421..7699bca3 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -41,6 +41,11 @@ projects("rsocket-kotlin") { module("ktor-websocket-internal") module("ktor-websocket-client") module("ktor-websocket-server") + + module("netty-internal") + module("netty-tcp") + module("netty-websocket") + module("netty-quic") } //deep ktor integration module