diff --git a/grpc/src/main/java/com/linecorp/armeria/common/grpc/GoogleGrpcStatusFunction.java b/grpc/src/main/java/com/linecorp/armeria/common/grpc/GoogleGrpcStatusFunction.java new file mode 100644 index 00000000000..db838180957 --- /dev/null +++ b/grpc/src/main/java/com/linecorp/armeria/common/grpc/GoogleGrpcStatusFunction.java @@ -0,0 +1,73 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common.grpc; + +import static com.linecorp.armeria.internal.common.grpc.MetadataUtil.GRPC_STATUS_DETAILS_BIN_KEY; +import static java.util.Objects.requireNonNull; + +import com.linecorp.armeria.common.RequestContext; +import com.linecorp.armeria.common.annotation.Nullable; +import com.linecorp.armeria.common.annotation.UnstableApi; +import com.linecorp.armeria.common.util.Exceptions; + +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.StatusRuntimeException; + +/** + * A {@link GrpcStatusFunction} that provides a way to include details of a status into a {@link Metadata}. + * You can implement a mapping function to convert {@link Throwable} into a {@link com.google.rpc.Status} + * which is stored in the `grpc-status-details-bin` key in the {@link Metadata}. + * If a given {@link Throwable} is an instance of either {@link StatusRuntimeException} or + * {@link StatusException}, the {@link Status} retrieved from the exception is + * returned with higher priority. + */ +@UnstableApi +public interface GoogleGrpcStatusFunction extends GrpcStatusFunction { + + @Nullable + @Override + default Status apply(RequestContext ctx, Throwable throwable, Metadata metadata) { + final Throwable cause = Exceptions.peel(requireNonNull(throwable, "throwable")); + if (cause instanceof StatusRuntimeException) { + return ((StatusRuntimeException) cause).getStatus(); + } + if (cause instanceof StatusException) { + return ((StatusException) cause).getStatus(); + } + final com.google.rpc.Status statusProto = applyStatusProto(ctx, cause, metadata); + if (statusProto == null) { + return null; + } + final Status status = Status.fromCodeValue(statusProto.getCode()) + .withDescription(statusProto.getMessage()); + metadata.discardAll(GRPC_STATUS_DETAILS_BIN_KEY); + metadata.put(GRPC_STATUS_DETAILS_BIN_KEY, statusProto); + return status; + } + + /** + * Maps the specified {@link Throwable} to a {@link com.google.rpc.Status}, + * and mutates the specified {@link Metadata}. + * The `grpc-status-details-bin` key is ignored since it will be overwritten + * by {@link GoogleGrpcStatusFunction#apply(RequestContext, Throwable, Metadata)}. + * If {@code null} is returned, the built-in mapping rule is used by default. + */ + com.google.rpc.@Nullable Status applyStatusProto(RequestContext ctx, Throwable throwable, + Metadata metadata); +} diff --git a/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/MetadataUtil.java b/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/MetadataUtil.java index 7e18f366de5..e4b75be0e38 100644 --- a/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/MetadataUtil.java +++ b/grpc/src/main/java/com/linecorp/armeria/internal/common/grpc/MetadataUtil.java @@ -28,6 +28,7 @@ import com.google.common.base.CharMatcher; import com.google.common.collect.ImmutableSet; import com.google.common.io.BaseEncoding; +import com.google.rpc.Status; import com.linecorp.armeria.common.HttpHeaderNames; import com.linecorp.armeria.common.HttpHeaders; @@ -36,6 +37,8 @@ import io.grpc.InternalMetadata; import io.grpc.Metadata; +import io.grpc.Metadata.Key; +import io.grpc.protobuf.ProtoUtils; import io.netty.util.AsciiString; /** @@ -43,6 +46,13 @@ */ public final class MetadataUtil { + /** + * A key for {@link Status} whose name is {@code "grpc-status-details-bin"}. + */ + public static final Key GRPC_STATUS_DETAILS_BIN_KEY = Key.of( + GrpcHeaderNames.GRPC_STATUS_DETAILS_BIN.toString(), + ProtoUtils.metadataMarshaller(Status.getDefaultInstance())); + private static final Logger logger = LoggerFactory.getLogger(MetadataUtil.class); private static final CharMatcher COMMA_MATCHER = CharMatcher.is(','); diff --git a/grpc/src/test/java/com/linecorp/armeria/common/grpc/GoogleGrpcStatusFunctionTest.java b/grpc/src/test/java/com/linecorp/armeria/common/grpc/GoogleGrpcStatusFunctionTest.java new file mode 100644 index 00000000000..e607ef6bbd3 --- /dev/null +++ b/grpc/src/test/java/com/linecorp/armeria/common/grpc/GoogleGrpcStatusFunctionTest.java @@ -0,0 +1,244 @@ +/* + * Copyright 2023 LINE Corporation + * + * LINE Corporation licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.linecorp.armeria.common.grpc; + +import static com.linecorp.armeria.internal.common.grpc.MetadataUtil.GRPC_STATUS_DETAILS_BIN_KEY; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.rpc.Code; + +import com.linecorp.armeria.client.grpc.GrpcClients; +import com.linecorp.armeria.common.RequestContext; +import com.linecorp.armeria.common.auth.AuthToken; +import com.linecorp.armeria.grpc.testing.Error.AuthError; +import com.linecorp.armeria.grpc.testing.Error.InternalError; +import com.linecorp.armeria.grpc.testing.Messages.SimpleRequest; +import com.linecorp.armeria.grpc.testing.Messages.SimpleResponse; +import com.linecorp.armeria.grpc.testing.TestServiceGrpc.TestServiceBlockingStub; +import com.linecorp.armeria.grpc.testing.TestServiceGrpc.TestServiceImplBase; +import com.linecorp.armeria.protobuf.EmptyProtos.Empty; +import com.linecorp.armeria.server.ServerBuilder; +import com.linecorp.armeria.server.ServiceRequestContext; +import com.linecorp.armeria.server.grpc.GrpcService; +import com.linecorp.armeria.server.logging.LoggingService; +import com.linecorp.armeria.testing.junit5.server.ServerExtension; + +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.protobuf.StatusProto; +import io.grpc.stub.StreamObserver; + +class GoogleGrpcStatusFunctionTest { + + @RegisterExtension + static ServerExtension server = new ServerExtension() { + @Override + protected void configure(ServerBuilder sb) { + sb.service(GrpcService.builder() + .intercept(new AuthInterceptor()) + .exceptionMapping(new ExceptionHandler()) + .addService(new TestService()) + .build()) + .decorator(LoggingService.newDecorator()); + } + }; + + @Test + void applyInternalError() { + final TestServiceBlockingStub client = GrpcClients.builder(server.httpUri()) + .auth(AuthToken.ofOAuth2("token-1234")) + .build(TestServiceBlockingStub.class); + final SimpleRequest request = SimpleRequest.newBuilder() + .setFillUsername(true) + .build(); + assertThatThrownBy(() -> client.unaryCall(request)) + .isInstanceOfSatisfying(StatusRuntimeException.class, e -> { + assertThat(e.getStatus().getCode()).isEqualTo(Status.INTERNAL.getCode()); + final com.google.rpc.Status status = e.getTrailers().get(GRPC_STATUS_DETAILS_BIN_KEY); + assertThat(status).isNotNull(); + assertThat(status.getCode()).isEqualTo(Code.INTERNAL.getNumber()); + assertThat(status.getDetailsCount()).isEqualTo(1); + final InternalError internalError; + try { + internalError = status.getDetails(0).unpack(InternalError.class); + } catch (InvalidProtocolBufferException ex) { + throw new RuntimeException(ex); + } + assertThat(internalError.getCode()).isEqualTo(123); + assertThat(internalError.getMessage()).isEqualTo("Unexpected error"); + }); + } + + @Test + void applyAuthError() { + final TestServiceBlockingStub client = GrpcClients.builder(server.httpUri()) + .auth(AuthToken.ofOAuth2("token-12345")) + .build(TestServiceBlockingStub.class); + final SimpleRequest request = SimpleRequest.newBuilder() + .setFillUsername(true) + .build(); + assertThatThrownBy(() -> client.unaryCall(request)) + .isInstanceOfSatisfying(StatusRuntimeException.class, e -> { + assertThat(e.getStatus().getCode()).isEqualTo(Status.UNAUTHENTICATED.getCode()); + final com.google.rpc.Status status = e.getTrailers().get(GRPC_STATUS_DETAILS_BIN_KEY); + assertThat(status).isNotNull(); + assertThat(status.getCode()).isEqualTo(Code.UNAUTHENTICATED.getNumber()); + assertThat(status.getDetailsCount()).isEqualTo(1); + final AuthError authError; + try { + authError = status.getDetails(0).unpack(AuthError.class); + } catch (InvalidProtocolBufferException ex) { + throw new RuntimeException(ex); + } + assertThat(authError.getCode()).isEqualTo(334); + assertThat(authError.getMessage()).isEqualTo("Invalid token"); + }); + } + + @Test + void earlyReturnStatusRuntimeException() { + final TestServiceBlockingStub client = GrpcClients.builder(server.httpUri()) + .auth(AuthToken.ofOAuth2("token-1234")) + .build(TestServiceBlockingStub.class); + assertThatThrownBy(() -> client.emptyCall(Empty.getDefaultInstance())) + .isInstanceOfSatisfying(StatusRuntimeException.class, e -> { + assertThat(e.getStatus().getCode()).isEqualTo(Status.INTERNAL.getCode()); + final com.google.rpc.Status status = e.getTrailers().get(GRPC_STATUS_DETAILS_BIN_KEY); + assertThat(status).isNotNull(); + assertThat(status.getCode()).isEqualTo(Code.INTERNAL.getNumber()); + assertThat(status.getMessage()).isEqualTo("Database failure"); + assertThat(status.getDetailsCount()).isEqualTo(1); + final InternalError internalError; + try { + internalError = status.getDetails(0).unpack(InternalError.class); + } catch (InvalidProtocolBufferException ex) { + throw new RuntimeException(ex); + } + assertThat(internalError.getCode()).isEqualTo(321); + assertThat(internalError.getMessage()).isEqualTo("Primary DB failure"); + }); + } + + private static final class TestService extends TestServiceImplBase { + @Override + public void unaryCall(SimpleRequest request, StreamObserver responseObserver) { + if (request.getFillUsername()) { + throw new InternalServerException("Unexpected error", 123); + } + responseObserver.onNext(SimpleResponse.newBuilder().setUsername("Armeria").build()); + responseObserver.onCompleted(); + } + + @Override + public void emptyCall(Empty empty, StreamObserver responseObserver) { + final InternalError internalError = InternalError.newBuilder() + .setCode(321) + .setMessage("Primary DB failure") + .build(); + final com.google.rpc.Status status = com.google.rpc.Status.newBuilder() + .setCode(Code.INTERNAL.getNumber()) + .setMessage("Database failure") + .addDetails(Any.pack(internalError)) + .build(); + throw StatusProto.toStatusRuntimeException(status); + } + } + + private static final class ExceptionHandler implements GoogleGrpcStatusFunction { + + @Override + public com.google.rpc.Status applyStatusProto(RequestContext ctx, Throwable throwable, + Metadata metadata) { + if (throwable instanceof AuthenticationException) { + final AuthenticationException authenticationException = (AuthenticationException) throwable; + final AuthError authError = AuthError.newBuilder() + .setCode(authenticationException.getCode()) + .setMessage(authenticationException.getMessage()) + .build(); + return com.google.rpc.Status.newBuilder() + .setCode(Code.UNAUTHENTICATED.getNumber()) + .addDetails(Any.pack(authError)) + .build(); + } + if (throwable instanceof InternalServerException) { + final InternalServerException internalServerException = (InternalServerException) throwable; + final InternalError internalError = InternalError + .newBuilder() + .setCode(internalServerException.getCode()) + .setMessage(internalServerException.getMessage()) + .build(); + return com.google.rpc.Status.newBuilder() + .setCode(Code.INTERNAL.getNumber()) + .addDetails(Any.pack(internalError)) + .build(); + } + return null; + } + } + + private static final class AuthInterceptor implements ServerInterceptor { + + @Override + public Listener interceptCall(ServerCall call, Metadata headers, + ServerCallHandler next) { + final ServiceRequestContext ctx = ServiceRequestContext.current(); + if (!ctx.request().headers().contains("Authorization", "Bearer token-1234")) { + throw new AuthenticationException("Invalid token", 334); + } + return next.startCall(call, headers); + } + } + + private static final class InternalServerException extends RuntimeException { + + private final int code; + + InternalServerException(String message, int code) { + super(message); + this.code = code; + } + + int getCode() { + return code; + } + } + + private static class AuthenticationException extends RuntimeException { + + private final int code; + + AuthenticationException(String message, int code) { + super(message); + this.code = code; + } + + int getCode() { + return code; + } + } +} diff --git a/grpc/src/test/proto/com/linecorp/armeria/grpc/testing/error.proto b/grpc/src/test/proto/com/linecorp/armeria/grpc/testing/error.proto new file mode 100644 index 00000000000..f50f239333d --- /dev/null +++ b/grpc/src/test/proto/com/linecorp/armeria/grpc/testing/error.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +package armeria.grpc.testing; + +option java_package = "com.linecorp.armeria.grpc.testing"; + +message InternalError { + int32 code = 1; + string message = 2; +} + +message AuthError { + int32 code = 1; + string message = 2; +}