diff --git a/grpc/core/src/main/java/io/helidon/grpc/core/InterceptorWeights.java b/grpc/core/src/main/java/io/helidon/grpc/core/InterceptorWeights.java new file mode 100644 index 00000000000..2975789be85 --- /dev/null +++ b/grpc/core/src/main/java/io/helidon/grpc/core/InterceptorWeights.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.grpc.core; + +/** + * gRPC interceptor weight classes. Higher weight means higher priority. + */ +public class InterceptorWeights { + + /** + * Context weight. + *

+ * Interceptors with this weight typically only perform tasks + * such as adding state to the call {@link io.grpc.Context}. + */ + public static final int CONTEXT = 5000; + + /** + * Tracing weight. + *

+ * Tracing and metrics interceptors are typically applied after any context + * interceptors so that they can trace and gather metrics on the whole call + * stack of remaining interceptors. + */ + public static final int TRACING = CONTEXT + 1; + + /** + * Security authentication weight. + */ + public static final int AUTHENTICATION = 2000; + + /** + * Security authorization weight. + */ + public static final int AUTHORIZATION = 2000; + + /** + * User-level weight. + *

+ * This value is also used as a default weight for application-supplied interceptors. + */ + public static final int USER = 1000; + + /** + * Cannot create instances. + */ + private InterceptorWeights() { + } +} diff --git a/grpc/core/src/main/java/io/helidon/grpc/core/WeightedBag.java b/grpc/core/src/main/java/io/helidon/grpc/core/WeightedBag.java index ae3f38c281b..67c66742d85 100644 --- a/grpc/core/src/main/java/io/helidon/grpc/core/WeightedBag.java +++ b/grpc/core/src/main/java/io/helidon/grpc/core/WeightedBag.java @@ -82,6 +82,15 @@ public static WeightedBag withDefaultWeight(double weight) { new ArrayList<>(), weight); } + /** + * Check if bag is empty. + * + * @return outcome of test + */ + public boolean isEmpty() { + return contents.isEmpty() && noWeightedList.isEmpty(); + } + /** * Obtain a copy of this {@link WeightedBag}. * diff --git a/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcClientMethodDescriptor.java b/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcClientMethodDescriptor.java index 1688a60aa43..b28170216cc 100644 --- a/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcClientMethodDescriptor.java +++ b/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcClientMethodDescriptor.java @@ -16,12 +16,16 @@ package io.helidon.webclient.grpc; +import java.util.Arrays; import java.util.Objects; +import io.helidon.grpc.core.InterceptorWeights; import io.helidon.grpc.core.MarshallerSupplier; import io.helidon.grpc.core.MethodHandler; +import io.helidon.grpc.core.WeightedBag; import io.grpc.CallCredentials; +import io.grpc.ClientInterceptor; import io.grpc.MethodDescriptor; /** @@ -48,6 +52,11 @@ public final class GrpcClientMethodDescriptor { */ private final MethodDescriptor descriptor; + /** + * The list of client interceptors for this method. + */ + private WeightedBag interceptors; + /** * The {@link io.grpc.CallCredentials} for this method. */ @@ -60,10 +69,12 @@ public final class GrpcClientMethodDescriptor { private GrpcClientMethodDescriptor(String name, MethodDescriptor descriptor, + WeightedBag interceptors, CallCredentials callCredentials, MethodHandler methodHandler) { this.name = name; this.descriptor = descriptor; + this.interceptors = interceptors; this.callCredentials = callCredentials; this.methodHandler = methodHandler; } @@ -146,6 +157,15 @@ public static Builder bidirectional(String serviceName, String name) { return builder(serviceName, name, MethodDescriptor.MethodType.BIDI_STREAMING); } + /** + * Obtain the {@link ClientInterceptor}s to use for this method. + * + * @return the {@link ClientInterceptor}s to use for this method + */ + WeightedBag interceptors() { + return interceptors.readOnly(); + } + /** * Return the {@link io.grpc.CallCredentials} set on this service. * @@ -234,6 +254,25 @@ public interface Rules { */ Rules responseType(Class type); + /** + * Register one or more {@link ClientInterceptor interceptors} for the method. + * + * @param interceptors the interceptor(s) to register + * @return this {@link Rules} instance for fluent call chaining + */ + Rules intercept(ClientInterceptor... interceptors); + + /** + * Register one or more {@link ClientInterceptor interceptors} for the method. + *

+ * The added interceptors will be applied using the specified priority. + * + * @param weight the weight to assign to the interceptors + * @param interceptors one or more {@link ClientInterceptor}s to register + * @return this {@link Rules} to allow fluent method chaining + */ + Rules intercept(double weight, ClientInterceptor... interceptors); + /** * Register the {@link MarshallerSupplier} for the method. *

@@ -276,6 +315,7 @@ public static class Builder private final MethodDescriptor.Builder descriptor; private Class requestType; private Class responseType; + private final WeightedBag interceptors = WeightedBag.withDefaultWeight(InterceptorWeights.USER); private MarshallerSupplier defaultMarshallerSupplier = MarshallerSupplier.defaultInstance(); private MarshallerSupplier marshallerSupplier; private CallCredentials callCredentials; @@ -305,6 +345,18 @@ public Builder responseType(Class type) { return this; } + @Override + public Builder intercept(ClientInterceptor... interceptors) { + this.interceptors.addAll(Arrays.asList(interceptors)); + return this; + } + + @Override + public Builder intercept(double weight, ClientInterceptor... interceptors) { + this.interceptors.addAll(Arrays.asList(interceptors), weight); + return this; + } + @Override public Builder marshallerSupplier(MarshallerSupplier supplier) { this.marshallerSupplier = supplier; @@ -364,6 +416,7 @@ public GrpcClientMethodDescriptor build() { return new GrpcClientMethodDescriptor(name, descriptor.build(), + interceptors, callCredentials, methodHandler); } diff --git a/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcServiceClientImpl.java b/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcServiceClientImpl.java index 3b31b7d0200..e9faafaa899 100644 --- a/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcServiceClientImpl.java +++ b/webclient/grpc/src/main/java/io/helidon/webclient/grpc/GrpcServiceClientImpl.java @@ -19,26 +19,49 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; + +import io.helidon.grpc.core.WeightedBag; import io.grpc.CallOptions; +import io.grpc.Channel; import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ClientInterceptors; import io.grpc.MethodDescriptor; import io.grpc.stub.ClientCalls; import io.grpc.stub.StreamObserver; class GrpcServiceClientImpl implements GrpcServiceClient { - private final GrpcServiceDescriptor descriptor; + private final GrpcServiceDescriptor serviceDescriptor; + private final Channel serviceChannel; private final GrpcClientImpl grpcClient; + private final Map methodCache = new ConcurrentHashMap<>(); GrpcServiceClientImpl(GrpcServiceDescriptor descriptor, GrpcClientImpl grpcClient) { - this.descriptor = descriptor; + this.serviceDescriptor = descriptor; this.grpcClient = grpcClient; + + if (descriptor.interceptors().isEmpty()) { + serviceChannel = grpcClient.channel(); + } else { + // sort interceptors using a weighted bag + WeightedBag interceptors = WeightedBag.create(); + for (ClientInterceptor interceptor : descriptor.interceptors()) { + interceptors.add(interceptor); + } + + // wrap channel to call interceptors -- reversed for composition + List orderedInterceptors = interceptors.stream().toList().reversed(); + serviceChannel = ClientInterceptors.intercept(grpcClient.channel(), orderedInterceptors); + } } @Override public String serviceName() { - return descriptor.serviceName(); + return serviceDescriptor.serviceName(); } @Override @@ -152,13 +175,26 @@ public StreamObserver bidi(String methodName, StreamObserver< } private ClientCall ensureMethod(String methodName, MethodDescriptor.MethodType methodType) { - GrpcClientMethodDescriptor method = descriptor.method(methodName); - if (!method.type().equals(methodType)) { - throw new IllegalArgumentException("Method " + methodName + " is of type " + method.type() + GrpcClientMethodDescriptor methodDescriptor = serviceDescriptor.method(methodName); + if (!methodDescriptor.type().equals(methodType)) { + throw new IllegalArgumentException("Method " + methodName + " is of type " + methodDescriptor.type() + ", yet " + methodType + " was requested."); } - return methodType == MethodDescriptor.MethodType.UNARY - ? new GrpcUnaryClientCall<>(grpcClient, method.descriptor(), CallOptions.DEFAULT) - : new GrpcClientCall<>(grpcClient, method.descriptor(), CallOptions.DEFAULT); + + // use channel that contains all service and method interceptors + if (methodDescriptor.interceptors().isEmpty()) { + return serviceChannel.newCall(methodDescriptor.descriptor(), CallOptions.DEFAULT); + } else { + Channel methodChannel = methodCache.computeIfAbsent(methodName, k -> { + WeightedBag interceptors = WeightedBag.create(); + for (ClientInterceptor interceptor : serviceDescriptor.interceptors()) { + interceptors.add(interceptor); + } + interceptors.merge(methodDescriptor.interceptors()); + List orderedInterceptors = interceptors.stream().toList().reversed(); + return ClientInterceptors.intercept(grpcClient.channel(), orderedInterceptors); + }); + return methodChannel.newCall(methodDescriptor.descriptor(), CallOptions.DEFAULT); + } } } diff --git a/webclient/tests/grpc/src/test/java/io/helidon/webclient/grpc/tests/GrpcInterceptorTest.java b/webclient/tests/grpc/src/test/java/io/helidon/webclient/grpc/tests/GrpcInterceptorTest.java new file mode 100644 index 00000000000..692b405ce11 --- /dev/null +++ b/webclient/tests/grpc/src/test/java/io/helidon/webclient/grpc/tests/GrpcInterceptorTest.java @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2024 Oracle and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.helidon.webclient.grpc.tests; + +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.MethodDescriptor; +import io.helidon.common.Weight; +import io.helidon.common.configurable.Resource; +import io.helidon.common.tls.Tls; +import io.helidon.webclient.grpc.GrpcClient; +import io.helidon.webclient.grpc.GrpcClientMethodDescriptor; +import io.helidon.webclient.grpc.GrpcServiceDescriptor; +import io.helidon.webserver.WebServer; +import io.helidon.webserver.testing.junit5.ServerTest; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; + +/** + * Tests client interceptors using low-level API. + */ +@ServerTest +class GrpcInterceptorTest extends GrpcBaseTest { + + private final GrpcClient grpcClient; + private final GrpcServiceDescriptor serviceDescriptor; + private final List> calledInterceptors = new CopyOnWriteArrayList<>(); + + private GrpcInterceptorTest(WebServer server) { + Tls clientTls = Tls.builder() + .trust(trust -> trust + .keystore(store -> store + .passphrase("password") + .trustStore(true) + .keystore(Resource.create("client.p12")))) + .build(); + this.grpcClient = GrpcClient.builder() + .tls(clientTls) + .baseUri("https://localhost:" + server.port()) + .build(); + this.serviceDescriptor = GrpcServiceDescriptor.builder() + .serviceName("StringService") + .putMethod("Upper", + GrpcClientMethodDescriptor.unary("StringService", "Upper") + .requestType(Strings.StringMessage.class) + .responseType(Strings.StringMessage.class) + .intercept(new Weight50Interceptor()) + .intercept(new Weight500Interceptor()) + .build()) + .addInterceptor(new Weight100Interceptor()) + .addInterceptor(new Weight1000Interceptor()) + .addInterceptor(new Weight10Interceptor()) + .build(); + } + + @Test + void testUnaryUpper() { + Strings.StringMessage res = grpcClient.serviceClient(serviceDescriptor) + .unary("Upper", newStringMessage("hello")); + assertThat(res.getText(), is("HELLO")); + assertThat(calledInterceptors, contains(Weight1000Interceptor.class, + Weight500Interceptor.class, + Weight100Interceptor.class, + Weight50Interceptor.class, + Weight10Interceptor.class)); + } + + class BaseInterceptor implements ClientInterceptor { + @Override + public ClientCall interceptCall(MethodDescriptor method, + CallOptions callOptions, + Channel next) { + calledInterceptors.add(getClass()); + return next.newCall(method, callOptions); + } + } + + @Weight(10.0) + class Weight10Interceptor extends BaseInterceptor { + } + + @Weight(50.0) + class Weight50Interceptor extends BaseInterceptor { + } + + @Weight(100.0) + class Weight100Interceptor extends BaseInterceptor { + } + + @Weight(500.0) + class Weight500Interceptor extends BaseInterceptor { + } + + @Weight(1000.0) + class Weight1000Interceptor extends BaseInterceptor { + } +}