From 6768f4a81cdf606770c1befc15b9974696c148dd Mon Sep 17 00:00:00 2001 From: Bruce Hamilton Date: Thu, 13 Feb 2025 12:15:23 +0100 Subject: [PATCH] KTOR-8105 Fix for concurrent flush attempts breaking CIO client (again) --- .../src/io/ktor/client/engine/cio/utils.kt | 4 +- .../src/io/ktor/network/sockets/SocketBase.kt | 1 + .../io/ktor/network/tls/TLSClientHandshake.kt | 47 +++++++++++-------- .../ktor/network/tls/TLSClientSessionJvm.kt | 38 +++++++++++---- 4 files changed, 59 insertions(+), 31 deletions(-) diff --git a/ktor-client/ktor-client-cio/common/src/io/ktor/client/engine/cio/utils.kt b/ktor-client/ktor-client-cio/common/src/io/ktor/client/engine/cio/utils.kt index b2b92c93706..42f36f58fb2 100644 --- a/ktor-client/ktor-client-cio/common/src/io/ktor/client/engine/cio/utils.kt +++ b/ktor-client/ktor-client-cio/common/src/io/ktor/client/engine/cio/utils.kt @@ -122,7 +122,7 @@ internal suspend fun writeBody( val chunkedJob: EncoderJob? = if (chunked) encodeChunked(output, callContext) else null val channel = chunkedJob?.channel ?: output - val scope = CoroutineScope(callContext + CoroutineName("Request body writer")) + val scope = CoroutineScope(callContext + CoroutineName("body writer")) scope.launch { try { processOutgoingContent(request, body, channel) @@ -194,7 +194,7 @@ internal suspend fun readResponse( } else -> { - val coroutineScope = CoroutineScope(callContext + CoroutineName("Response")) + val coroutineScope = CoroutineScope(callContext + CoroutineName("body reader")) val httpBodyParser = coroutineScope.writer(autoFlush = true) { parseHttpBody(version, contentLength, transferEncoding, connectionType, input, channel) } diff --git a/ktor-network/common/src/io/ktor/network/sockets/SocketBase.kt b/ktor-network/common/src/io/ktor/network/sockets/SocketBase.kt index 18d26696bbd..cb21d2532d8 100644 --- a/ktor-network/common/src/io/ktor/network/sockets/SocketBase.kt +++ b/ktor-network/common/src/io/ktor/network/sockets/SocketBase.kt @@ -33,6 +33,7 @@ internal abstract class SocketBase( override fun close() { if (!closeFlag.compareAndSet(false, true)) return + // TODO this can be dangerous if there is another thread writing to this readerJob.value?.channel?.close() writerJob.value?.cancel() checkChannels() diff --git a/ktor-network/ktor-network-tls/jvm/src/io/ktor/network/tls/TLSClientHandshake.kt b/ktor-network/ktor-network-tls/jvm/src/io/ktor/network/tls/TLSClientHandshake.kt index be2fe7b3e91..05c09b0f128 100644 --- a/ktor-network/ktor-network-tls/jvm/src/io/ktor/network/tls/TLSClientHandshake.kt +++ b/ktor-network/ktor-network-tls/jvm/src/io/ktor/network/tls/TLSClientHandshake.kt @@ -27,7 +27,8 @@ internal class TLSClientHandshake( rawInput: ByteReadChannel, rawOutput: ByteWriteChannel, private val config: TLSConfig, - override val coroutineContext: CoroutineContext + override val coroutineContext: CoroutineContext, + private val closeDeferred: CompletableDeferred = CompletableDeferred(), ) : CoroutineScope { private val digest = Digest() private val clientSeed: ByteArray = config.random.generateClientSeed() @@ -101,36 +102,44 @@ internal class TLSClientHandshake( } } + var useCipher = false + @OptIn(ObsoleteCoroutinesApi::class) val output: SendChannel = actor(CoroutineName("cio-tls-encoder")) { - var useCipher = false - - try { - for (rawRecord in channel) { - try { - val record = if (useCipher) cipher.encrypt(rawRecord) else rawRecord - if (rawRecord.type == TLSRecordType.ChangeCipherSpec) useCipher = true - - rawOutput.writeRecord(record) - } catch (cause: Throwable) { - channel.close(cause) - } + for (rawRecord in channel) { + try { + val record = if (useCipher) cipher.encrypt(rawRecord) else rawRecord + if (rawRecord.type == TLSRecordType.ChangeCipherSpec) useCipher = true + + rawOutput.writeRecord(record) + } catch (cause: Throwable) { + channel.close(cause) } - } finally { - rawOutput.writeRecord( - TLSRecord( + } + }.apply { + invokeOnClose { + launch(CoroutineName("cio-tls-close")) { + val closeRecord = TLSRecord( TLSRecordType.Alert, packet = buildPacket { writeByte(TLSAlertLevel.WARNING.code.toByte()) writeByte(TLSAlertType.CloseNotify.code.toByte()) } ) - ) - - rawOutput.flushAndClose() + val record = if (useCipher) cipher.encrypt(closeRecord) else closeRecord + rawOutput.writeRecord(record) + rawOutput.flushAndClose() + closeDeferred.complete(Unit) + } } } + fun close(): Deferred { + input.cancel() + output.close() + return closeDeferred + } + @OptIn(ExperimentalCoroutinesApi::class) private val handshakes: ReceiveChannel = produce(CoroutineName("cio-tls-handshake")) { while (true) { diff --git a/ktor-network/ktor-network-tls/jvm/src/io/ktor/network/tls/TLSClientSessionJvm.kt b/ktor-network/ktor-network-tls/jvm/src/io/ktor/network/tls/TLSClientSessionJvm.kt index c4d6bd6c8e1..c1afa9b74b2 100644 --- a/ktor-network/ktor-network-tls/jvm/src/io/ktor/network/tls/TLSClientSessionJvm.kt +++ b/ktor-network/ktor-network-tls/jvm/src/io/ktor/network/tls/TLSClientSessionJvm.kt @@ -24,15 +24,24 @@ internal actual suspend fun openTLSSession( val handshake = TLSClientHandshake(input, output, config, context) try { handshake.negotiate() - } catch (cause: ClosedSendChannelException) { - throw TLSException("Negotiation failed due to EOS", cause) + } catch (cause: Exception) { + runCatching { + handshake.close().await() + socket.close() + } + if (cause is ClosedSendChannelException) + throw TlsException("Negotiation failed due to EOS", cause) + else throw cause } - return TLSSocket(handshake.input, handshake.output, socket, context) + return TLSSocket( + handshake, + socket, + context + ) } private class TLSSocket( - private val input: ReceiveChannel, - private val output: SendChannel, + private val base: TLSClientHandshake, private val socket: Socket, override val coroutineContext: CoroutineContext ) : CoroutineScope, Socket by socket { @@ -49,7 +58,7 @@ private class TLSSocket( private suspend fun appDataInputLoop(pipe: ByteWriteChannel) { try { - input.consumeEach { record -> + base.input.consumeEach { record -> val packet = record.packet val length = packet.remaining when (record.type) { @@ -57,7 +66,7 @@ private class TLSSocket( pipe.writePacket(record.packet) pipe.flush() } - else -> throw TLSException("Unexpected record ${record.type} ($length bytes)") + else -> throw TlsException("Unexpected record ${record.type} ($length bytes)") } } } catch (_: Throwable) { @@ -76,16 +85,25 @@ private class TLSSocket( if (rc == -1) break buffer.flip() - output.send(TLSRecord(TLSRecordType.ApplicationData, packet = buildPacket { writeFully(buffer) })) + base.output.send(TLSRecord(TLSRecordType.ApplicationData, packet = buildPacket { writeFully(buffer) })) } } catch (_: ClosedSendChannelException) { // The socket was already closed, we should ignore that error. } finally { - output.close() + base.output.close() } } override fun dispose() { - socket.dispose() + close() + } + + /** + * The socket is closed implicitly after the output is closed. + */ + override fun close() { + base.close().invokeOnCompletion { + socket.close() + } } }