From 04cde56a232a1bff4bc1fb2f50e166b48d48bf21 Mon Sep 17 00:00:00 2001 From: ryosuke-hasebe Date: Tue, 23 May 2023 15:56:45 +0900 Subject: [PATCH] Fix bug related with context propagation (CoroutineServerInterceptor) --- grpc-kotlin/build.gradle.kts | 1 + .../grpc/kotlin/CoroutineServerInterceptor.kt | 39 ++++--- .../kotlin/CoroutineServerInterceptorTest.kt | 101 ++++++++++++++++-- 3 files changed, 117 insertions(+), 24 deletions(-) diff --git a/grpc-kotlin/build.gradle.kts b/grpc-kotlin/build.gradle.kts index bdfd375c5d1c..d43aa525b050 100644 --- a/grpc-kotlin/build.gradle.kts +++ b/grpc-kotlin/build.gradle.kts @@ -3,6 +3,7 @@ dependencies { implementation(project(":grpc")) implementation(libs.grpc.kotlin) + implementation(libs.kotlin.reflect) implementation(libs.kotlin.coroutines.jdk8) testImplementation(libs.kotlin.coroutines.test) diff --git a/grpc-kotlin/src/main/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptor.kt b/grpc-kotlin/src/main/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptor.kt index 6952725efa5f..38576f466a84 100644 --- a/grpc-kotlin/src/main/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptor.kt +++ b/grpc-kotlin/src/main/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptor.kt @@ -17,18 +17,21 @@ package com.linecorp.armeria.server.grpc.kotlin import com.linecorp.armeria.common.annotation.UnstableApi -import com.linecorp.armeria.internal.common.kotlin.ArmeriaRequestCoroutineContext -import com.linecorp.armeria.internal.server.grpc.AbstractServerCall import com.linecorp.armeria.server.grpc.AsyncServerInterceptor +import io.grpc.Context import io.grpc.Metadata import io.grpc.ServerCall import io.grpc.ServerCallHandler import io.grpc.ServerInterceptor -import kotlinx.coroutines.DelicateCoroutinesApi -import kotlinx.coroutines.GlobalScope -import kotlinx.coroutines.asCoroutineDispatcher +import io.grpc.kotlin.CoroutineContextServerInterceptor +import io.grpc.kotlin.GrpcContextElement +import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.future.future import java.util.concurrent.CompletableFuture +import kotlin.coroutines.CoroutineContext +import kotlin.reflect.full.companionObject +import kotlin.reflect.full.companionObjectInstance +import kotlin.reflect.full.memberProperties /** * A [ServerInterceptor] that is able to suspend the interceptor without blocking the @@ -54,20 +57,18 @@ import java.util.concurrent.CompletableFuture @UnstableApi interface CoroutineServerInterceptor : AsyncServerInterceptor { - @OptIn(DelicateCoroutinesApi::class) override fun asyncInterceptCall( call: ServerCall, headers: Metadata, next: ServerCallHandler ): CompletableFuture> { - check(call is AbstractServerCall) { - throw IllegalArgumentException( - "Cannot use ${AsyncServerInterceptor::class.java.name} with a non-Armeria gRPC server" - ) - } - val executor = call.blockingExecutor() ?: call.eventLoop() - - return GlobalScope.future(executor.asCoroutineDispatcher() + ArmeriaRequestCoroutineContext(call.ctx())) { + return CoroutineScope( + // It is necessary to propagate the CoroutineContext set by the previous CoroutineContextServerInterceptor. + // (The ArmeriaRequestCoroutineContext is also propagated by CoroutineContextServerInterceptor) + COROUTINE_CONTEXT_KEY.get() + // In gRPC-kotlin, the Coroutine Context is propagated using the gRPC Context. + + GrpcContextElement.current() + ).future { suspendedInterceptCall(call, headers, next) } } @@ -87,4 +88,14 @@ interface CoroutineServerInterceptor : AsyncServerInterceptor { headers: Metadata, next: ServerCallHandler ): ServerCall.Listener + + companion object { + @Suppress("UNCHECKED_CAST") + internal val COROUTINE_CONTEXT_KEY: Context.Key = + CoroutineContextServerInterceptor::class.let { kclass -> + val companionObject = checkNotNull(kclass.companionObject) + val property = companionObject.memberProperties.single { it.name == "COROUTINE_CONTEXT_KEY" } + checkNotNull(property.getter.call(kclass.companionObjectInstance)) as Context.Key + } + } } diff --git a/grpc-kotlin/src/test/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptorTest.kt b/grpc-kotlin/src/test/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptorTest.kt index 08faa79fe100..a05514ae4403 100644 --- a/grpc-kotlin/src/test/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptorTest.kt +++ b/grpc-kotlin/src/test/kotlin/com/linecorp/armeria/server/grpc/kotlin/CoroutineServerInterceptorTest.kt @@ -16,6 +16,7 @@ package com.linecorp.armeria.server.grpc.kotlin +import com.google.common.util.concurrent.ThreadFactoryBuilder import com.google.protobuf.ByteString import com.linecorp.armeria.client.grpc.GrpcClients import com.linecorp.armeria.common.RequestContext @@ -35,12 +36,19 @@ import com.linecorp.armeria.server.ServiceRequestContext import com.linecorp.armeria.server.auth.Authorizer import com.linecorp.armeria.server.grpc.GrpcService import com.linecorp.armeria.testing.junit5.server.ServerExtension +import io.grpc.Context +import io.grpc.Contexts import io.grpc.Metadata import io.grpc.ServerCall import io.grpc.ServerCallHandler import io.grpc.Status import io.grpc.StatusException +import io.grpc.kotlin.CoroutineContextServerInterceptor +import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.asContextElement +import kotlinx.coroutines.asCoroutineDispatcher +import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.delay import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow @@ -51,13 +59,16 @@ import kotlinx.coroutines.flow.toList import kotlinx.coroutines.future.await import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.jupiter.api.extension.RegisterExtension import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ValueSource import java.util.concurrent.CompletableFuture +import java.util.concurrent.Executors import java.util.concurrent.TimeUnit +import kotlin.coroutines.CoroutineContext internal class CoroutineServerInterceptorTest { @@ -205,19 +216,23 @@ internal class CoroutineServerInterceptorTest { @RegisterExtension val server: ServerExtension = object : ServerExtension() { override fun configure(sb: ServerBuilder) { - val statusFunction = GrpcStatusFunction { _: RequestContext, throwable: Throwable, _: Metadata -> - if (throwable is AnticipatedException && throwable.message == "Invalid access") { - return@GrpcStatusFunction Status.UNAUTHENTICATED + val statusFunction = + GrpcStatusFunction { _: RequestContext, throwable: Throwable, _: Metadata -> + if (throwable is AnticipatedException && throwable.message == "Invalid access") { + return@GrpcStatusFunction Status.UNAUTHENTICATED + } + // Fallback to the default. + null } - // Fallback to the default. - null - } + val threadLocalInterceptor = ThreadLocalInterceptor() val authInterceptor = AuthInterceptor() + val coroutineNameInterceptor = CoroutineNameInterceptor() sb.serviceUnder( "/non-blocking", GrpcService.builder() .exceptionMapping(statusFunction) - .intercept(authInterceptor) + // applying order is coroutineNameInterceptor -> authInterceptor -> threadLocalInterceptor + .intercept(threadLocalInterceptor, authInterceptor, coroutineNameInterceptor) .addService(TestService()) .build() ) @@ -226,7 +241,8 @@ internal class CoroutineServerInterceptorTest { GrpcService.builder() .addService(TestService()) .exceptionMapping(statusFunction) - .intercept(authInterceptor) + // applying order is coroutineNameInterceptor -> authInterceptor -> threadLocalInterceptor + .intercept(threadLocalInterceptor, authInterceptor, coroutineNameInterceptor) .useBlockingTaskExecutor(true) .build() ) @@ -236,6 +252,10 @@ internal class CoroutineServerInterceptorTest { private const val username = "Armeria" private const val token = "token-1234" + private val executorDispatcher = Executors.newSingleThreadExecutor( + ThreadFactoryBuilder().setNameFormat("my-executor").build() + ).asCoroutineDispatcher() + private class AuthInterceptor : CoroutineServerInterceptor { private val authorizer = Authorizer { ctx: ServiceRequestContext, _: Metadata -> val future = CompletableFuture() @@ -254,21 +274,70 @@ internal class CoroutineServerInterceptorTest { headers: Metadata, next: ServerCallHandler ): ServerCall.Listener { + assertContextPropagation() + + delay(100) + assertContextPropagation() // OK even if resume from suspend. + + withContext(executorDispatcher) { + // OK even if the dispatcher is switched + assertContextPropagation() + assertThat(Thread.currentThread().name).contains("my-executor") + } + val result = authorizer.authorize(ServiceRequestContext.current(), headers).await() + if (result) { - return next.startCall(call, headers) + val ctx = Context.current().withValue(AUTHORIZATION_RESULT_GRPC_CONTEXT_KEY, "OK") + return Contexts.interceptCall(ctx, call, headers, next) } else { throw AnticipatedException("Invalid access") } } + + private suspend fun assertContextPropagation() { + assertThat(ServiceRequestContext.currentOrNull()).isNotNull() + assertThat(currentCoroutineContext()[CoroutineName]?.name).isEqualTo("my-coroutine-name") + } + + companion object { + val AUTHORIZATION_RESULT_GRPC_CONTEXT_KEY: Context.Key = + Context.key("authorization-result") + } + } + + private class CoroutineNameInterceptor : CoroutineContextServerInterceptor() { + override fun coroutineContext(call: ServerCall<*, *>, headers: Metadata): CoroutineContext { + return CoroutineName("my-coroutine-name") + } + } + + private class ThreadLocalInterceptor : CoroutineContextServerInterceptor() { + override fun coroutineContext(call: ServerCall<*, *>, headers: Metadata): CoroutineContext { + return THREAD_LOCAL.asContextElement(value = "thread-local-value") + } + + companion object { + val THREAD_LOCAL = ThreadLocal() + } } private class TestService : TestServiceGrpcKt.TestServiceCoroutineImplBase() { override suspend fun unaryCall(request: SimpleRequest): SimpleResponse { + assertContextPropagation() + + delay(100) + assertContextPropagation() // OK even if resume from suspend. + + withContext(executorDispatcher) { + // OK even if the dispatcher is switched + assertContextPropagation() + assertThat(Thread.currentThread().name).contains("my-executor") + } + if (request.fillUsername) { return SimpleResponse.newBuilder().setUsername(username).build() } - return SimpleResponse.getDefaultInstance() } @@ -276,6 +345,7 @@ internal class CoroutineServerInterceptorTest { return flow { for (i in 1..5) { delay(500) + assertContextPropagation() emit(buildReply(username)) } } @@ -284,16 +354,27 @@ internal class CoroutineServerInterceptorTest { override suspend fun streamingInputCall(requests: Flow): StreamingInputCallResponse { val names = requests.map { it.payload.body.toString() }.toList() + assertContextPropagation() + return buildReply(names) } override fun fullDuplexCall(requests: Flow): Flow { return flow { requests.collect { + delay(500) + assertContextPropagation() emit(buildReply(username)) } } } + + private suspend fun assertContextPropagation() { + assertThat(ServiceRequestContext.currentOrNull()).isNotNull() + assertThat(currentCoroutineContext()[CoroutineName]?.name).isEqualTo("my-coroutine-name") + assertThat(ThreadLocalInterceptor.THREAD_LOCAL.get()).isEqualTo("thread-local-value") + assertThat(AuthInterceptor.AUTHORIZATION_RESULT_GRPC_CONTEXT_KEY.get()).isEqualTo("OK") + } } private fun buildReply(message: String): StreamingOutputCallResponse =