Skip to content

Commit

Permalink
KTOR-8105 Fix for concurrent flush attempts breaking CIO client (again)
Browse files Browse the repository at this point in the history
  • Loading branch information
bjhham committed Feb 13, 2025
1 parent bd652e4 commit 6768f4a
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ internal suspend fun writeBody(
val chunkedJob: EncoderJob? = if (chunked) encodeChunked(output, callContext) else null
val channel = chunkedJob?.channel ?: output

val scope = CoroutineScope(callContext + CoroutineName("Request body writer"))
val scope = CoroutineScope(callContext + CoroutineName("body writer"))
scope.launch {
try {
processOutgoingContent(request, body, channel)
Expand Down Expand Up @@ -194,7 +194,7 @@ internal suspend fun readResponse(
}

else -> {
val coroutineScope = CoroutineScope(callContext + CoroutineName("Response"))
val coroutineScope = CoroutineScope(callContext + CoroutineName("body reader"))
val httpBodyParser = coroutineScope.writer(autoFlush = true) {
parseHttpBody(version, contentLength, transferEncoding, connectionType, input, channel)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ internal abstract class SocketBase(
override fun close() {
if (!closeFlag.compareAndSet(false, true)) return

// TODO this can be dangerous if there is another thread writing to this
readerJob.value?.channel?.close()
writerJob.value?.cancel()
checkChannels()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ internal class TLSClientHandshake(
rawInput: ByteReadChannel,
rawOutput: ByteWriteChannel,
private val config: TLSConfig,
override val coroutineContext: CoroutineContext
override val coroutineContext: CoroutineContext,
private val closeDeferred: CompletableDeferred<Unit> = CompletableDeferred<Unit>(),
) : CoroutineScope {
private val digest = Digest()
private val clientSeed: ByteArray = config.random.generateClientSeed()
Expand Down Expand Up @@ -101,36 +102,44 @@ internal class TLSClientHandshake(
}
}

var useCipher = false

@OptIn(ObsoleteCoroutinesApi::class)
val output: SendChannel<TLSRecord> = actor(CoroutineName("cio-tls-encoder")) {
var useCipher = false

try {
for (rawRecord in channel) {
try {
val record = if (useCipher) cipher.encrypt(rawRecord) else rawRecord
if (rawRecord.type == TLSRecordType.ChangeCipherSpec) useCipher = true

rawOutput.writeRecord(record)
} catch (cause: Throwable) {
channel.close(cause)
}
for (rawRecord in channel) {
try {
val record = if (useCipher) cipher.encrypt(rawRecord) else rawRecord
if (rawRecord.type == TLSRecordType.ChangeCipherSpec) useCipher = true

rawOutput.writeRecord(record)
} catch (cause: Throwable) {
channel.close(cause)
}
} finally {
rawOutput.writeRecord(
TLSRecord(
}
}.apply {
invokeOnClose {
launch(CoroutineName("cio-tls-close")) {
val closeRecord = TLSRecord(
TLSRecordType.Alert,
packet = buildPacket {
writeByte(TLSAlertLevel.WARNING.code.toByte())
writeByte(TLSAlertType.CloseNotify.code.toByte())
}
)
)

rawOutput.flushAndClose()
val record = if (useCipher) cipher.encrypt(closeRecord) else closeRecord
rawOutput.writeRecord(record)
rawOutput.flushAndClose()
closeDeferred.complete(Unit)
}
}
}

fun close(): Deferred<Unit> {
input.cancel()
output.close()
return closeDeferred
}

@OptIn(ExperimentalCoroutinesApi::class)
private val handshakes: ReceiveChannel<TLSHandshake> = produce(CoroutineName("cio-tls-handshake")) {
while (true) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,24 @@ internal actual suspend fun openTLSSession(
val handshake = TLSClientHandshake(input, output, config, context)
try {
handshake.negotiate()
} catch (cause: ClosedSendChannelException) {
throw TLSException("Negotiation failed due to EOS", cause)
} catch (cause: Exception) {
runCatching {
handshake.close().await()
socket.close()
}
if (cause is ClosedSendChannelException)
throw TlsException("Negotiation failed due to EOS", cause)
else throw cause
}
return TLSSocket(handshake.input, handshake.output, socket, context)
return TLSSocket(
handshake,
socket,
context
)
}

private class TLSSocket(
private val input: ReceiveChannel<TLSRecord>,
private val output: SendChannel<TLSRecord>,
private val base: TLSClientHandshake,
private val socket: Socket,
override val coroutineContext: CoroutineContext
) : CoroutineScope, Socket by socket {
Expand All @@ -49,15 +58,15 @@ private class TLSSocket(

private suspend fun appDataInputLoop(pipe: ByteWriteChannel) {
try {
input.consumeEach { record ->
base.input.consumeEach { record ->
val packet = record.packet
val length = packet.remaining
when (record.type) {
TLSRecordType.ApplicationData -> {
pipe.writePacket(record.packet)
pipe.flush()
}
else -> throw TLSException("Unexpected record ${record.type} ($length bytes)")
else -> throw TlsException("Unexpected record ${record.type} ($length bytes)")
}
}
} catch (_: Throwable) {
Expand All @@ -76,16 +85,25 @@ private class TLSSocket(
if (rc == -1) break

buffer.flip()
output.send(TLSRecord(TLSRecordType.ApplicationData, packet = buildPacket { writeFully(buffer) }))
base.output.send(TLSRecord(TLSRecordType.ApplicationData, packet = buildPacket { writeFully(buffer) }))
}
} catch (_: ClosedSendChannelException) {
// The socket was already closed, we should ignore that error.
} finally {
output.close()
base.output.close()
}
}

override fun dispose() {
socket.dispose()
close()
}

/**
* The socket is closed implicitly after the output is closed.
*/
override fun close() {
base.close().invokeOnCompletion {
socket.close()
}
}
}

0 comments on commit 6768f4a

Please sign in to comment.