diff --git a/binder/src/main/java/io/grpc/binder/SecurityPolicies.java b/binder/src/main/java/io/grpc/binder/SecurityPolicies.java index 05e8c43da79..fa734762650 100644 --- a/binder/src/main/java/io/grpc/binder/SecurityPolicies.java +++ b/binder/src/main/java/io/grpc/binder/SecurityPolicies.java @@ -16,6 +16,9 @@ package io.grpc.binder; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.util.concurrent.Futures.immediateFuture; + import android.annotation.SuppressLint; import android.app.admin.DevicePolicyManager; import android.content.Context; @@ -32,6 +35,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.hash.Hashing; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; import com.google.errorprone.annotations.CheckReturnValue; import io.grpc.ExperimentalApi; import io.grpc.Status; @@ -333,6 +339,9 @@ private static boolean checkPackageSignature( * Creates a {@link SecurityPolicy} that allows access if and only if *all* of the specified * {@code securityPolicies} allow access. * + *
If any of the policies is an {@link AsyncSecurityPolicy}, then all policies may be evaluated
+ * concurrently to speed up the success scenario.
+ *
* @param securityPolicies the security policies that all must allow access.
* @throws NullPointerException if any of the inputs are {@code null}.
* @throws IllegalArgumentException if {@code securityPolicies} is empty.
@@ -341,10 +350,17 @@ public static SecurityPolicy allOf(SecurityPolicy... securityPolicies) {
Preconditions.checkNotNull(securityPolicies, "securityPolicies");
Preconditions.checkArgument(securityPolicies.length > 0, "securityPolicies must not be empty");
- return allOfSecurityPolicy(securityPolicies);
+ boolean anyAsync =
+ Arrays
+ .stream(securityPolicies)
+ .anyMatch(policy -> policy instanceof AsyncSecurityPolicy);
+
+ return anyAsync
+ ? allOfSecurityPolicyAsync(securityPolicies)
+ : allOfSecurityPolicySync(securityPolicies);
}
- private static SecurityPolicy allOfSecurityPolicy(SecurityPolicy... securityPolicies) {
+ private static SecurityPolicy allOfSecurityPolicySync(SecurityPolicy... securityPolicies) {
return new SecurityPolicy() {
@Override
public Status checkAuthorization(int uid) {
@@ -360,6 +376,33 @@ public Status checkAuthorization(int uid) {
};
}
+ private static SecurityPolicy allOfSecurityPolicyAsync(SecurityPolicy... securityPolicies) {
+ return new AsyncSecurityPolicy() {
+ @Override
+ public ListenableFuture> futureStatuses = Futures.allAsList(allStatuses);
+
+ return Futures
+ .transform(
+ futureStatuses,statuses ->
+ statuses.stream().filter(status -> !status.isOk()).findFirst().orElse(Status.OK),
+ MoreExecutors.directExecutor());
+ }
+ };
+ }
+
/**
* Creates a {@link SecurityPolicy} that allows access if *any* of the specified {@code
* securityPolicies} allow access.
diff --git a/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java b/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java
index 84c76a84bf2..7b0388865cc 100644
--- a/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java
+++ b/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java
@@ -22,6 +22,7 @@
import static android.content.pm.PackageInfo.REQUESTED_PERMISSION_GRANTED;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.truth.Truth.assertThat;
+import static com.google.common.util.concurrent.Futures.immediateFuture;
import static org.robolectric.Shadows.shadowOf;
import android.app.admin.DevicePolicyManager;
@@ -35,9 +36,11 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.hash.Hashing;
+import com.google.common.util.concurrent.ListenableFuture;
import io.grpc.Status;
import java.util.HashMap;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -523,15 +526,38 @@ public void testAllOf_succeedsIfAllSecurityPoliciesAllowed() throws Exception {
@Test
public void testAllOf_failsIfOneSecurityPoliciesNotAllowed() throws Exception {
+ policy =
+ SecurityPolicies.allOf(
+ SecurityPolicies.internalOnly(),
+ SecurityPolicies.permissionDenied("Not allowed SecurityPolicy"));
+
+ assertThat(policy.checkAuthorization(MY_UID).getCode())
+ .isEqualTo(Status.PERMISSION_DENIED.getCode());
+ assertThat(policy.checkAuthorization(MY_UID).getDescription())
+ .contains("Not allowed SecurityPolicy");
+ }
+
+ @Test
+ public void testAllOfAsync_succeedsIfAllSecurityPoliciesAllowed() {
+ policy =
+ SecurityPolicies.allOf(
+ SecurityPolicies.internalOnly(),
+ makeAsyncPolicy(uid -> immediateFuture(Status.OK)));
+
+ assertThat(policy.checkAuthorization(MY_UID).getCode()).isEqualTo(Status.OK.getCode());
+ }
+
+ @Test
+ public void testAllOfAsync_failsIfOneSecurityPoliciesNotAllowed() {
policy =
SecurityPolicies.allOf(
SecurityPolicies.internalOnly(),
- SecurityPolicies.permissionDenied("Not allowed SecurityPolicy"));
+ makeAsyncPolicy(uid -> immediateFuture(Status.OK)),
+ makeAsyncPolicy(uid -> immediateFuture(Status.ABORTED)),
+ makeAsyncPolicy(uid -> immediateFuture(Status.INVALID_ARGUMENT)));
assertThat(policy.checkAuthorization(MY_UID).getCode())
- .isEqualTo(Status.PERMISSION_DENIED.getCode());
- assertThat(policy.checkAuthorization(MY_UID).getDescription())
- .contains("Not allowed SecurityPolicy");
+ .isEqualTo(Status.Code.ABORTED);
}
@Test
@@ -703,4 +729,13 @@ public void testOneOfSignatureSha256Hash_failsIfPackageNameMatchAndOneOfSignatur
private static byte[] getSha256Hash(Signature signature) {
return Hashing.sha256().hashBytes(signature.toByteArray()).asBytes();
}
+
+ private static AsyncSecurityPolicy makeAsyncPolicy(Function