diff --git a/kotlin-result-coroutines/src/commonMain/kotlin/com/github/michaelbull/result/coroutines/ParZip.kt b/kotlin-result-coroutines/src/commonMain/kotlin/com/github/michaelbull/result/coroutines/ParZip.kt new file mode 100644 index 0000000..8295a64 --- /dev/null +++ b/kotlin-result-coroutines/src/commonMain/kotlin/com/github/michaelbull/result/coroutines/ParZip.kt @@ -0,0 +1,103 @@ +package com.github.michaelbull.result.coroutines + +import com.github.michaelbull.result.Err +import com.github.michaelbull.result.Result +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + +private typealias Producer = suspend CoroutineScope.() -> Result + +private suspend inline fun parZipInternal( + producers: List>, + crossinline transform: suspend CoroutineScope.(values: List) -> V, +): Result { + contract { + callsInPlace(transform, InvocationKind.AT_MOST_ONCE) + } + return coroutineBinding { + val values = producers + .map { producer -> async { producer().bind() } } + .awaitAll() + transform(values) + } +} + +/** + * Runs [producer1] and [producer2] in parallel, combining their successful results with [transform]. + * If either computation fails with an [Err], the other is cancelled, and the error is returned as [Err]. + */ +public suspend fun parZip( + producer1: Producer, + producer2: Producer, + transform: suspend CoroutineScope.(T1, T2) -> V, +): Result = + parZipInternal(listOf(producer1, producer2)) { + @Suppress("UNCHECKED_CAST") + transform(it[0] as T1, it[1] as T2) + } + +/** + * Runs [producer1], [producer2], and [producer3] in parallel, combining their successful results with [transform]. + * If any computation fails with an [Err], the others are cancelled, and the error is returned as [Err]. + */ +public suspend fun parZip( + producer1: Producer, + producer2: Producer, + producer3: Producer, + transform: suspend CoroutineScope.(T1, T2, T3) -> V, +): Result = + parZipInternal(listOf(producer1, producer2, producer3)) { + @Suppress("UNCHECKED_CAST") + transform( + it[0] as T1, + it[1] as T2, + it[2] as T3 + ) + } + +/** + * Runs [producer1], [producer2], [producer3], and [producer4] in parallel, combining their successful results with [transform]. + * If any computation fails with an [Err], the others are cancelled, and the error is returned as [Err]. + */ +public suspend fun parZip( + producer1: Producer, + producer2: Producer, + producer3: Producer, + producer4: Producer, + transform: suspend CoroutineScope.(T1, T2, T3, T4) -> V, +): Result = + parZipInternal(listOf(producer1, producer2, producer3, producer4)) { + @Suppress("UNCHECKED_CAST") + transform( + it[0] as T1, + it[1] as T2, + it[2] as T3, + it[3] as T4 + ) + } + +/** + * Runs [producer1], [producer2], [producer3], [producer4], and [producer5] in parallel, combining their successful results with [transform]. + * If any computation fails with an [Err], the others are cancelled, and the error is returned as [Err]. + */ +public suspend fun parZip( + producer1: Producer, + producer2: Producer, + producer3: Producer, + producer4: Producer, + producer5: Producer, + transform: suspend CoroutineScope.(T1, T2, T3, T4, T5) -> V, +): Result = + parZipInternal(listOf(producer1, producer2, producer3, producer4, producer5)) { + @Suppress("UNCHECKED_CAST") + transform( + it[0] as T1, + it[1] as T2, + it[2] as T3, + it[3] as T4, + it[4] as T5 + ) + } diff --git a/kotlin-result-coroutines/src/commonTest/kotlin/com.github.michaelbull.result.coroutines/ParZipTest.kt b/kotlin-result-coroutines/src/commonTest/kotlin/com.github.michaelbull.result.coroutines/ParZipTest.kt new file mode 100644 index 0000000..7d8d1e7 --- /dev/null +++ b/kotlin-result-coroutines/src/commonTest/kotlin/com.github.michaelbull.result.coroutines/ParZipTest.kt @@ -0,0 +1,211 @@ +package com.github.michaelbull.result.coroutines + +import com.github.michaelbull.result.Err +import com.github.michaelbull.result.Ok +import com.github.michaelbull.result.Result +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import kotlin.test.Test +import kotlin.test.assertEquals + +private suspend inline fun simulateDelay() = delay(100) + +private suspend fun produceErr(error: String): Result { + simulateDelay() + return Err(error) +} + +private suspend fun produceOk(value: V): Result { + simulateDelay() + return Ok(value) +} + +class ParZipTest { + + data class ZipData3(val a: String, val b: Int, val c: Boolean) + data class ZipData4(val a: String, val b: Int, val c: Boolean, val d: Double) + data class ZipData5(val a: String, val b: Int, val c: Boolean, val d: Double, val e: Char) + + @Test + fun parZip2ReturnsTransformedValueIfBothOk() = runTest { + val modifyGate = CompletableDeferred() + + val result = withContext(Dispatchers.Default) { + parZip( + { + modifyGate.await() + produceOk(value = "producer1") + }, + { + modifyGate.complete(Unit) + produceOk(value = "producer2") + }, + { v1, v2 -> + simulateDelay() + v1 to v2 + } + ) + } + + assertEquals( + expected = Ok("producer1" to "producer2"), + actual = result, + ) + } + + @Test + fun parZip2ReturnsErrIfOneOfTwoErr() = runTest { + val modifyGate = CompletableDeferred() + + val result = withContext(Dispatchers.Default) { + parZip( + { + modifyGate.await() + produceOk(value = "producer1") + }, + { + modifyGate.complete(Unit) + produceErr(error = "failed") + }, + { v1, v2 -> + simulateDelay() + v1 to v2 + } + ) + } + + assertEquals( + expected = Err("failed"), + actual = result, + ) + } + + @Test + fun parZip3ReturnsTransformedValueIfAllOk() = runTest { + val result = withContext(Dispatchers.Default) { + parZip( + { produceOk(value = "producer1") }, + { produceOk(value = 42) }, + { produceOk(value = true) }, + { v1, v2, v3 -> + simulateDelay() + ZipData3(v1, v2, v3) + } + ) + } + + assertEquals( + expected = Ok(ZipData3("producer1", 42, true)), + actual = result, + ) + } + + @Test + fun parZip3ReturnsErrIfOneOfThreeErr() = runTest { + val result = withContext(Dispatchers.Default) { + parZip( + { produceOk(value = "producer1") }, + { produceErr(error = "failed") }, + { produceOk(value = true) }, + { v1, v2, v3 -> + simulateDelay() + ZipData3(v1, v2, v3) + } + ) + } + + assertEquals( + expected = Err("failed"), + actual = result, + ) + } + + @Test + fun parZip4ReturnsTransformedValueIfAllOk() = runTest { + val result = withContext(Dispatchers.Default) { + parZip( + { produceOk(value = "producer1") }, + { produceOk(value = 42) }, + { produceOk(value = true) }, + { produceOk(value = 3.14) }, + { v1, v2, v3, v4 -> + simulateDelay() + ZipData4(v1, v2, v3, v4) + } + ) + } + + assertEquals( + expected = Ok(ZipData4("producer1", 42, true, 3.14)), + actual = result, + ) + } + + @Test + fun parZip4ReturnsErrIfOneOfFourErr() = runTest { + val result = withContext(Dispatchers.Default) { + parZip( + { produceOk(value = "producer1") }, + { produceErr(error = "failed") }, + { produceOk(value = true) }, + { produceOk(value = 3.14) }, + { v1, v2, v3, v4 -> + simulateDelay() + ZipData4(v1, v2, v3, v4) + } + ) + } + + assertEquals( + expected = Err("failed"), + actual = result, + ) + } + + @Test + fun parZip5ReturnsTransformedValueIfAllOk() = runTest { + val result = withContext(Dispatchers.Default) { + parZip( + { produceOk(value = "producer1") }, + { produceOk(value = 42) }, + { produceOk(value = true) }, + { produceOk(value = 3.14) }, + { produceOk(value = 'X') }, + { v1, v2, v3, v4, v5 -> + simulateDelay() + ZipData5(v1, v2, v3, v4, v5) + } + ) + } + + assertEquals( + expected = Ok(ZipData5("producer1", 42, true, 3.14, 'X')), + actual = result, + ) + } + + @Test + fun parZip5ReturnsErrIfOneOfFiveErr() = runTest { + val result = withContext(Dispatchers.Default) { + parZip( + { produceOk(value = "producer1") }, + { produceErr(error = "failed") }, + { produceOk(value = true) }, + { produceOk(value = 3.14) }, + { produceOk(value = 'X') }, + { v1, v2, v3, v4, v5 -> + simulateDelay() + ZipData5(v1, v2, v3, v4, v5) + } + ) + } + + assertEquals( + expected = Err("failed"), + actual = result, + ) + } +}