Skip to content

Commit

Permalink
Provide an easier way to include additional error information into gR…
Browse files Browse the repository at this point in the history
…PC response (line#4986)

Motivation:

It would be nice if Armeria gRPC users could easily develop their own richer error-handling model

Modifications:

- Add `GoogleGrpcStatusFunction` 

Result:

- Closes line#4614
- Users can easily implement gRPC richer error model. This model enables servers to return and clients to consume additional error details expressed as one or more protobuf messages
  • Loading branch information
ta7uw authored Jul 31, 2023
1 parent 4edff55 commit d9076b5
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright 2023 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.common.grpc;

import static com.linecorp.armeria.internal.common.grpc.MetadataUtil.GRPC_STATUS_DETAILS_BIN_KEY;
import static java.util.Objects.requireNonNull;

import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.annotation.UnstableApi;
import com.linecorp.armeria.common.util.Exceptions;

import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.StatusException;
import io.grpc.StatusRuntimeException;

/**
* A {@link GrpcStatusFunction} that provides a way to include details of a status into a {@link Metadata}.
* You can implement a mapping function to convert {@link Throwable} into a {@link com.google.rpc.Status}
* which is stored in the `grpc-status-details-bin` key in the {@link Metadata}.
* If a given {@link Throwable} is an instance of either {@link StatusRuntimeException} or
* {@link StatusException}, the {@link Status} retrieved from the exception is
* returned with higher priority.
*/
@UnstableApi
public interface GoogleGrpcStatusFunction extends GrpcStatusFunction {

@Nullable
@Override
default Status apply(RequestContext ctx, Throwable throwable, Metadata metadata) {
final Throwable cause = Exceptions.peel(requireNonNull(throwable, "throwable"));
if (cause instanceof StatusRuntimeException) {
return ((StatusRuntimeException) cause).getStatus();
}
if (cause instanceof StatusException) {
return ((StatusException) cause).getStatus();
}
final com.google.rpc.Status statusProto = applyStatusProto(ctx, cause, metadata);
if (statusProto == null) {
return null;
}
final Status status = Status.fromCodeValue(statusProto.getCode())
.withDescription(statusProto.getMessage());
metadata.discardAll(GRPC_STATUS_DETAILS_BIN_KEY);
metadata.put(GRPC_STATUS_DETAILS_BIN_KEY, statusProto);
return status;
}

/**
* Maps the specified {@link Throwable} to a {@link com.google.rpc.Status},
* and mutates the specified {@link Metadata}.
* The `grpc-status-details-bin` key is ignored since it will be overwritten
* by {@link GoogleGrpcStatusFunction#apply(RequestContext, Throwable, Metadata)}.
* If {@code null} is returned, the built-in mapping rule is used by default.
*/
com.google.rpc.@Nullable Status applyStatusProto(RequestContext ctx, Throwable throwable,
Metadata metadata);
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.google.common.base.CharMatcher;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.BaseEncoding;
import com.google.rpc.Status;

import com.linecorp.armeria.common.HttpHeaderNames;
import com.linecorp.armeria.common.HttpHeaders;
Expand All @@ -36,13 +37,22 @@

import io.grpc.InternalMetadata;
import io.grpc.Metadata;
import io.grpc.Metadata.Key;
import io.grpc.protobuf.ProtoUtils;
import io.netty.util.AsciiString;

/**
* Utilities for working with {@link Metadata}.
*/
public final class MetadataUtil {

/**
* A key for {@link Status} whose name is {@code "grpc-status-details-bin"}.
*/
public static final Key<Status> GRPC_STATUS_DETAILS_BIN_KEY = Key.of(
GrpcHeaderNames.GRPC_STATUS_DETAILS_BIN.toString(),
ProtoUtils.metadataMarshaller(Status.getDefaultInstance()));

private static final Logger logger = LoggerFactory.getLogger(MetadataUtil.class);

private static final CharMatcher COMMA_MATCHER = CharMatcher.is(',');
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
/*
* Copyright 2023 LINE Corporation
*
* LINE Corporation licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.linecorp.armeria.common.grpc;

import static com.linecorp.armeria.internal.common.grpc.MetadataUtil.GRPC_STATUS_DETAILS_BIN_KEY;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.rpc.Code;

import com.linecorp.armeria.client.grpc.GrpcClients;
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.auth.AuthToken;
import com.linecorp.armeria.grpc.testing.Error.AuthError;
import com.linecorp.armeria.grpc.testing.Error.InternalError;
import com.linecorp.armeria.grpc.testing.Messages.SimpleRequest;
import com.linecorp.armeria.grpc.testing.Messages.SimpleResponse;
import com.linecorp.armeria.grpc.testing.TestServiceGrpc.TestServiceBlockingStub;
import com.linecorp.armeria.grpc.testing.TestServiceGrpc.TestServiceImplBase;
import com.linecorp.armeria.protobuf.EmptyProtos.Empty;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.grpc.GrpcService;
import com.linecorp.armeria.server.logging.LoggingService;
import com.linecorp.armeria.testing.junit5.server.ServerExtension;

import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.protobuf.StatusProto;
import io.grpc.stub.StreamObserver;

class GoogleGrpcStatusFunctionTest {

@RegisterExtension
static ServerExtension server = new ServerExtension() {
@Override
protected void configure(ServerBuilder sb) {
sb.service(GrpcService.builder()
.intercept(new AuthInterceptor())
.exceptionMapping(new ExceptionHandler())
.addService(new TestService())
.build())
.decorator(LoggingService.newDecorator());
}
};

@Test
void applyInternalError() {
final TestServiceBlockingStub client = GrpcClients.builder(server.httpUri())
.auth(AuthToken.ofOAuth2("token-1234"))
.build(TestServiceBlockingStub.class);
final SimpleRequest request = SimpleRequest.newBuilder()
.setFillUsername(true)
.build();
assertThatThrownBy(() -> client.unaryCall(request))
.isInstanceOfSatisfying(StatusRuntimeException.class, e -> {
assertThat(e.getStatus().getCode()).isEqualTo(Status.INTERNAL.getCode());
final com.google.rpc.Status status = e.getTrailers().get(GRPC_STATUS_DETAILS_BIN_KEY);
assertThat(status).isNotNull();
assertThat(status.getCode()).isEqualTo(Code.INTERNAL.getNumber());
assertThat(status.getDetailsCount()).isEqualTo(1);
final InternalError internalError;
try {
internalError = status.getDetails(0).unpack(InternalError.class);
} catch (InvalidProtocolBufferException ex) {
throw new RuntimeException(ex);
}
assertThat(internalError.getCode()).isEqualTo(123);
assertThat(internalError.getMessage()).isEqualTo("Unexpected error");
});
}

@Test
void applyAuthError() {
final TestServiceBlockingStub client = GrpcClients.builder(server.httpUri())
.auth(AuthToken.ofOAuth2("token-12345"))
.build(TestServiceBlockingStub.class);
final SimpleRequest request = SimpleRequest.newBuilder()
.setFillUsername(true)
.build();
assertThatThrownBy(() -> client.unaryCall(request))
.isInstanceOfSatisfying(StatusRuntimeException.class, e -> {
assertThat(e.getStatus().getCode()).isEqualTo(Status.UNAUTHENTICATED.getCode());
final com.google.rpc.Status status = e.getTrailers().get(GRPC_STATUS_DETAILS_BIN_KEY);
assertThat(status).isNotNull();
assertThat(status.getCode()).isEqualTo(Code.UNAUTHENTICATED.getNumber());
assertThat(status.getDetailsCount()).isEqualTo(1);
final AuthError authError;
try {
authError = status.getDetails(0).unpack(AuthError.class);
} catch (InvalidProtocolBufferException ex) {
throw new RuntimeException(ex);
}
assertThat(authError.getCode()).isEqualTo(334);
assertThat(authError.getMessage()).isEqualTo("Invalid token");
});
}

@Test
void earlyReturnStatusRuntimeException() {
final TestServiceBlockingStub client = GrpcClients.builder(server.httpUri())
.auth(AuthToken.ofOAuth2("token-1234"))
.build(TestServiceBlockingStub.class);
assertThatThrownBy(() -> client.emptyCall(Empty.getDefaultInstance()))
.isInstanceOfSatisfying(StatusRuntimeException.class, e -> {
assertThat(e.getStatus().getCode()).isEqualTo(Status.INTERNAL.getCode());
final com.google.rpc.Status status = e.getTrailers().get(GRPC_STATUS_DETAILS_BIN_KEY);
assertThat(status).isNotNull();
assertThat(status.getCode()).isEqualTo(Code.INTERNAL.getNumber());
assertThat(status.getMessage()).isEqualTo("Database failure");
assertThat(status.getDetailsCount()).isEqualTo(1);
final InternalError internalError;
try {
internalError = status.getDetails(0).unpack(InternalError.class);
} catch (InvalidProtocolBufferException ex) {
throw new RuntimeException(ex);
}
assertThat(internalError.getCode()).isEqualTo(321);
assertThat(internalError.getMessage()).isEqualTo("Primary DB failure");
});
}

private static final class TestService extends TestServiceImplBase {
@Override
public void unaryCall(SimpleRequest request, StreamObserver<SimpleResponse> responseObserver) {
if (request.getFillUsername()) {
throw new InternalServerException("Unexpected error", 123);
}
responseObserver.onNext(SimpleResponse.newBuilder().setUsername("Armeria").build());
responseObserver.onCompleted();
}

@Override
public void emptyCall(Empty empty, StreamObserver<Empty> responseObserver) {
final InternalError internalError = InternalError.newBuilder()
.setCode(321)
.setMessage("Primary DB failure")
.build();
final com.google.rpc.Status status = com.google.rpc.Status.newBuilder()
.setCode(Code.INTERNAL.getNumber())
.setMessage("Database failure")
.addDetails(Any.pack(internalError))
.build();
throw StatusProto.toStatusRuntimeException(status);
}
}

private static final class ExceptionHandler implements GoogleGrpcStatusFunction {

@Override
public com.google.rpc.Status applyStatusProto(RequestContext ctx, Throwable throwable,
Metadata metadata) {
if (throwable instanceof AuthenticationException) {
final AuthenticationException authenticationException = (AuthenticationException) throwable;
final AuthError authError = AuthError.newBuilder()
.setCode(authenticationException.getCode())
.setMessage(authenticationException.getMessage())
.build();
return com.google.rpc.Status.newBuilder()
.setCode(Code.UNAUTHENTICATED.getNumber())
.addDetails(Any.pack(authError))
.build();
}
if (throwable instanceof InternalServerException) {
final InternalServerException internalServerException = (InternalServerException) throwable;
final InternalError internalError = InternalError
.newBuilder()
.setCode(internalServerException.getCode())
.setMessage(internalServerException.getMessage())
.build();
return com.google.rpc.Status.newBuilder()
.setCode(Code.INTERNAL.getNumber())
.addDetails(Any.pack(internalError))
.build();
}
return null;
}
}

private static final class AuthInterceptor implements ServerInterceptor {

@Override
public <I, O> Listener<I> interceptCall(ServerCall<I, O> call, Metadata headers,
ServerCallHandler<I, O> next) {
final ServiceRequestContext ctx = ServiceRequestContext.current();
if (!ctx.request().headers().contains("Authorization", "Bearer token-1234")) {
throw new AuthenticationException("Invalid token", 334);
}
return next.startCall(call, headers);
}
}

private static final class InternalServerException extends RuntimeException {

private final int code;

InternalServerException(String message, int code) {
super(message);
this.code = code;
}

int getCode() {
return code;
}
}

private static class AuthenticationException extends RuntimeException {

private final int code;

AuthenticationException(String message, int code) {
super(message);
this.code = code;
}

int getCode() {
return code;
}
}
}
15 changes: 15 additions & 0 deletions grpc/src/test/proto/com/linecorp/armeria/grpc/testing/error.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
syntax = "proto3";

package armeria.grpc.testing;

option java_package = "com.linecorp.armeria.grpc.testing";

message InternalError {
int32 code = 1;
string message = 2;
}

message AuthError {
int32 code = 1;
string message = 2;
}

0 comments on commit d9076b5

Please sign in to comment.