Skip to content

Commit

Permalink
Fix bug related with context propagation (CoroutineServerInterceptor)
Browse files Browse the repository at this point in the history
  • Loading branch information
be-hase committed May 23, 2023
1 parent 92e7017 commit e4aad71
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 24 deletions.
1 change: 1 addition & 0 deletions grpc-kotlin/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ dependencies {
implementation(project(":grpc"))

implementation(libs.grpc.kotlin)
implementation(libs.kotlin.reflect)
implementation(libs.kotlin.coroutines.jdk8)

testImplementation(libs.kotlin.coroutines.test)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@
package com.linecorp.armeria.server.grpc.kotlin

import com.linecorp.armeria.common.annotation.UnstableApi
import com.linecorp.armeria.internal.common.kotlin.ArmeriaRequestCoroutineContext
import com.linecorp.armeria.internal.server.grpc.AbstractServerCall
import com.linecorp.armeria.server.grpc.AsyncServerInterceptor
import io.grpc.Context
import io.grpc.Metadata
import io.grpc.ServerCall
import io.grpc.ServerCallHandler
import io.grpc.ServerInterceptor
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.asCoroutineDispatcher
import io.grpc.kotlin.CoroutineContextServerInterceptor
import io.grpc.kotlin.GrpcContextElement
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.future.future
import java.util.concurrent.CompletableFuture
import kotlin.coroutines.CoroutineContext
import kotlin.reflect.full.companionObject
import kotlin.reflect.full.companionObjectInstance
import kotlin.reflect.full.memberProperties

/**
* A [ServerInterceptor] that is able to suspend the interceptor without blocking the
Expand All @@ -54,20 +57,17 @@ import java.util.concurrent.CompletableFuture
@UnstableApi
interface CoroutineServerInterceptor : AsyncServerInterceptor {

@OptIn(DelicateCoroutinesApi::class)
override fun <I : Any, O : Any> asyncInterceptCall(
call: ServerCall<I, O>,
headers: Metadata,
next: ServerCallHandler<I, O>
): CompletableFuture<ServerCall.Listener<I>> {
check(call is AbstractServerCall) {
throw IllegalArgumentException(
"Cannot use ${AsyncServerInterceptor::class.java.name} with a non-Armeria gRPC server"
)
}
val executor = call.blockingExecutor() ?: call.eventLoop()

return GlobalScope.future(executor.asCoroutineDispatcher() + ArmeriaRequestCoroutineContext(call.ctx())) {
// COROUTINE_CONTEXT_KEY.get():
// It is necessary to propagate the CoroutineContext set by the previous CoroutineContextServerInterceptor.
// (The ArmeriaRequestCoroutineContext is also propagated by CoroutineContextServerInterceptor)
// GrpcContextElement.current():
// In gRPC-kotlin, the Coroutine Context is propagated using the gRPC Context.
return CoroutineScope(COROUTINE_CONTEXT_KEY.get() + GrpcContextElement.current()).future {
suspendedInterceptCall(call, headers, next)
}
}
Expand All @@ -87,4 +87,14 @@ interface CoroutineServerInterceptor : AsyncServerInterceptor {
headers: Metadata,
next: ServerCallHandler<ReqT, RespT>
): ServerCall.Listener<ReqT>

companion object {
@Suppress("UNCHECKED_CAST")
internal val COROUTINE_CONTEXT_KEY: Context.Key<CoroutineContext> =
CoroutineContextServerInterceptor::class.let { kclass ->
val companionObject = checkNotNull(kclass.companionObject)
val property = companionObject.memberProperties.single { it.name == "COROUTINE_CONTEXT_KEY" }
checkNotNull(property.getter.call(kclass.companionObjectInstance)) as Context.Key<CoroutineContext>
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.linecorp.armeria.server.grpc.kotlin

import com.google.common.util.concurrent.ThreadFactoryBuilder
import com.google.protobuf.ByteString
import com.linecorp.armeria.client.grpc.GrpcClients
import com.linecorp.armeria.common.RequestContext
Expand All @@ -35,12 +36,19 @@ import com.linecorp.armeria.server.ServiceRequestContext
import com.linecorp.armeria.server.auth.Authorizer
import com.linecorp.armeria.server.grpc.GrpcService
import com.linecorp.armeria.testing.junit5.server.ServerExtension
import io.grpc.Context
import io.grpc.Contexts
import io.grpc.Metadata
import io.grpc.ServerCall
import io.grpc.ServerCallHandler
import io.grpc.Status
import io.grpc.StatusException
import io.grpc.kotlin.CoroutineContextServerInterceptor
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.asContextElement
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
Expand All @@ -51,13 +59,16 @@ import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.future.await
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.withContext
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.jupiter.api.extension.RegisterExtension
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ValueSource
import java.util.concurrent.CompletableFuture
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import kotlin.coroutines.CoroutineContext

internal class CoroutineServerInterceptorTest {

Expand Down Expand Up @@ -205,19 +216,23 @@ internal class CoroutineServerInterceptorTest {
@RegisterExtension
val server: ServerExtension = object : ServerExtension() {
override fun configure(sb: ServerBuilder) {
val statusFunction = GrpcStatusFunction { _: RequestContext, throwable: Throwable, _: Metadata ->
if (throwable is AnticipatedException && throwable.message == "Invalid access") {
return@GrpcStatusFunction Status.UNAUTHENTICATED
val statusFunction =
GrpcStatusFunction { _: RequestContext, throwable: Throwable, _: Metadata ->
if (throwable is AnticipatedException && throwable.message == "Invalid access") {
return@GrpcStatusFunction Status.UNAUTHENTICATED
}
// Fallback to the default.
null
}
// Fallback to the default.
null
}
val threadLocalInterceptor = ThreadLocalInterceptor()
val authInterceptor = AuthInterceptor()
val coroutineNameInterceptor = CoroutineNameInterceptor()
sb.serviceUnder(
"/non-blocking",
GrpcService.builder()
.exceptionMapping(statusFunction)
.intercept(authInterceptor)
// applying order is coroutineNameInterceptor -> authInterceptor -> threadLocalInterceptor
.intercept(threadLocalInterceptor, authInterceptor, coroutineNameInterceptor)
.addService(TestService())
.build()
)
Expand All @@ -226,7 +241,8 @@ internal class CoroutineServerInterceptorTest {
GrpcService.builder()
.addService(TestService())
.exceptionMapping(statusFunction)
.intercept(authInterceptor)
// applying order is coroutineNameInterceptor -> authInterceptor -> threadLocalInterceptor
.intercept(threadLocalInterceptor, authInterceptor, coroutineNameInterceptor)
.useBlockingTaskExecutor(true)
.build()
)
Expand All @@ -236,6 +252,10 @@ internal class CoroutineServerInterceptorTest {
private const val username = "Armeria"
private const val token = "token-1234"

private val executorDispatcher = Executors.newSingleThreadExecutor(
ThreadFactoryBuilder().setNameFormat("my-executor").build()
).asCoroutineDispatcher()

private class AuthInterceptor : CoroutineServerInterceptor {
private val authorizer = Authorizer { ctx: ServiceRequestContext, _: Metadata ->
val future = CompletableFuture<Boolean>()
Expand All @@ -254,28 +274,78 @@ internal class CoroutineServerInterceptorTest {
headers: Metadata,
next: ServerCallHandler<ReqT, RespT>
): ServerCall.Listener<ReqT> {
assertContextPropagation()

delay(100)
assertContextPropagation() // OK even if resume from suspend.

withContext(executorDispatcher) {
// OK even if the dispatcher is switched
assertContextPropagation()
assertThat(Thread.currentThread().name).contains("my-executor")
}

val result = authorizer.authorize(ServiceRequestContext.current(), headers).await()

if (result) {
return next.startCall(call, headers)
val ctx = Context.current().withValue(AUTHORIZATION_RESULT_GRPC_CONTEXT_KEY, "OK")
return Contexts.interceptCall(ctx, call, headers, next)
} else {
throw AnticipatedException("Invalid access")
}
}

private suspend fun assertContextPropagation() {
assertThat(ServiceRequestContext.currentOrNull()).isNotNull()
assertThat(currentCoroutineContext()[CoroutineName]?.name).isEqualTo("my-coroutine-name")
}

companion object {
val AUTHORIZATION_RESULT_GRPC_CONTEXT_KEY: Context.Key<String> =
Context.key("authorization-result")
}
}

private class CoroutineNameInterceptor : CoroutineContextServerInterceptor() {
override fun coroutineContext(call: ServerCall<*, *>, headers: Metadata): CoroutineContext {
return CoroutineName("my-coroutine-name")
}
}

private class ThreadLocalInterceptor : CoroutineContextServerInterceptor() {
override fun coroutineContext(call: ServerCall<*, *>, headers: Metadata): CoroutineContext {
return THREAD_LOCAL.asContextElement(value = "thread-local-value")
}

companion object {
val THREAD_LOCAL = ThreadLocal<String>()
}
}

private class TestService : TestServiceGrpcKt.TestServiceCoroutineImplBase() {
override suspend fun unaryCall(request: SimpleRequest): SimpleResponse {
assertContextPropagation()

delay(100)
assertContextPropagation() // OK even if resume from suspend.

withContext(executorDispatcher) {
// OK even if the dispatcher is switched
assertContextPropagation()
assertThat(Thread.currentThread().name).contains("my-executor")
}

if (request.fillUsername) {
return SimpleResponse.newBuilder().setUsername(username).build()
}

return SimpleResponse.getDefaultInstance()
}

override fun streamingOutputCall(request: StreamingOutputCallRequest): Flow<StreamingOutputCallResponse> {
return flow {
for (i in 1..5) {
delay(500)
assertContextPropagation()
emit(buildReply(username))
}
}
Expand All @@ -284,16 +354,27 @@ internal class CoroutineServerInterceptorTest {
override suspend fun streamingInputCall(requests: Flow<StreamingInputCallRequest>): StreamingInputCallResponse {
val names = requests.map { it.payload.body.toString() }.toList()

assertContextPropagation()

return buildReply(names)
}

override fun fullDuplexCall(requests: Flow<StreamingOutputCallRequest>): Flow<StreamingOutputCallResponse> {
return flow {
requests.collect {
delay(500)
assertContextPropagation()
emit(buildReply(username))
}
}
}

private suspend fun assertContextPropagation() {
assertThat(ServiceRequestContext.currentOrNull()).isNotNull()
assertThat(currentCoroutineContext()[CoroutineName]?.name).isEqualTo("my-coroutine-name")
assertThat(ThreadLocalInterceptor.THREAD_LOCAL.get()).isEqualTo("thread-local-value")
assertThat(AuthInterceptor.AUTHORIZATION_RESULT_GRPC_CONTEXT_KEY.get()).isEqualTo("OK")
}
}

private fun buildReply(message: String): StreamingOutputCallResponse =
Expand Down

0 comments on commit e4aad71

Please sign in to comment.