From 83dec04ef56d5216a53f18a358a7c140d4ddedec Mon Sep 17 00:00:00 2001 From: Pablo Arteaga Date: Fri, 16 Jun 2023 16:10:28 -0400 Subject: [PATCH 01/11] Implement OPA --- core/trino-server/src/main/provisio/trino.xml | 6 + plugin/trino-opa/README.md | 206 +++ plugin/trino-opa/pom.xml | 161 ++ .../main/java/io/trino/plugin/opa/ForOpa.java | 30 + .../io/trino/plugin/opa/OpaAccessControl.java | 891 ++++++++++ .../plugin/opa/OpaAccessControlFactory.java | 107 ++ .../plugin/opa/OpaAccessControlModule.java | 60 + .../plugin/opa/OpaAccessControlPlugin.java | 28 + .../plugin/opa/OpaBatchAccessControl.java | 160 ++ .../java/io/trino/plugin/opa/OpaConfig.java | 83 + .../trino/plugin/opa/OpaHighLevelClient.java | 117 ++ .../io/trino/plugin/opa/OpaHttpClient.java | 222 +++ .../trino/plugin/opa/OpaQueryException.java | 68 + .../opa/schema/OpaBatchQueryResult.java | 30 + .../io/trino/plugin/opa/schema/OpaQuery.java | 24 + .../plugin/opa/schema/OpaQueryContext.java | 37 + .../plugin/opa/schema/OpaQueryInput.java | 25 + .../opa/schema/OpaQueryInputAction.java | 109 ++ .../plugin/opa/schema/OpaQueryInputGrant.java | 76 + .../opa/schema/OpaQueryInputResource.java | 146 ++ .../plugin/opa/schema/OpaQueryResult.java | 18 + .../plugin/opa/schema/PropertiesMapper.java | 32 + .../schema/TrinoCatalogSessionProperty.java | 27 + .../plugin/opa/schema/TrinoFunction.java | 45 + .../opa/schema/TrinoGrantPrincipal.java | 43 + .../plugin/opa/schema/TrinoIdentity.java | 45 + .../trino/plugin/opa/schema/TrinoSchema.java | 56 + .../trino/plugin/opa/schema/TrinoTable.java | 69 + .../io/trino/plugin/opa/schema/TrinoUser.java | 45 + .../plugin/opa/FilteringTestHelpers.java | 68 + .../trino/plugin/opa/FunctionalHelpers.java | 74 + .../io/trino/plugin/opa/HttpClientUtils.java | 131 ++ .../OpaAccessControlFilteringUnitTest.java | 362 ++++ .../opa/OpaAccessControlSystemTest.java | 330 ++++ .../plugin/opa/OpaAccessControlUnitTest.java | 1459 +++++++++++++++++ ...paBatchAccessControlFilteringUnitTest.java | 470 ++++++ .../plugin/opa/RequestTestUtilities.java | 76 + .../io/trino/plugin/opa/ResponseTest.java | 133 ++ .../java/io/trino/plugin/opa/TestFactory.java | 85 + .../java/io/trino/plugin/opa/TestHelpers.java | 142 ++ .../io/trino/plugin/opa/TestOpaConfig.java | 56 + pom.xml | 1 + 42 files changed, 6353 insertions(+) create mode 100644 plugin/trino-opa/README.md create mode 100644 plugin/trino-opa/pom.xml create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/ForOpa.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlFactory.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlModule.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlPlugin.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaBatchAccessControl.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaConfig.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHighLevelClient.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHttpClient.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaQueryException.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaBatchQueryResult.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQuery.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryContext.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInput.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputAction.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputGrant.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputResource.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryResult.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/PropertiesMapper.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoCatalogSessionProperty.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoFunction.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoGrantPrincipal.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoIdentity.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoSchema.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoTable.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoUser.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/FilteringTestHelpers.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/FunctionalHelpers.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlFilteringUnitTest.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlSystemTest.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlUnitTest.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaBatchAccessControlFilteringUnitTest.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/ResponseTest.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestFactory.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaConfig.java diff --git a/core/trino-server/src/main/provisio/trino.xml b/core/trino-server/src/main/provisio/trino.xml index a11e4ade027e5..eb69c2a406970 100644 --- a/core/trino-server/src/main/provisio/trino.xml +++ b/core/trino-server/src/main/provisio/trino.xml @@ -319,4 +319,10 @@ + + + + + + diff --git a/plugin/trino-opa/README.md b/plugin/trino-opa/README.md new file mode 100644 index 0000000000000..25878d80d60b8 --- /dev/null +++ b/plugin/trino-opa/README.md @@ -0,0 +1,206 @@ +# trino-opa + +This plugin enables Trino to use Open Policy Agent (OPA) as an authorization engine. + +For more information on OPA, please refer to the Open Policy Agent [documentation](https://www.openpolicyagent.org/). + +## Configuration + +You will need to configure Trino to use the OPA plugin as its access control engine, then configure the +plugin to contact your OPA endpoint. + +`config.properties` - **enabling the plugin**: + +Make sure to enable the plugin by configuring Trino to pull in the relevant config file for the OPA +authorizer, e.g.: + +```properties +access-control.config-files=/etc/trino/access-control-file-based.properties,/etc/trino/access-control-opa.properties +``` + +`access-control-opa.properties` - **configuring the plugin**: + +Set the access control name to `opa` and specify the policy URI, for example: + +```properties +access-control.name=opa +opa.policy.uri=https://your-opa-endpoint/v1/data/allow +``` + +If you also want to enable the _batch_ mode (see [Batch mode](#batch-mode)), you must additionally set up an +`opa.policy.batched-uri` configuration entry. + +> Batch mode is _not_ a replacement for the "main" URI. The batch mode is _only_ +> used for certain authorization queries where batching is applicable. Even when using +> `opa.policy.batched-uri`, you _must_ still provide an `opa.policy.uri` + +For instance: + +```properties +access-control.name=opa +opa.policy.uri=https://your-opa-endpoint/v1/data/allow +opa.policy.batched-uri=https://your-opa-endpoint/v1/data/batch +``` + +### All configuration entries + +| Configuration name | Required | Default | Description | +|--------------------------|:--------:|:-------:|------------------------------------------------------------------------------------------------------------------------------| +| `opa.policy.uri` | Yes | N/A | Endpoint to query OPA | +| `opa.policy.batched-uri` | No | Unset | Endpoint for batch OPA requests | +| `opa.log-requests` | No | `false` | Determines whether requests (URI, headers and entire body) are logged prior to sending them to OPA | +| `opa.log-responses` | No | `false` | Determines whether OPA responses (URI, status code, headers and entire body) are logged | +| `opa.http-client.*` | No | Unset | Additional HTTP client configurations that get passed down. E.g. `opa.http-client.http-proxy` for configuring the HTTP proxy | + +> When request / response logging is enabled, they will be logged at DEBUG level under the `io.trino.plugin.opa.OpaHttpClient` logger, you will need to update +> your log configuration accordingly. +> +> Be aware that enabling these options will produce very large amounts of logs + + +## OPA queries + +The plugin will contact OPA for each authorization request as defined on the SPI. + +OPA must return a response containing a boolean `allow` field, which will determine whether the operation +is permitted or not. + +The plugin will pass as much context as possible within the OPA request. A simple way of checking +what data is passed in from Trino is to run OPA locally in verbose mode. + +### Query structure + +A query will contain a `context` and an `action` as its top level fields. + +#### Query context: + +This determines _who_ is performing the operations, and reflects the `SystemSecurityContext` class in Trino. + +#### Query action: + +This determines _what_ action is being performed and upon what resources, the top level fields are as follows: + +- `operation` (string): operation being performed +- `resource` (object, nullable): information about the object being operated upon +- `targetResource` (object, nullable): information about the _new object_ being created, if applicable +- `grantee` (object, nullable): grantee of a grant operation +- `grantor` (object, nullable): grantor in a grant operation + +#### Examples + +Accessing a table will result in a query like the one below: + +```json +{ + "context": { + "identity": { + "user": "foo", + "groups": ["some-group"] + } + }, + "action": { + "operation": "SelectFromColumns", + "resource": { + "table": { + "catalogName": "my_catalog", + "schemaName": "my_schema", + "tableName": "my_table", + "columns": [ + "column1", + "column2", + "column3" + ] + } + } + } +} +``` + +`targetResource` is used in cases where a new resource, distinct from the one in `resource` is being created. For instance, +when renaming a table. + +```json +{ + "context": { + "identity": { + "user": "foo", + "groups": ["some-group"] + } + }, + "action": { + "operation": "RenameTable", + "resource": { + "table": { + "catalogName": "my_catalog", + "schemaName": "my_schema", + "tableName": "my_table" + } + }, + "targetResource": { + "table": { + "catalogName": "my_catalog", + "schemaName": "my_schema", + "tableName": "new_table_name" + } + } + } +} +``` + + +## Batch mode + +A very powerful feature provided by OPA is its ability to respond to authorization queries with +more complex answers than a `true`/`false` boolean value. + +Many features in Trino require _filtering_ to be performed to determine, given a list of resources, +(e.g. tables, queries, views, etc...) which of those a user should be entitled to see/interact with. + +If `opa.policy.batched-uri` is _not_ configured, the plugin will send one request to OPA _per item_ being +filtered, then use the responses from OPA to construct a filtered list containing only those items for which +a `true` response was returned. + +Configuring `opa.policy.batched-uri` will allow the plugin to send a request to that _batch_ endpoint instead, +with a **list** of the resources being filtered under `action.filterResources` (as opposed to `action.resource`). + +> The other fields in the request are identical to the non-batch endpoint. + +An OPA policy supporting batch operations should return a (potentially empty) list containing the _indices_ +of the items for which authorization is granted (if any). Returning a `null` value instead of a list +is equivalent to returning an empty list. + +> We may want to reconsider the choice of using _indices_ in the response as opposed to returning a list +> containing copies of elements from the `filterResources` field in the request for which access should +> be granted. Indices were chosen over copying elements as it made validation in the plugin easier, +> and from the few examples we tried, it also made certain policies a bit simpler. Any feedback is appreciated! + +An interesting side effect of this is that we can add batching support for policies that didn't originally +have it quite easily. Consider the following rego: + +```rego +package foo + +# ... rest of the policy ... +# this assumes the non-batch response field is called "allow" +batch contains i { + some i + raw_resource := input.action.filterResources[i] + allow with input.action.resource as raw_resource +} + +# Corner case: filtering columns is done with a single table item, and many columns inside +# We cannot use our normal logic in other parts of the policy as they are based on sets +# and we need to retain order +batch contains i { + some i + input.action.operation == "FilterColumns" + count(input.action.filterResources) == 1 + raw_resource := input.action.filterResources[0] + count(raw_resource["table"]["columns"]) > 0 + new_resources := [ + object.union(raw_resource, {"table": {"column": column_name}}) + | column_name := raw_resource["table"]["columns"][_] + ] + allow with input.action.resource as new_resources[i] +} +``` diff --git a/plugin/trino-opa/pom.xml b/plugin/trino-opa/pom.xml new file mode 100644 index 0000000000000..3964a1dd30da9 --- /dev/null +++ b/plugin/trino-opa/pom.xml @@ -0,0 +1,161 @@ + + + 4.0.0 + + io.trino + trino-root + 431-SNAPSHOT + ../../pom.xml + + + trino-opa + + trino-plugin + Trino - Open Policy Agent + + + ${project.parent.basedir} + + + + + com.google.guava + guava + + + + com.google.inject + guice + + + + io.airlift + bootstrap + + + + io.airlift + concurrent + + + + io.airlift + configuration + + + + io.airlift + http-client + + + + io.airlift + json + + + + io.airlift + log + + + + jakarta.validation + jakarta.validation-api + + + + com.fasterxml.jackson.core + jackson-annotations + provided + + + + io.airlift + slice + provided + + + + io.opentelemetry + opentelemetry-api + provided + + + io.opentelemetry + opentelemetry-context + provided + + + + io.trino + trino-spi + provided + + + + org.openjdk.jol + jol-core + provided + + + + com.fasterxml.jackson.core + jackson-databind + runtime + + + com.fasterxml.jackson.datatype + jackson-datatype-jdk8 + runtime + + + + io.trino + trino-blackhole + test + + + + io.trino + trino-main + test + + + + io.trino + trino-testing + test + + + + org.junit.jupiter + junit-jupiter + test + + + + org.junit.jupiter + junit-jupiter-api + test + + + + org.junit.jupiter + junit-jupiter-params + test + + + + org.testcontainers + junit-jupiter + test + + + org.testcontainers + testcontainers + test + + + + + diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/ForOpa.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/ForOpa.java new file mode 100644 index 0000000000000..d580b645a8f56 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/ForOpa.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.inject.BindingAnnotation; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@BindingAnnotation +@Target({FIELD, PARAMETER, METHOD}) +@Retention(RUNTIME) +public @interface ForOpa { +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java new file mode 100644 index 0000000000000..0d8d9bd881e57 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java @@ -0,0 +1,891 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ImmutableSetMultimap; +import com.google.common.collect.Multimaps; +import com.google.inject.Inject; +import io.trino.plugin.opa.schema.OpaQueryContext; +import io.trino.plugin.opa.schema.OpaQueryInput; +import io.trino.plugin.opa.schema.OpaQueryInputAction; +import io.trino.plugin.opa.schema.OpaQueryInputGrant; +import io.trino.plugin.opa.schema.OpaQueryInputResource; +import io.trino.plugin.opa.schema.TrinoCatalogSessionProperty; +import io.trino.plugin.opa.schema.TrinoFunction; +import io.trino.plugin.opa.schema.TrinoGrantPrincipal; +import io.trino.plugin.opa.schema.TrinoSchema; +import io.trino.plugin.opa.schema.TrinoTable; +import io.trino.plugin.opa.schema.TrinoUser; +import io.trino.spi.connector.CatalogSchemaName; +import io.trino.spi.connector.CatalogSchemaRoutineName; +import io.trino.spi.connector.CatalogSchemaTableName; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.function.SchemaFunctionName; +import io.trino.spi.security.AccessDeniedException; +import io.trino.spi.security.Identity; +import io.trino.spi.security.Privilege; +import io.trino.spi.security.SystemAccessControl; +import io.trino.spi.security.SystemSecurityContext; +import io.trino.spi.security.TrinoPrincipal; + +import java.security.Principal; +import java.util.Collection; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.trino.plugin.opa.OpaHighLevelClient.buildQueryInputForSimpleResource; +import static io.trino.plugin.opa.schema.PropertiesMapper.convertProperties; +import static io.trino.spi.security.AccessDeniedException.denyCreateCatalog; +import static io.trino.spi.security.AccessDeniedException.denyCreateFunction; +import static io.trino.spi.security.AccessDeniedException.denyCreateRole; +import static io.trino.spi.security.AccessDeniedException.denyCreateSchema; +import static io.trino.spi.security.AccessDeniedException.denyCreateViewWithSelect; +import static io.trino.spi.security.AccessDeniedException.denyDenySchemaPrivilege; +import static io.trino.spi.security.AccessDeniedException.denyDenyTablePrivilege; +import static io.trino.spi.security.AccessDeniedException.denyDropCatalog; +import static io.trino.spi.security.AccessDeniedException.denyDropFunction; +import static io.trino.spi.security.AccessDeniedException.denyDropRole; +import static io.trino.spi.security.AccessDeniedException.denyDropSchema; +import static io.trino.spi.security.AccessDeniedException.denyExecuteProcedure; +import static io.trino.spi.security.AccessDeniedException.denyExecuteTableProcedure; +import static io.trino.spi.security.AccessDeniedException.denyGrantRoles; +import static io.trino.spi.security.AccessDeniedException.denyGrantSchemaPrivilege; +import static io.trino.spi.security.AccessDeniedException.denyGrantTablePrivilege; +import static io.trino.spi.security.AccessDeniedException.denyImpersonateUser; +import static io.trino.spi.security.AccessDeniedException.denyRenameMaterializedView; +import static io.trino.spi.security.AccessDeniedException.denyRenameSchema; +import static io.trino.spi.security.AccessDeniedException.denyRenameTable; +import static io.trino.spi.security.AccessDeniedException.denyRenameView; +import static io.trino.spi.security.AccessDeniedException.denyRevokeRoles; +import static io.trino.spi.security.AccessDeniedException.denyRevokeSchemaPrivilege; +import static io.trino.spi.security.AccessDeniedException.denyRevokeTablePrivilege; +import static io.trino.spi.security.AccessDeniedException.denySetCatalogSessionProperty; +import static io.trino.spi.security.AccessDeniedException.denySetSchemaAuthorization; +import static io.trino.spi.security.AccessDeniedException.denySetSystemSessionProperty; +import static io.trino.spi.security.AccessDeniedException.denySetTableAuthorization; +import static io.trino.spi.security.AccessDeniedException.denySetViewAuthorization; +import static io.trino.spi.security.AccessDeniedException.denyShowCreateSchema; +import static io.trino.spi.security.AccessDeniedException.denyShowFunctions; +import static io.trino.spi.security.AccessDeniedException.denyShowTables; + +public sealed class OpaAccessControl + implements SystemAccessControl + permits OpaBatchAccessControl +{ + private final OpaHighLevelClient opaHighLevelClient; + + @Inject + public OpaAccessControl(OpaHighLevelClient opaHighLevelClient) + { + this.opaHighLevelClient = opaHighLevelClient; + } + + @Override + public void checkCanImpersonateUser(Identity identity, String userName) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromIdentity(identity), + "ImpersonateUser", + () -> denyImpersonateUser(identity.getUser(), userName), + OpaQueryInputResource.builder().user(new TrinoUser(userName)).build()); + } + + @Override + public void checkCanSetUser(Optional principal, String userName) + {} + + @Override + public void checkCanExecuteQuery(Identity identity) + { + opaHighLevelClient.queryAndEnforce(OpaQueryContext.fromIdentity(identity), "ExecuteQuery", AccessDeniedException::denyExecuteQuery); + } + + @Override + public void checkCanViewQueryOwnedBy(Identity identity, Identity queryOwner) + { + opaHighLevelClient.queryAndEnforce(OpaQueryContext.fromIdentity(identity), "ViewQueryOwnedBy", AccessDeniedException::denyViewQuery, OpaQueryInputResource.builder().user(new TrinoUser(queryOwner)).build()); + } + + @Override + public Collection filterViewQueryOwnedBy(Identity identity, Collection queryOwners) + { + return opaHighLevelClient.parallelFilterFromOpa( + queryOwners, + queryOwner -> buildQueryInputForSimpleResource( + OpaQueryContext.fromIdentity(identity), + "FilterViewQueryOwnedBy", + OpaQueryInputResource.builder().user(new TrinoUser(queryOwner)).build())); + } + + @Override + public void checkCanKillQueryOwnedBy(Identity identity, Identity queryOwner) + { + opaHighLevelClient.queryAndEnforce(OpaQueryContext.fromIdentity(identity), "KillQueryOwnedBy", AccessDeniedException::denyKillQuery, OpaQueryInputResource.builder().user(new TrinoUser(queryOwner)).build()); + } + + @Override + public void checkCanReadSystemInformation(Identity identity) + { + opaHighLevelClient.queryAndEnforce(OpaQueryContext.fromIdentity(identity), "ReadSystemInformation", AccessDeniedException::denyReadSystemInformationAccess); + } + + @Override + public void checkCanWriteSystemInformation(Identity identity) + { + opaHighLevelClient.queryAndEnforce(OpaQueryContext.fromIdentity(identity), "WriteSystemInformation", AccessDeniedException::denyWriteSystemInformationAccess); + } + + @Override + public void checkCanSetSystemSessionProperty(Identity identity, String propertyName) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromIdentity(identity), + "SetSystemSessionProperty", + () -> denySetSystemSessionProperty(propertyName), + OpaQueryInputResource.builder().systemSessionProperty(propertyName).build()); + } + + @Override + public boolean canAccessCatalog(SystemSecurityContext context, String catalogName) + { + return opaHighLevelClient.queryOpaWithSimpleResource( + OpaQueryContext.fromSystemSecurityContext(context), + "AccessCatalog", + OpaQueryInputResource.builder().catalog(catalogName).build()); + } + + @Override + public void checkCanCreateCatalog(SystemSecurityContext context, String catalog) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + "CreateCatalog", + () -> denyCreateCatalog(catalog), + OpaQueryInputResource.builder().catalog(catalog).build()); + } + + @Override + public void checkCanDropCatalog(SystemSecurityContext context, String catalog) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + "DropCatalog", + () -> denyDropCatalog(catalog), + OpaQueryInputResource.builder().catalog(catalog).build()); + } + + @Override + public Set filterCatalogs(SystemSecurityContext context, Set catalogs) + { + return opaHighLevelClient.parallelFilterFromOpa( + catalogs, + catalog -> buildQueryInputForSimpleResource( + OpaQueryContext.fromSystemSecurityContext(context), + "FilterCatalogs", + OpaQueryInputResource.builder().catalog(catalog).build())); + } + + @Override + public void checkCanCreateSchema(SystemSecurityContext context, CatalogSchemaName schema, Map properties) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + "CreateSchema", + () -> denyCreateSchema(schema.toString()), + OpaQueryInputResource.builder().schema(new TrinoSchema(schema).withProperties(convertProperties(properties))).build()); + } + + @Override + public void checkCanDropSchema(SystemSecurityContext context, CatalogSchemaName schema) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + "DropSchema", + () -> denyDropSchema(schema.toString()), + OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build()); + } + + @Override + public void checkCanRenameSchema(SystemSecurityContext context, CatalogSchemaName schema, String newSchemaName) + { + OpaQueryInputResource resource = OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build(); + OpaQueryInputResource targetResource = OpaQueryInputResource.builder().schema(new TrinoSchema(schema.getCatalogName(), newSchemaName)).build(); + + OpaQueryContext queryContext = OpaQueryContext.fromSystemSecurityContext(context); + + if (!opaHighLevelClient.queryOpaWithSourceAndTargetResource(queryContext, "RenameSchema", resource, targetResource)) { + denyRenameSchema(schema.toString(), newSchemaName); + } + } + + @Override + public void checkCanSetSchemaAuthorization(SystemSecurityContext context, CatalogSchemaName schema, TrinoPrincipal principal) + { + OpaQueryInputResource resource = OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build(); + OpaQueryInputGrant grantee = OpaQueryInputGrant.builder().principal(TrinoGrantPrincipal.fromTrinoPrincipal(principal)).build(); + OpaQueryInputAction action = OpaQueryInputAction.builder() + .operation("SetSchemaAuthorization") + .resource(resource) + .grantee(grantee) + .build(); + OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + + if (!opaHighLevelClient.queryOpa(input)) { + denySetSchemaAuthorization(schema.toString(), principal); + } + } + + @Override + public void checkCanShowSchemas(SystemSecurityContext context, String catalogName) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + "ShowSchemas", + AccessDeniedException::denyShowSchemas, + OpaQueryInputResource.builder().catalog(catalogName).build()); + } + + @Override + public Set filterSchemas(SystemSecurityContext context, String catalogName, Set schemaNames) + { + return opaHighLevelClient.parallelFilterFromOpa( + schemaNames, + schema -> buildQueryInputForSimpleResource( + OpaQueryContext.fromSystemSecurityContext(context), + "FilterSchemas", + OpaQueryInputResource.builder().schema(new TrinoSchema(catalogName, schema)).build())); + } + + @Override + public void checkCanShowCreateSchema(SystemSecurityContext context, CatalogSchemaName schemaName) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + "ShowCreateSchema", + () -> denyShowCreateSchema(schemaName.toString()), + OpaQueryInputResource.builder().schema(new TrinoSchema(schemaName)).build()); + } + + @Override + public void checkCanShowCreateTable(SystemSecurityContext context, CatalogSchemaTableName table) + { + checkTableOperation(context, "ShowCreateTable", table, AccessDeniedException::denyShowCreateTable); + } + + @Override + public void checkCanCreateTable(SystemSecurityContext context, CatalogSchemaTableName table, Map properties) + { + checkTableAndPropertiesOperation(context, "CreateTable", table, convertProperties(properties), AccessDeniedException::denyCreateTable); + } + + @Override + public void checkCanDropTable(SystemSecurityContext context, CatalogSchemaTableName table) + { + checkTableOperation(context, "DropTable", table, AccessDeniedException::denyDropTable); + } + + @Override + public void checkCanRenameTable(SystemSecurityContext context, CatalogSchemaTableName table, CatalogSchemaTableName newTable) + { + OpaQueryInputResource oldResource = OpaQueryInputResource.builder().table(new TrinoTable(table)).build(); + OpaQueryInputResource newResource = OpaQueryInputResource.builder().table(new TrinoTable(newTable)).build(); + OpaQueryContext queryContext = OpaQueryContext.fromSystemSecurityContext(context); + + if (!opaHighLevelClient.queryOpaWithSourceAndTargetResource(queryContext, "RenameTable", oldResource, newResource)) { + denyRenameTable(table.toString(), newTable.toString()); + } + } + + @Override + public void checkCanSetTableProperties(SystemSecurityContext context, CatalogSchemaTableName table, Map> properties) + { + checkTableAndPropertiesOperation(context, "SetTableProperties", table, properties, AccessDeniedException::denySetTableProperties); + } + + @Override + public void checkCanSetTableComment(SystemSecurityContext context, CatalogSchemaTableName table) + { + checkTableOperation(context, "SetTableComment", table, AccessDeniedException::denyCommentTable); + } + + @Override + public void checkCanSetViewComment(SystemSecurityContext context, CatalogSchemaTableName view) + { + checkTableOperation(context, "SetViewComment", view, AccessDeniedException::denyCommentView); + } + + @Override + public void checkCanSetColumnComment(SystemSecurityContext context, CatalogSchemaTableName table) + { + checkTableOperation(context, "SetColumnComment", table, AccessDeniedException::denyCommentColumn); + } + + @Override + public void checkCanShowTables(SystemSecurityContext context, CatalogSchemaName schema) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + "ShowTables", + () -> denyShowTables(schema.toString()), + OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build()); + } + + @Override + public Set filterTables(SystemSecurityContext context, String catalogName, Set tableNames) + { + return opaHighLevelClient.parallelFilterFromOpa( + tableNames, + table -> buildQueryInputForSimpleResource( + OpaQueryContext.fromSystemSecurityContext(context), + "FilterTables", + OpaQueryInputResource.builder() + .table(new TrinoTable(catalogName, table.getSchemaName(), table.getTableName())) + .build())); + } + + @Override + public void checkCanShowColumns(SystemSecurityContext context, CatalogSchemaTableName table) + { + checkTableOperation(context, "ShowColumns", table, AccessDeniedException::denyShowColumns); + } + + @Override + public Map> filterColumns(SystemSecurityContext context, String catalogName, Map> tableColumns) + { + ImmutableSet.Builder allColumnsBuilder = ImmutableSet.builder(); + for (Map.Entry> entry : tableColumns.entrySet()) { + SchemaTableName schemaTableName = entry.getKey(); + TrinoTable trinoTable = new TrinoTable(catalogName, schemaTableName.getSchemaName(), schemaTableName.getTableName()); + for (String columnName : entry.getValue()) { + allColumnsBuilder.add(trinoTable.withColumns(ImmutableSet.of(columnName))); + } + } + Set filteredColumns = opaHighLevelClient.parallelFilterFromOpa( + allColumnsBuilder.build(), + tableColumn -> buildQueryInputForSimpleResource( + OpaQueryContext.fromSystemSecurityContext(context), + "FilterColumns", + OpaQueryInputResource.builder().table(tableColumn).build())); + + ImmutableSetMultimap.Builder results = ImmutableSetMultimap.builder(); + for (TrinoTable tableColumn : filteredColumns) { + results.put(new SchemaTableName(tableColumn.schemaName(), tableColumn.tableName()), getOnlyElement(tableColumn.columns())); + } + return Multimaps.asMap(results.build()); + } + + @Override + public void checkCanAddColumn(SystemSecurityContext context, CatalogSchemaTableName table) + { + checkTableOperation(context, "AddColumn", table, AccessDeniedException::denyAddColumn); + } + + @Override + public void checkCanAlterColumn(SystemSecurityContext context, CatalogSchemaTableName table) + { + checkTableOperation(context, "AlterColumn", table, AccessDeniedException::denyAlterColumn); + } + + @Override + public void checkCanDropColumn(SystemSecurityContext context, CatalogSchemaTableName table) + { + checkTableOperation(context, "DropColumn", table, AccessDeniedException::denyDropColumn); + } + + @Override + public void checkCanSetTableAuthorization(SystemSecurityContext context, CatalogSchemaTableName table, TrinoPrincipal principal) + { + OpaQueryInputResource resource = OpaQueryInputResource.builder().table(new TrinoTable(table)).build(); + OpaQueryInputGrant grantee = OpaQueryInputGrant.builder().principal(TrinoGrantPrincipal.fromTrinoPrincipal(principal)).build(); + OpaQueryInputAction action = OpaQueryInputAction.builder() + .operation("SetTableAuthorization") + .resource(resource) + .grantee(grantee) + .build(); + OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + + if (!opaHighLevelClient.queryOpa(input)) { + denySetTableAuthorization(table.toString(), principal); + } + } + + @Override + public void checkCanRenameColumn(SystemSecurityContext context, CatalogSchemaTableName table) + { + checkTableOperation(context, "RenameColumn", table, AccessDeniedException::denyRenameColumn); + } + + @Override + public void checkCanSelectFromColumns(SystemSecurityContext context, CatalogSchemaTableName table, Set columns) + { + checkTableAndColumnsOperation(context, "SelectFromColumns", table, columns, AccessDeniedException::denySelectColumns); + } + + @Override + public void checkCanInsertIntoTable(SystemSecurityContext context, CatalogSchemaTableName table) + { + checkTableOperation(context, "InsertIntoTable", table, AccessDeniedException::denyInsertTable); + } + + @Override + public void checkCanDeleteFromTable(SystemSecurityContext context, CatalogSchemaTableName table) + { + checkTableOperation(context, "DeleteFromTable", table, AccessDeniedException::denyDeleteTable); + } + + @Override + public void checkCanTruncateTable(SystemSecurityContext context, CatalogSchemaTableName table) + { + checkTableOperation(context, "TruncateTable", table, AccessDeniedException::denyTruncateTable); + } + + @Override + public void checkCanUpdateTableColumns(SystemSecurityContext securityContext, CatalogSchemaTableName table, Set updatedColumnNames) + { + checkTableAndColumnsOperation(securityContext, "UpdateTableColumns", table, updatedColumnNames, AccessDeniedException::denyUpdateTableColumns); + } + + @Override + public void checkCanCreateView(SystemSecurityContext context, CatalogSchemaTableName view) + { + checkTableOperation(context, "CreateView", view, AccessDeniedException::denyCreateView); + } + + @Override + public void checkCanRenameView(SystemSecurityContext context, CatalogSchemaTableName view, CatalogSchemaTableName newView) + { + OpaQueryInputResource oldResource = OpaQueryInputResource.builder().table(new TrinoTable(view)).build(); + OpaQueryInputResource newResource = OpaQueryInputResource.builder().table(new TrinoTable(newView)).build(); + OpaQueryContext queryContext = OpaQueryContext.fromSystemSecurityContext(context); + + if (!opaHighLevelClient.queryOpaWithSourceAndTargetResource(queryContext, "RenameView", oldResource, newResource)) { + denyRenameView(view.toString(), newView.toString()); + } + } + + @Override + public void checkCanSetViewAuthorization(SystemSecurityContext context, CatalogSchemaTableName view, TrinoPrincipal principal) + { + OpaQueryInputResource resource = OpaQueryInputResource.builder().table(new TrinoTable(view)).build(); + OpaQueryInputGrant grantee = OpaQueryInputGrant.builder().principal(TrinoGrantPrincipal.fromTrinoPrincipal(principal)).build(); + OpaQueryInputAction action = OpaQueryInputAction.builder() + .operation("SetViewAuthorization") + .resource(resource) + .grantee(grantee) + .build(); + OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + + if (!opaHighLevelClient.queryOpa(input)) { + denySetViewAuthorization(view.toString(), principal); + } + } + + @Override + public void checkCanDropView(SystemSecurityContext context, CatalogSchemaTableName view) + { + checkTableOperation(context, "DropView", view, AccessDeniedException::denyDropView); + } + + @Override + public void checkCanCreateViewWithSelectFromColumns(SystemSecurityContext context, CatalogSchemaTableName table, Set columns) + { + checkTableAndColumnsOperation(context, "CreateViewWithSelectFromColumns", table, columns, (tableAsString, columnSet) -> denyCreateViewWithSelect(tableAsString, context.getIdentity())); + } + + @Override + public void checkCanCreateMaterializedView(SystemSecurityContext context, CatalogSchemaTableName materializedView, Map properties) + { + checkTableAndPropertiesOperation(context, "CreateMaterializedView", materializedView, convertProperties(properties), AccessDeniedException::denyCreateMaterializedView); + } + + @Override + public void checkCanRefreshMaterializedView(SystemSecurityContext context, CatalogSchemaTableName materializedView) + { + checkTableOperation(context, "RefreshMaterializedView", materializedView, AccessDeniedException::denyRefreshMaterializedView); + } + + @Override + public void checkCanSetMaterializedViewProperties(SystemSecurityContext context, CatalogSchemaTableName materializedView, Map> properties) + { + checkTableAndPropertiesOperation(context, "SetMaterializedViewProperties", materializedView, properties, AccessDeniedException::denySetMaterializedViewProperties); + } + + @Override + public void checkCanDropMaterializedView(SystemSecurityContext context, CatalogSchemaTableName materializedView) + { + checkTableOperation(context, "DropMaterializedView", materializedView, AccessDeniedException::denyDropMaterializedView); + } + + @Override + public void checkCanRenameMaterializedView(SystemSecurityContext context, CatalogSchemaTableName view, CatalogSchemaTableName newView) + { + OpaQueryInputResource oldResource = OpaQueryInputResource.builder().table(new TrinoTable(view)).build(); + OpaQueryInputResource newResource = OpaQueryInputResource.builder().table(new TrinoTable(newView)).build(); + OpaQueryContext queryContext = OpaQueryContext.fromSystemSecurityContext(context); + + if (!opaHighLevelClient.queryOpaWithSourceAndTargetResource(queryContext, "RenameMaterializedView", oldResource, newResource)) { + denyRenameMaterializedView(view.toString(), newView.toString()); + } + } + + @Override + public void checkCanSetCatalogSessionProperty(SystemSecurityContext context, String catalogName, String propertyName) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + "SetCatalogSessionProperty", + () -> denySetCatalogSessionProperty(propertyName), + OpaQueryInputResource.builder().catalogSessionProperty(new TrinoCatalogSessionProperty(catalogName, propertyName)).build()); + } + + @Override + public void checkCanGrantSchemaPrivilege(SystemSecurityContext context, Privilege privilege, CatalogSchemaName schema, TrinoPrincipal grantee, boolean grantOption) + { + OpaQueryInputResource resource = OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build(); + OpaQueryInputGrant opaGrantee = OpaQueryInputGrant.builder() + .principal(TrinoGrantPrincipal.fromTrinoPrincipal(grantee)) + .grantOption(grantOption) + .privilege(privilege) + .build(); + OpaQueryInputAction action = OpaQueryInputAction.builder() + .operation("GrantSchemaPrivilege") + .resource(resource) + .grantee(opaGrantee) + .build(); + OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + + if (!opaHighLevelClient.queryOpa(input)) { + denyGrantSchemaPrivilege(privilege.toString(), schema.toString()); + } + } + + @Override + public void checkCanDenySchemaPrivilege(SystemSecurityContext context, Privilege privilege, CatalogSchemaName schema, TrinoPrincipal grantee) + { + OpaQueryInputResource resource = OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build(); + OpaQueryInputGrant opaGrantee = OpaQueryInputGrant.builder() + .principal(TrinoGrantPrincipal.fromTrinoPrincipal(grantee)) + .privilege(privilege) + .build(); + OpaQueryInputAction action = OpaQueryInputAction.builder() + .operation("DenySchemaPrivilege") + .resource(resource) + .grantee(opaGrantee) + .build(); + OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + + if (!opaHighLevelClient.queryOpa(input)) { + denyDenySchemaPrivilege(privilege.toString(), schema.toString()); + } + } + + @Override + public void checkCanRevokeSchemaPrivilege(SystemSecurityContext context, Privilege privilege, CatalogSchemaName schema, TrinoPrincipal revokee, boolean grantOption) + { + OpaQueryInputResource resource = OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build(); + OpaQueryInputGrant opaGrantee = OpaQueryInputGrant.builder() + .principal(TrinoGrantPrincipal.fromTrinoPrincipal(revokee)) + .grantOption(grantOption) + .privilege(privilege) + .build(); + OpaQueryInputAction action = OpaQueryInputAction.builder() + .operation("RevokeSchemaPrivilege") + .resource(resource) + .grantee(opaGrantee) + .build(); + OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + + if (!opaHighLevelClient.queryOpa(input)) { + denyRevokeSchemaPrivilege(privilege.toString(), schema.toString()); + } + } + + @Override + public void checkCanGrantTablePrivilege(SystemSecurityContext context, Privilege privilege, CatalogSchemaTableName table, TrinoPrincipal grantee, boolean grantOption) + { + OpaQueryInputResource resource = OpaQueryInputResource.builder().table(new TrinoTable(table)).build(); + OpaQueryInputGrant opaGrantee = OpaQueryInputGrant.builder() + .principal(TrinoGrantPrincipal.fromTrinoPrincipal(grantee)) + .grantOption(grantOption) + .privilege(privilege) + .build(); + OpaQueryInputAction action = OpaQueryInputAction.builder() + .operation("GrantTablePrivilege") + .resource(resource) + .grantee(opaGrantee) + .build(); + OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + + if (!opaHighLevelClient.queryOpa(input)) { + denyGrantTablePrivilege(privilege.toString(), table.toString()); + } + } + + @Override + public void checkCanDenyTablePrivilege(SystemSecurityContext context, Privilege privilege, CatalogSchemaTableName table, TrinoPrincipal grantee) + { + OpaQueryInputResource resource = OpaQueryInputResource.builder().table(new TrinoTable(table)).build(); + OpaQueryInputGrant opaGrantee = OpaQueryInputGrant.builder() + .principal(TrinoGrantPrincipal.fromTrinoPrincipal(grantee)) + .privilege(privilege) + .build(); + OpaQueryInputAction action = OpaQueryInputAction.builder() + .operation("DenyTablePrivilege") + .resource(resource) + .grantee(opaGrantee) + .build(); + OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + + if (!opaHighLevelClient.queryOpa(input)) { + denyDenyTablePrivilege(privilege.toString(), table.toString()); + } + } + + @Override + public void checkCanRevokeTablePrivilege(SystemSecurityContext context, Privilege privilege, CatalogSchemaTableName table, TrinoPrincipal revokee, boolean grantOption) + { + OpaQueryInputResource resource = OpaQueryInputResource.builder().table(new TrinoTable(table)).build(); + OpaQueryInputGrant opaRevokee = OpaQueryInputGrant.builder() + .principal(TrinoGrantPrincipal.fromTrinoPrincipal(revokee)) + .privilege(privilege) + .grantOption(grantOption) + .build(); + OpaQueryInputAction action = OpaQueryInputAction.builder() + .operation("RevokeTablePrivilege") + .resource(resource) + .grantee(opaRevokee) + .build(); + OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + + if (!opaHighLevelClient.queryOpa(input)) { + denyRevokeTablePrivilege(privilege.toString(), table.toString()); + } + } + + @Override + public void checkCanShowRoles(SystemSecurityContext context) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + "ShowRoles", + AccessDeniedException::denyShowRoles); + } + + @Override + public void checkCanCreateRole(SystemSecurityContext context, String role, Optional grantor) + { + OpaQueryInputResource resource = OpaQueryInputResource.builder().role(role).build(); + OpaQueryInputAction action = OpaQueryInputAction.builder() + .operation("CreateRole") + .resource(resource) + .grantor(TrinoGrantPrincipal.fromTrinoPrincipal(grantor)) + .build(); + OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + + if (!opaHighLevelClient.queryOpa(input)) { + denyCreateRole(role); + } + } + + @Override + public void checkCanDropRole(SystemSecurityContext context, String role) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + "DropRole", + () -> denyDropRole(role), + OpaQueryInputResource.builder().role(role).build()); + } + + @Override + public void checkCanGrantRoles(SystemSecurityContext context, Set roles, Set grantees, boolean adminOption, Optional grantor) + { + OpaQueryInputResource resource = OpaQueryInputResource.builder().roles(roles).build(); + OpaQueryInputGrant opaGrantees = OpaQueryInputGrant.builder() + .grantOption(adminOption) + .principals(grantees.stream() + .map(TrinoGrantPrincipal::fromTrinoPrincipal) + .collect(toImmutableSet())) + .build(); + OpaQueryInputAction action = OpaQueryInputAction.builder() + .operation("GrantRoles") + .resource(resource) + .grantee(opaGrantees) + .grantor(TrinoGrantPrincipal.fromTrinoPrincipal(grantor)) + .build(); + OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + + if (!opaHighLevelClient.queryOpa(input)) { + denyGrantRoles(roles, grantees); + } + } + + @Override + public void checkCanRevokeRoles(SystemSecurityContext context, Set roles, Set grantees, boolean adminOption, Optional grantor) + { + OpaQueryInputResource resource = OpaQueryInputResource.builder().roles(roles).build(); + OpaQueryInputGrant opaGrantees = OpaQueryInputGrant.builder() + .grantOption(adminOption) + .principals(grantees.stream() + .map(TrinoGrantPrincipal::fromTrinoPrincipal) + .collect(toImmutableSet())) + .build(); + OpaQueryInputAction action = OpaQueryInputAction.builder() + .operation("RevokeRoles") + .resource(resource) + .grantee(opaGrantees) + .grantor(TrinoGrantPrincipal.fromTrinoPrincipal(grantor)) + .build(); + OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + + if (!opaHighLevelClient.queryOpa(input)) { + denyRevokeRoles(roles, grantees); + } + } + + @Override + public void checkCanShowCurrentRoles(SystemSecurityContext context) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + "ShowCurrentRoles", + AccessDeniedException::denyShowCurrentRoles); + } + + @Override + public void checkCanShowRoleGrants(SystemSecurityContext context) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + "ShowRoleGrants", + AccessDeniedException::denyShowRoleGrants); + } + + @Override + public void checkCanShowFunctions(SystemSecurityContext context, CatalogSchemaName schema) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + "ShowFunctions", + () -> denyShowFunctions(schema.toString()), + OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build()); + } + + @Override + public Set filterFunctions(SystemSecurityContext context, String catalogName, Set functionNames) + { + return opaHighLevelClient.parallelFilterFromOpa( + functionNames, + function -> buildQueryInputForSimpleResource( + OpaQueryContext.fromSystemSecurityContext(context), + "FilterFunctions", + OpaQueryInputResource.builder() + .function( + new TrinoFunction( + new TrinoSchema(catalogName, function.getSchemaName()), + function.getFunctionName())) + .build())); + } + + @Override + public void checkCanExecuteProcedure(SystemSecurityContext systemSecurityContext, CatalogSchemaRoutineName procedure) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(systemSecurityContext), + "ExecuteProcedure", + () -> denyExecuteProcedure(procedure.toString()), + OpaQueryInputResource.builder().function(TrinoFunction.fromTrinoFunction(procedure)).build()); + } + + @Override + public boolean canExecuteFunction(SystemSecurityContext systemSecurityContext, CatalogSchemaRoutineName functionName) + { + return opaHighLevelClient.queryOpaWithSimpleResource( + OpaQueryContext.fromSystemSecurityContext(systemSecurityContext), + "ExecuteFunction", + OpaQueryInputResource.builder().function(TrinoFunction.fromTrinoFunction(functionName)).build()); + } + + @Override + public boolean canCreateViewWithExecuteFunction(SystemSecurityContext systemSecurityContext, CatalogSchemaRoutineName functionName) + { + return opaHighLevelClient.queryOpaWithSimpleResource( + OpaQueryContext.fromSystemSecurityContext(systemSecurityContext), + "CreateViewWithExecuteFunction", + OpaQueryInputResource.builder().function(TrinoFunction.fromTrinoFunction(functionName)).build()); + } + + @Override + public void checkCanExecuteTableProcedure(SystemSecurityContext systemSecurityContext, CatalogSchemaTableName table, String procedure) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(systemSecurityContext), + "ExecuteTableProcedure", + () -> denyExecuteTableProcedure(table.toString(), procedure), + OpaQueryInputResource.builder().table(new TrinoTable(table)).function(procedure).build()); + } + + @Override + public void checkCanCreateFunction(SystemSecurityContext systemSecurityContext, CatalogSchemaRoutineName functionName) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(systemSecurityContext), + "CreateFunction", + () -> denyCreateFunction(functionName.toString()), + OpaQueryInputResource.builder().function(TrinoFunction.fromTrinoFunction(functionName)).build()); + } + + @Override + public void checkCanDropFunction(SystemSecurityContext systemSecurityContext, CatalogSchemaRoutineName functionName) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(systemSecurityContext), + "DropFunction", + () -> denyDropFunction(functionName.toString()), + OpaQueryInputResource.builder().function(TrinoFunction.fromTrinoFunction(functionName)).build()); + } + + private void checkTableOperation(SystemSecurityContext context, String actionName, CatalogSchemaTableName table, Consumer deny) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + actionName, + () -> deny.accept(table.toString()), + OpaQueryInputResource.builder().table(new TrinoTable(table)).build()); + } + + private void checkTableAndPropertiesOperation(SystemSecurityContext context, String actionName, CatalogSchemaTableName table, Map> properties, Consumer deny) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + actionName, + () -> deny.accept(table.toString()), + OpaQueryInputResource.builder().table(new TrinoTable(table).withProperties(properties)).build()); + } + + private void checkTableAndColumnsOperation(SystemSecurityContext context, String actionName, CatalogSchemaTableName table, Set columns, BiConsumer> deny) + { + opaHighLevelClient.queryAndEnforce( + OpaQueryContext.fromSystemSecurityContext(context), + actionName, + () -> deny.accept(table.toString(), columns), + OpaQueryInputResource.builder().table(new TrinoTable(table).withColumns(columns)).build()); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlFactory.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlFactory.java new file mode 100644 index 0000000000000..24cc9c45c269f --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlFactory.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.common.annotations.VisibleForTesting; +import com.google.inject.Injector; +import com.google.inject.Key; +import com.google.inject.Provider; +import com.google.inject.Scopes; +import io.airlift.bootstrap.Bootstrap; +import io.airlift.concurrent.BoundedExecutor; +import io.airlift.http.client.HttpClient; +import io.airlift.json.JsonModule; +import io.trino.plugin.opa.schema.OpaQuery; +import io.trino.plugin.opa.schema.OpaQueryResult; +import io.trino.spi.security.SystemAccessControl; +import io.trino.spi.security.SystemAccessControlFactory; + +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.Executor; + +import static io.airlift.concurrent.Threads.daemonThreadsNamed; +import static io.airlift.http.client.HttpClientBinder.httpClientBinder; +import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newCachedThreadPool; + +public class OpaAccessControlFactory + implements SystemAccessControlFactory +{ + @Override + public String getName() + { + return "opa"; + } + + @Override + public SystemAccessControl create(Map config) + { + return create(config, Optional.empty()); + } + + @Override + public SystemAccessControl create(Map config, SystemAccessControlContext context) + { + return create(config); + } + + @VisibleForTesting + protected static SystemAccessControl create(Map config, Optional httpClient) + { + requireNonNull(config, "config is null"); + + Bootstrap app = new Bootstrap( + new JsonModule(), + binder -> { + jsonCodecBinder(binder).bindJsonCodec(OpaQuery.class); + jsonCodecBinder(binder).bindJsonCodec(OpaQueryResult.class); + httpClient.ifPresentOrElse( + client -> binder.bind(Key.get(HttpClient.class, ForOpa.class)).toInstance(client), + () -> httpClientBinder(binder).bindHttpClient("opa", ForOpa.class)); + binder.bind(OpaHighLevelClient.class); + binder.bind(Key.get(Executor.class, ForOpa.class)) + .toProvider(ExecutorProvider.class) + .in(Scopes.SINGLETON); + binder.bind(OpaHttpClient.class).in(Scopes.SINGLETON); + }, + new OpaAccessControlModule()); + + Injector injector = app + .doNotInitializeLogging() + .setRequiredConfigurationProperties(config) + .initialize(); + return injector.getInstance(SystemAccessControl.class); + } + + private static class ExecutorProvider + implements Provider + { + private final Executor executor; + + private ExecutorProvider() + { + this.executor = new BoundedExecutor( + newCachedThreadPool(daemonThreadsNamed("opa-access-control-http-%s")), + Runtime.getRuntime().availableProcessors()); + } + + @Override + public Executor get() + { + return executor; + } + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlModule.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlModule.java new file mode 100644 index 0000000000000..4c0c470351c10 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlModule.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.inject.Binder; +import com.google.inject.Scopes; +import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.trino.plugin.opa.schema.OpaBatchQueryResult; +import io.trino.spi.security.SystemAccessControl; + +import static io.airlift.configuration.ConditionalModule.conditionalModule; +import static io.airlift.configuration.ConfigBinder.configBinder; +import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; + +public class OpaAccessControlModule + extends AbstractConfigurationAwareModule +{ + @Override + protected void setup(Binder binder) + { + configBinder(binder).bindConfig(OpaConfig.class); + install(conditionalModule( + OpaConfig.class, + config -> config.getOpaBatchUri().isPresent(), + new OpaBatchAccessControlModule(), + new OpaSingleAuthorizerModule())); + } + + public static class OpaSingleAuthorizerModule + extends AbstractConfigurationAwareModule + { + @Override + protected void setup(Binder binder) + { + binder.bind(SystemAccessControl.class).to(OpaAccessControl.class).in(Scopes.SINGLETON); + } + } + + public static class OpaBatchAccessControlModule + extends AbstractConfigurationAwareModule + { + @Override + protected void setup(Binder binder) + { + jsonCodecBinder(binder).bindJsonCodec(OpaBatchQueryResult.class); + binder.bind(SystemAccessControl.class).to(OpaBatchAccessControl.class).in(Scopes.SINGLETON); + } + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlPlugin.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlPlugin.java new file mode 100644 index 0000000000000..0c3140bc43f11 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlPlugin.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.common.collect.ImmutableList; +import io.trino.spi.Plugin; +import io.trino.spi.security.SystemAccessControlFactory; + +public class OpaAccessControlPlugin + implements Plugin +{ + @Override + public Iterable getSystemAccessControlFactories() + { + return ImmutableList.of(new OpaAccessControlFactory()); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaBatchAccessControl.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaBatchAccessControl.java new file mode 100644 index 0000000000000..266068cbb8a2d --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaBatchAccessControl.java @@ -0,0 +1,160 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.inject.Inject; +import io.airlift.json.JsonCodec; +import io.trino.plugin.opa.schema.OpaBatchQueryResult; +import io.trino.plugin.opa.schema.OpaQueryContext; +import io.trino.plugin.opa.schema.OpaQueryInput; +import io.trino.plugin.opa.schema.OpaQueryInputAction; +import io.trino.plugin.opa.schema.OpaQueryInputResource; +import io.trino.plugin.opa.schema.TrinoFunction; +import io.trino.plugin.opa.schema.TrinoSchema; +import io.trino.plugin.opa.schema.TrinoTable; +import io.trino.plugin.opa.schema.TrinoUser; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.function.SchemaFunctionName; +import io.trino.spi.security.Identity; +import io.trino.spi.security.SystemSecurityContext; + +import java.net.URI; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +public final class OpaBatchAccessControl + extends OpaAccessControl +{ + private final JsonCodec batchResultCodec; + private final URI opaBatchedPolicyUri; + private final OpaHttpClient opaHttpClient; + + @Inject + public OpaBatchAccessControl( + OpaHighLevelClient opaHighLevelClient, + JsonCodec batchResultCodec, + OpaHttpClient opaHttpClient, + OpaConfig config) + { + super(opaHighLevelClient); + this.opaBatchedPolicyUri = config.getOpaBatchUri().orElseThrow(); + this.batchResultCodec = batchResultCodec; + this.opaHttpClient = opaHttpClient; + } + + @Override + public Collection filterViewQueryOwnedBy(Identity identity, Collection queryOwners) + { + return batchFilterFromOpa( + OpaQueryContext.fromIdentity(identity), + "FilterViewQueryOwnedBy", + queryOwners, + queryOwner -> OpaQueryInputResource.builder() + .user(new TrinoUser(queryOwner)) + .build()); + } + + @Override + public Set filterCatalogs(SystemSecurityContext context, Set catalogs) + { + return batchFilterFromOpa( + OpaQueryContext.fromSystemSecurityContext(context), + "FilterCatalogs", + catalogs, + catalog -> OpaQueryInputResource.builder() + .catalog(catalog) + .build()); + } + + @Override + public Set filterSchemas(SystemSecurityContext context, String catalogName, Set schemaNames) + { + return batchFilterFromOpa( + OpaQueryContext.fromSystemSecurityContext(context), + "FilterSchemas", + schemaNames, + schema -> OpaQueryInputResource.builder().schema(new TrinoSchema(catalogName, schema)).build()); + } + + @Override + public Set filterTables(SystemSecurityContext context, String catalogName, Set tableNames) + { + return batchFilterFromOpa( + OpaQueryContext.fromSystemSecurityContext(context), + "FilterTables", + tableNames, + table -> OpaQueryInputResource.builder().table(new TrinoTable(catalogName, table.getSchemaName(), table.getTableName())).build()); + } + + @Override + public Map> filterColumns(SystemSecurityContext context, String catalogName, Map> tableColumns) + { + BiFunction, OpaQueryInput> requestBuilder = batchRequestBuilder( + OpaQueryContext.fromSystemSecurityContext(context), + "FilterColumns", + (schemaTableName, columns) -> OpaQueryInputResource.builder() + .table(new TrinoTable(catalogName, schemaTableName.getSchemaName(), schemaTableName.getTableName()).withColumns(ImmutableSet.copyOf(columns))) + .build()); + return opaHttpClient.parallelBatchFilterFromOpa(tableColumns, requestBuilder, opaBatchedPolicyUri, batchResultCodec); + } + + @Override + public Set filterFunctions(SystemSecurityContext context, String catalogName, Set functionNames) + { + return batchFilterFromOpa( + OpaQueryContext.fromSystemSecurityContext(context), + "FilterFunctions", + functionNames, + function -> OpaQueryInputResource.builder() + .function(new TrinoFunction(new TrinoSchema(catalogName, function.getSchemaName()), function.getFunctionName())) + .build()); + } + + private Function, OpaQueryInput> batchRequestBuilder(OpaQueryContext context, String operation, Function resourceMapper) + { + return items -> new OpaQueryInput( + context, + OpaQueryInputAction.builder() + .operation(operation) + .filterResources(items.stream().map(resourceMapper).collect(toImmutableList())) + .build()); + } + + private BiFunction, OpaQueryInput> batchRequestBuilder(OpaQueryContext context, String operation, BiFunction, OpaQueryInputResource> resourceMapper) + { + return (resourcesKey, resourcesList) -> new OpaQueryInput( + context, + OpaQueryInputAction.builder() + .operation(operation) + .filterResources(ImmutableList.of(resourceMapper.apply(resourcesKey, resourcesList))) + .build()); + } + + private Set batchFilterFromOpa(OpaQueryContext context, String operation, Collection items, Function converter) + { + return opaHttpClient.batchFilterFromOpa( + items, + batchRequestBuilder(context, operation, converter), + opaBatchedPolicyUri, + batchResultCodec); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaConfig.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaConfig.java new file mode 100644 index 0000000000000..6169cd2748e38 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaConfig.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import io.airlift.configuration.Config; +import io.airlift.configuration.ConfigDescription; +import jakarta.validation.constraints.NotNull; + +import java.net.URI; +import java.util.Optional; + +public class OpaConfig +{ + private URI opaUri; + + private Optional opaBatchUri = Optional.empty(); + private boolean logRequests; + private boolean logResponses; + + @NotNull + public URI getOpaUri() + { + return opaUri; + } + + @Config("opa.policy.uri") + @ConfigDescription("URI for OPA policies") + public OpaConfig setOpaUri(@NotNull URI opaUri) + { + this.opaUri = opaUri; + return this; + } + + public Optional getOpaBatchUri() + { + return opaBatchUri; + } + + @Config("opa.policy.batched-uri") + @ConfigDescription("URI for Batch OPA policies - if not set, a single request will be sent for each entry on filtering methods") + public OpaConfig setOpaBatchUri(URI opaBatchUri) + { + this.opaBatchUri = Optional.ofNullable(opaBatchUri); + return this; + } + + public boolean getLogRequests() + { + return this.logRequests; + } + + @Config("opa.log-requests") + @ConfigDescription("Whether to log requests (URI, entire body and headers) prior to sending them to OPA") + public OpaConfig setLogRequests(boolean logRequests) + { + this.logRequests = logRequests; + return this; + } + + public boolean getLogResponses() + { + return this.logResponses; + } + + @Config("opa.log-responses") + @ConfigDescription("Whether to log responses (URI, entire body, status code and headers) received from OPA") + public OpaConfig setLogResponses(boolean logResponses) + { + this.logResponses = logResponses; + return this; + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHighLevelClient.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHighLevelClient.java new file mode 100644 index 0000000000000..5270c1b7f02da --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHighLevelClient.java @@ -0,0 +1,117 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.inject.Inject; +import io.airlift.json.JsonCodec; +import io.trino.plugin.opa.schema.OpaQueryContext; +import io.trino.plugin.opa.schema.OpaQueryInput; +import io.trino.plugin.opa.schema.OpaQueryInputAction; +import io.trino.plugin.opa.schema.OpaQueryInputResource; +import io.trino.plugin.opa.schema.OpaQueryResult; +import io.trino.spi.security.AccessDeniedException; + +import java.net.URI; +import java.util.Collection; +import java.util.Set; +import java.util.function.Function; + +import static java.util.Objects.requireNonNull; + +public class OpaHighLevelClient +{ + private final JsonCodec queryResultCodec; + private final URI opaPolicyUri; + private final OpaHttpClient opaHttpClient; + + @Inject + public OpaHighLevelClient( + JsonCodec queryResultCodec, + OpaHttpClient opaHttpClient, + OpaConfig config) + { + this.queryResultCodec = requireNonNull(queryResultCodec, "queryResultCodec is null"); + this.opaHttpClient = requireNonNull(opaHttpClient, "opaHttpClient is null"); + this.opaPolicyUri = config.getOpaUri(); + } + + public boolean queryOpa(OpaQueryInput input) + { + return opaHttpClient.consumeOpaResponse(opaHttpClient.submitOpaRequest(input, opaPolicyUri, queryResultCodec)).result(); + } + + private boolean queryOpaWithSimpleAction(OpaQueryContext context, String operation) + { + return queryOpa(buildQueryInputForSimpleAction(context, operation)); + } + + public boolean queryOpaWithSimpleResource(OpaQueryContext context, String operation, OpaQueryInputResource resource) + { + return queryOpa(buildQueryInputForSimpleResource(context, operation, resource)); + } + + public boolean queryOpaWithSourceAndTargetResource(OpaQueryContext context, String operation, OpaQueryInputResource resource, OpaQueryInputResource targetResource) + { + return queryOpa( + new OpaQueryInput( + context, + OpaQueryInputAction.builder() + .operation(operation) + .resource(resource) + .targetResource(targetResource) + .build())); + } + + public void queryAndEnforce( + OpaQueryContext context, + String actionName, + Runnable deny, + OpaQueryInputResource resource) + { + if (!queryOpaWithSimpleResource(context, actionName, resource)) { + deny.run(); + // we should never get here because deny should throw + throw new AccessDeniedException("Access denied for action %s and resource %s".formatted(actionName, resource)); + } + } + + public void queryAndEnforce( + OpaQueryContext context, + String actionName, + Runnable deny) + { + if (!queryOpaWithSimpleAction(context, actionName)) { + deny.run(); + // we should never get here because deny should throw + throw new AccessDeniedException("Access denied for action %s".formatted(actionName)); + } + } + + public Set parallelFilterFromOpa( + Collection items, + Function requestBuilder) + { + return opaHttpClient.parallelFilterFromOpa(items, requestBuilder, opaPolicyUri, queryResultCodec); + } + + public static OpaQueryInput buildQueryInputForSimpleAction(OpaQueryContext context, String operation) + { + return new OpaQueryInput(context, OpaQueryInputAction.builder().operation(operation).build()); + } + + public static OpaQueryInput buildQueryInputForSimpleResource(OpaQueryContext context, String operation, OpaQueryInputResource resource) + { + return new OpaQueryInput(context, OpaQueryInputAction.builder().operation(operation).resource(resource).build()); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHttpClient.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHttpClient.java new file mode 100644 index 0000000000000..5cdc40f2e1c74 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHttpClient.java @@ -0,0 +1,222 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.FluentFuture; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.inject.Inject; +import io.airlift.http.client.FullJsonResponseHandler; +import io.airlift.http.client.HttpClient; +import io.airlift.http.client.HttpStatus; +import io.airlift.http.client.JsonBodyGenerator; +import io.airlift.http.client.Request; +import io.airlift.json.JsonCodec; +import io.airlift.log.Logger; +import io.trino.plugin.opa.schema.OpaBatchQueryResult; +import io.trino.plugin.opa.schema.OpaQuery; +import io.trino.plugin.opa.schema.OpaQueryInput; +import io.trino.plugin.opa.schema.OpaQueryResult; + +import java.net.URI; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.function.BiFunction; +import java.util.function.Function; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.net.HttpHeaders.CONTENT_TYPE; +import static com.google.common.net.MediaType.JSON_UTF_8; +import static io.airlift.http.client.FullJsonResponseHandler.createFullJsonResponseHandler; +import static io.airlift.http.client.JsonBodyGenerator.jsonBodyGenerator; +import static io.airlift.http.client.Request.Builder.preparePost; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; +import static java.util.Objects.requireNonNullElse; + +public class OpaHttpClient +{ + private final HttpClient httpClient; + private final JsonCodec serializer; + private final Executor executor; + private final boolean logRequests; + private final boolean logResponses; + private static final Logger log = Logger.get(OpaHttpClient.class); + + @Inject + public OpaHttpClient( + @ForOpa HttpClient httpClient, + JsonCodec serializer, + @ForOpa Executor executor, + OpaConfig config) + { + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + this.serializer = requireNonNull(serializer, "serializer is null"); + this.executor = requireNonNull(executor, "executor is null"); + this.logRequests = config.getLogRequests(); + this.logResponses = config.getLogResponses(); + } + + public FluentFuture submitOpaRequest(OpaQueryInput input, URI uri, JsonCodec deserializer) + { + Request request; + JsonBodyGenerator requestBodyGenerator; + try { + requestBodyGenerator = jsonBodyGenerator(serializer, new OpaQuery(input)); + request = preparePost() + .addHeader(CONTENT_TYPE, JSON_UTF_8.toString()) + .setUri(uri) + .setBodyGenerator(requestBodyGenerator) + .build(); + } + catch (IllegalArgumentException e) { + log.error(e, "Failed to serialize OPA request body when attempting to send request to URI \"%s\"", uri.toString()); + throw new OpaQueryException.SerializeFailed(e); + } + if (logRequests) { + log.debug( + "Sending OPA request to URI \"%s\" ; request body = %s ; request headers = %s", + uri.toString(), + tryConvertBytesToString(requestBodyGenerator.getBody()), + request.getHeaders()); + } + return FluentFuture.from(httpClient.executeAsync(request, createFullJsonResponseHandler(deserializer))) + .transform(response -> parseOpaResponse(response, uri), executor); + } + + public T consumeOpaResponse(ListenableFuture opaResponseFuture) + { + try { + return opaResponseFuture.get(); + } + catch (ExecutionException e) { + if (e.getCause() instanceof OpaQueryException queryException) { + throw queryException; + } + log.error(e, "Failed to obtain response from OPA due to an unknown error"); + throw new OpaQueryException.QueryFailed(e); + } + catch (InterruptedException e) { + log.error(e, "OPA request was interrupted in flight"); + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + public Set parallelFilterFromOpa(Collection items, Function requestBuilder, URI uri, JsonCodec deserializer) + { + if (items.isEmpty()) { + return ImmutableSet.of(); + } + List>> allFutures = items.stream() + .map(item -> submitOpaRequest(requestBuilder.apply(item), uri, deserializer) + .transform(result -> result.result() ? Optional.of(item) : Optional.empty(), executor)) + .collect(toImmutableList()); + return consumeOpaResponse( + Futures.whenAllComplete(allFutures).call(() -> allFutures.stream() + .map(this::consumeOpaResponse) + .filter(Optional::isPresent) + .map(Optional::get) + .collect(toImmutableSet()), + executor)); + } + + public Set batchFilterFromOpa(Collection items, Function, OpaQueryInput> requestBuilder, URI uri, JsonCodec deserializer) + { + if (items.isEmpty()) { + return ImmutableSet.of(); + } + String dummyMapKey = "filter"; + return parallelBatchFilterFromOpa(ImmutableMap.of(dummyMapKey, items), (mapKey, mapValue) -> requestBuilder.apply(mapValue), uri, deserializer).getOrDefault(dummyMapKey, ImmutableSet.of()); + } + + public Map> parallelBatchFilterFromOpa(Map> items, BiFunction, OpaQueryInput> requestBuilder, URI uri, JsonCodec deserializer) + { + ImmutableMap.Builder>> allFuturesBuilder = ImmutableMap.builder(); + + for (Map.Entry> mapEntry : items.entrySet()) { + if (mapEntry.getValue().isEmpty()) { + continue; + } + List orderedItems = ImmutableList.copyOf(mapEntry.getValue()); + allFuturesBuilder.put( + mapEntry.getKey(), + submitOpaRequest(requestBuilder.apply(mapEntry.getKey(), orderedItems), uri, deserializer) + .transform( + response -> requireNonNullElse(response.result(), ImmutableList.of()).stream() + .map(orderedItems::get) + .collect(toImmutableSet()), + executor)); + } + + ImmutableMap>> allFutures = allFuturesBuilder.buildOrThrow(); + ImmutableMap.Builder> resultBuilder = ImmutableMap.builder(); + List>> consumedFutures = consumeOpaResponse( + Futures.whenAllComplete(allFutures.values()).call( + () -> allFutures.entrySet().stream() + .map(entry -> Map.entry(entry.getKey(), consumeOpaResponse(entry.getValue()))) + .filter(entry -> !entry.getValue().isEmpty()) + .collect(toImmutableList()), + executor)); + return resultBuilder.putAll(consumedFutures).buildKeepingLast(); + } + + private T parseOpaResponse(FullJsonResponseHandler.JsonResponse response, URI uri) + { + int statusCode = response.getStatusCode(); + String uriString = uri.toString(); + if (HttpStatus.familyForStatusCode(statusCode) != HttpStatus.Family.SUCCESSFUL) { + if (statusCode == HttpStatus.NOT_FOUND.code()) { + log.warn("OPA responded with not found error for policy with URI \"%s\"", uriString); + throw new OpaQueryException.PolicyNotFound(uriString); + } + + log.error("Received unknown error from OPA for URI \"%s\" with status code = %d", uriString, statusCode); + throw new OpaQueryException.OpaServerError(uriString, statusCode, response.toString()); + } + if (!response.hasValue()) { + log.error(response.getException(), "OPA response for URI \"%s\" with status code = %d could not be deserialized", uriString, statusCode); + throw new OpaQueryException.DeserializeFailed(response.getException()); + } + if (logResponses) { + log.debug( + "OPA response for URI \"%s\" received: status code = %d ; response payload = %s ; response headers = %s", + uriString, + statusCode, + tryConvertBytesToString(response.getJsonBytes()), + response.getHeaders()); + } + return response.getValue(); + } + + private static String tryConvertBytesToString(byte[] bytes) + { + try { + return new String(bytes, UTF_8); + } + catch (Exception e) { + log.error(e, "Failed to convert JSON bytes to string for logging"); + return ""; + } + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaQueryException.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaQueryException.java new file mode 100644 index 0000000000000..ef94546e0db30 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaQueryException.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +public abstract class OpaQueryException + extends RuntimeException +{ + public OpaQueryException(String message, Throwable cause) + { + super(message, cause); + } + + public static final class QueryFailed + extends OpaQueryException + { + public QueryFailed(Throwable cause) + { + super("Failed to query OPA backend", cause); + } + } + + public static final class SerializeFailed + extends OpaQueryException + { + public SerializeFailed(Throwable cause) + { + super("Failed to serialize OPA query context", cause); + } + } + + public static final class DeserializeFailed + extends OpaQueryException + { + public DeserializeFailed(Throwable cause) + { + super("Failed to deserialize OPA policy response", cause); + } + } + + public static final class PolicyNotFound + extends OpaQueryException + { + public PolicyNotFound(String policyName) + { + super("OPA policy named %s did not return a value (or does not exist)".formatted(policyName), null); + } + } + + public static final class OpaServerError + extends OpaQueryException + { + public OpaServerError(String policyName, int statusCode, String extra) + { + super("OPA server returned status %d when processing policy %s: %s".formatted(statusCode, policyName, extra), null); + } + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaBatchQueryResult.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaBatchQueryResult.java new file mode 100644 index 0000000000000..572401b569fb8 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaBatchQueryResult.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import jakarta.validation.constraints.NotNull; + +import java.util.List; + +import static java.util.Objects.requireNonNullElse; + +public record OpaBatchQueryResult(@JsonProperty("decision_id") String decisionId, @NotNull List result) +{ + public OpaBatchQueryResult + { + result = ImmutableList.copyOf(requireNonNullElse(result, ImmutableList.of())); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQuery.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQuery.java new file mode 100644 index 0000000000000..df029c281cd0b --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQuery.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import static java.util.Objects.requireNonNull; + +public record OpaQuery(OpaQueryInput input) +{ + public OpaQuery + { + requireNonNull(input, "input is null"); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryContext.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryContext.java new file mode 100644 index 0000000000000..75af04b2071c1 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryContext.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import io.trino.spi.security.Identity; +import io.trino.spi.security.SystemSecurityContext; + +import static java.util.Objects.requireNonNull; + +public record OpaQueryContext(TrinoIdentity identity) +{ + public static OpaQueryContext fromSystemSecurityContext(SystemSecurityContext ctx) + { + return new OpaQueryContext(TrinoIdentity.fromTrinoIdentity(ctx.getIdentity())); + } + + public static OpaQueryContext fromIdentity(Identity identity) + { + return new OpaQueryContext(TrinoIdentity.fromTrinoIdentity(identity)); + } + + public OpaQueryContext + { + requireNonNull(identity, "identity is null"); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInput.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInput.java new file mode 100644 index 0000000000000..e9023826e5c28 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInput.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import static java.util.Objects.requireNonNull; + +public record OpaQueryInput(OpaQueryContext context, OpaQueryInputAction action) +{ + public OpaQueryInput + { + requireNonNull(context, "context is null"); + requireNonNull(action, "action is null"); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputAction.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputAction.java new file mode 100644 index 0000000000000..67f47e5077af4 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputAction.java @@ -0,0 +1,109 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.google.common.collect.ImmutableList; +import jakarta.validation.constraints.NotNull; + +import java.util.Collection; +import java.util.List; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; +import static java.util.Objects.requireNonNull; + +@JsonInclude(NON_NULL) +public record OpaQueryInputAction( + @NotNull String operation, + OpaQueryInputResource resource, + List filterResources, + OpaQueryInputResource targetResource, + OpaQueryInputGrant grantee, + TrinoGrantPrincipal grantor) +{ + public OpaQueryInputAction + { + requireNonNull(operation, "operation is null"); + if (filterResources != null && resource != null) { + throw new IllegalArgumentException("resource and filterResources cannot both be configured"); + } + if (filterResources != null) { + filterResources = ImmutableList.copyOf(filterResources); + } + } + + public static Builder builder() + { + return new Builder(); + } + + public static class Builder + { + private String operation; + private OpaQueryInputResource resource; + private List filterResources; + private OpaQueryInputResource targetResource; + private OpaQueryInputGrant grantee; + private TrinoGrantPrincipal grantor; + + private Builder() {} + + public Builder operation(String operation) + { + this.operation = operation; + return this; + } + + public Builder resource(OpaQueryInputResource resource) + { + this.resource = resource; + return this; + } + + public Builder filterResources(Collection resources) + { + this.filterResources = ImmutableList.copyOf(resources); + return this; + } + + public Builder targetResource(OpaQueryInputResource targetResource) + { + this.targetResource = targetResource; + return this; + } + + public Builder grantee(OpaQueryInputGrant grantee) + { + this.grantee = grantee; + return this; + } + + public Builder grantor(TrinoGrantPrincipal grantor) + { + this.grantor = grantor; + return this; + } + + public OpaQueryInputAction build() + { + return new OpaQueryInputAction( + this.operation, + this.resource, + this.filterResources, + this.targetResource, + this.grantee, + this.grantor); + } + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputGrant.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputGrant.java new file mode 100644 index 0000000000000..93650035bbb34 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputGrant.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.google.common.collect.ImmutableSet; +import io.trino.spi.security.Privilege; +import jakarta.validation.constraints.NotNull; + +import java.util.Set; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; +import static java.util.Objects.requireNonNull; + +@JsonInclude(NON_NULL) +public record OpaQueryInputGrant(@NotNull Set principals, Boolean grantOption, String privilege) +{ + public OpaQueryInputGrant + { + principals = ImmutableSet.copyOf(requireNonNull(principals, "principals is null")); + } + + public static Builder builder() + { + return new Builder(); + } + + public static class Builder + { + private Set principals; + private Boolean grantOption; + private String privilege; + + private Builder() {} + + public Builder principal(TrinoGrantPrincipal principal) + { + this.principals = ImmutableSet.of(principal); + return this; + } + + public Builder principals(Set principals) + { + this.principals = principals; + return this; + } + + public Builder grantOption(boolean grantOption) + { + this.grantOption = grantOption; + return this; + } + + public Builder privilege(Privilege privilege) + { + this.privilege = privilege.name(); + return this; + } + + public OpaQueryInputGrant build() + { + return new OpaQueryInputGrant(this.principals, this.grantOption, this.privilege); + } + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputResource.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputResource.java new file mode 100644 index 0000000000000..67cdb010edc15 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputResource.java @@ -0,0 +1,146 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.google.common.collect.ImmutableSet; +import jakarta.validation.constraints.NotNull; + +import java.util.Set; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Objects.requireNonNull; + +@JsonInclude(NON_NULL) +public record OpaQueryInputResource( + TrinoUser user, + NamedEntity systemSessionProperty, + TrinoCatalogSessionProperty catalogSessionProperty, + TrinoFunction function, + NamedEntity catalog, + TrinoSchema schema, + TrinoTable table, + NamedEntity role, + Set roles) +{ + public OpaQueryInputResource + { + if (roles != null) { + roles = ImmutableSet.copyOf(roles); + } + } + + public record NamedEntity(@NotNull String name) + { + public NamedEntity + { + requireNonNull(name, "name is null"); + } + } + + public static Builder builder() + { + return new Builder(); + } + + public static class Builder + { + private TrinoUser user; + private NamedEntity systemSessionProperty; + private TrinoCatalogSessionProperty catalogSessionProperty; + private NamedEntity catalog; + private TrinoSchema schema; + private TrinoTable table; + private NamedEntity role; + private Set roles; + private TrinoFunction function; + + private Builder() {} + + public Builder user(TrinoUser user) + { + this.user = user; + return this; + } + + public Builder systemSessionProperty(String systemSessionProperty) + { + this.systemSessionProperty = new NamedEntity(systemSessionProperty); + return this; + } + + public Builder catalogSessionProperty(TrinoCatalogSessionProperty catalogSessionProperty) + { + this.catalogSessionProperty = catalogSessionProperty; + return this; + } + + public Builder catalog(String catalog) + { + this.catalog = new NamedEntity(catalog); + return this; + } + + public Builder schema(TrinoSchema schema) + { + this.schema = schema; + return this; + } + + public Builder table(TrinoTable table) + { + this.table = table; + return this; + } + + public Builder role(String role) + { + this.role = new NamedEntity(role); + return this; + } + + public Builder roles(Set roles) + { + this.roles = roles.stream().map(NamedEntity::new).collect(toImmutableSet()); + return this; + } + + public Builder function(TrinoFunction function) + { + this.function = function; + return this; + } + + public Builder function(String functionName) + { + this.function = new TrinoFunction(functionName); + return this; + } + + public OpaQueryInputResource build() + { + return new OpaQueryInputResource( + this.user, + this.systemSessionProperty, + this.catalogSessionProperty, + this.function, + this.catalog, + this.schema, + this.table, + this.role, + this.roles); + } + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryResult.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryResult.java new file mode 100644 index 0000000000000..c49310bd533f8 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryResult.java @@ -0,0 +1,18 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public record OpaQueryResult(@JsonProperty("decision_id") String decisionId, boolean result) {} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/PropertiesMapper.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/PropertiesMapper.java new file mode 100644 index 0000000000000..7d38279e97a92 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/PropertiesMapper.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import java.util.Map; +import java.util.Optional; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; + +public class PropertiesMapper +{ + private PropertiesMapper() + {} + + public static Map> convertProperties(Map properties) + { + return properties.entrySet().stream() + .map(propertiesEntry -> Map.entry(propertiesEntry.getKey(), Optional.ofNullable(propertiesEntry.getValue()))) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoCatalogSessionProperty.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoCatalogSessionProperty.java new file mode 100644 index 0000000000000..59256e4037af9 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoCatalogSessionProperty.java @@ -0,0 +1,27 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import jakarta.validation.constraints.NotNull; + +import static java.util.Objects.requireNonNull; + +public record TrinoCatalogSessionProperty(@NotNull String catalogName, @NotNull String propertyName) +{ + public TrinoCatalogSessionProperty + { + requireNonNull(catalogName, "catalogName is null"); + requireNonNull(propertyName, "propertyName is null"); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoFunction.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoFunction.java new file mode 100644 index 0000000000000..b25abb0dce4c1 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoFunction.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonUnwrapped; +import io.trino.spi.connector.CatalogSchemaRoutineName; +import jakarta.validation.constraints.NotNull; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; +import static java.util.Objects.requireNonNull; + +@JsonInclude(NON_NULL) +public record TrinoFunction( + @JsonUnwrapped TrinoSchema catalogSchema, + @NotNull String functionName) +{ + public static TrinoFunction fromTrinoFunction(CatalogSchemaRoutineName catalogSchemaRoutineName) + { + return new TrinoFunction( + new TrinoSchema(catalogSchemaRoutineName.getCatalogName(), catalogSchemaRoutineName.getSchemaName()), + catalogSchemaRoutineName.getRoutineName()); + } + + public TrinoFunction(String functionName) + { + this(null, functionName); + } + + public TrinoFunction + { + requireNonNull(functionName, "functionName is null"); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoGrantPrincipal.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoGrantPrincipal.java new file mode 100644 index 0000000000000..702c2b0311c81 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoGrantPrincipal.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import com.fasterxml.jackson.annotation.JsonInclude; +import io.trino.spi.security.TrinoPrincipal; +import jakarta.validation.constraints.NotNull; + +import java.util.Optional; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; +import static java.util.Objects.requireNonNull; + +@JsonInclude(NON_NULL) +public record TrinoGrantPrincipal(@NotNull String type, @NotNull String name) +{ + public static TrinoGrantPrincipal fromTrinoPrincipal(TrinoPrincipal principal) + { + return new TrinoGrantPrincipal(principal.getType().name(), principal.getName()); + } + + public static TrinoGrantPrincipal fromTrinoPrincipal(Optional principal) + { + return principal.map(TrinoGrantPrincipal::fromTrinoPrincipal).orElse(null); + } + + public TrinoGrantPrincipal + { + requireNonNull(type, "type is null"); + requireNonNull(name, "name is null"); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoIdentity.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoIdentity.java new file mode 100644 index 0000000000000..da104199f7e2c --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoIdentity.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.spi.security.Identity; +import jakarta.validation.constraints.NotNull; + +import java.util.Map; +import java.util.Set; + +import static java.util.Objects.requireNonNull; + +public record TrinoIdentity( + @NotNull String user, + @NotNull Set groups, + @NotNull Map extraCredentials) +{ + public static TrinoIdentity fromTrinoIdentity(Identity identity) + { + return new TrinoIdentity( + identity.getUser(), + identity.getGroups(), + identity.getExtraCredentials()); + } + + public TrinoIdentity + { + requireNonNull(user, "user is null"); + groups = ImmutableSet.copyOf(requireNonNull(groups, "groups is null")); + extraCredentials = ImmutableMap.copyOf(requireNonNull(extraCredentials, "extraCredentials is null")); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoSchema.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoSchema.java new file mode 100644 index 0000000000000..334411ca05af6 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoSchema.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.google.common.collect.ImmutableMap; +import io.trino.spi.connector.CatalogSchemaName; +import jakarta.validation.constraints.NotNull; + +import java.util.Map; +import java.util.Optional; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; +import static java.util.Objects.requireNonNull; + +@JsonInclude(NON_NULL) +public record TrinoSchema( + @NotNull String catalogName, + @NotNull String schemaName, + Map> properties) +{ + public TrinoSchema + { + requireNonNull(catalogName, "catalogName is null"); + requireNonNull(schemaName, "schemaName is null"); + if (properties != null) { + properties = ImmutableMap.copyOf(properties); + } + } + + public TrinoSchema(CatalogSchemaName schema) + { + this(schema.getCatalogName(), schema.getSchemaName()); + } + + public TrinoSchema(String catalogName, String schemaName) + { + this(catalogName, schemaName, null); + } + + public TrinoSchema withProperties(Map> newProperties) + { + return new TrinoSchema(catalogName, schemaName, newProperties); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoTable.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoTable.java new file mode 100644 index 0000000000000..479d955e2a204 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoTable.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.spi.connector.CatalogSchemaTableName; +import jakarta.validation.constraints.NotNull; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; +import static java.util.Objects.requireNonNull; + +@JsonInclude(NON_NULL) +public record TrinoTable( + @NotNull String catalogName, + @NotNull String schemaName, + @NotNull String tableName, + Set columns, + Map> properties) +{ + public TrinoTable + { + requireNonNull(catalogName, "catalogName is null"); + requireNonNull(schemaName, "schemaName is null"); + requireNonNull(tableName, "tableName is null"); + if (columns != null) { + columns = ImmutableSet.copyOf(columns); + } + if (properties != null) { + properties = ImmutableMap.copyOf(properties); + } + } + + public TrinoTable(CatalogSchemaTableName table) + { + this(table.getCatalogName(), table.getSchemaTableName().getSchemaName(), table.getSchemaTableName().getTableName()); + } + + public TrinoTable(String catalogName, String schemaName, String tableName) + { + this(catalogName, schemaName, tableName, null, null); + } + + public TrinoTable withColumns(Set newColumns) + { + return new TrinoTable(catalogName, schemaName, tableName, newColumns, properties); + } + + public TrinoTable withProperties(Map> newProperties) + { + return new TrinoTable(catalogName, schemaName, tableName, columns, newProperties); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoUser.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoUser.java new file mode 100644 index 0000000000000..90829cd427554 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoUser.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonUnwrapped; +import io.trino.spi.security.Identity; + +import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; +import static java.util.Objects.requireNonNull; + +@JsonInclude(NON_NULL) +public record TrinoUser(String user, @JsonUnwrapped TrinoIdentity identity) +{ + public TrinoUser + { + if (identity == null) { + requireNonNull(user, "user is null"); + } + if (user != null && identity != null) { + throw new IllegalArgumentException("user and identity may not both be set"); + } + } + + public TrinoUser(String name) + { + this(name, null); + } + + public TrinoUser(Identity identity) + { + this(null, TrinoIdentity.fromTrinoIdentity(identity)); + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FilteringTestHelpers.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FilteringTestHelpers.java new file mode 100644 index 0000000000000..38ca687c02927 --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FilteringTestHelpers.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.function.SchemaFunctionName; +import io.trino.spi.security.Identity; +import io.trino.spi.security.SystemSecurityContext; +import org.junit.jupiter.api.Named; +import org.junit.jupiter.params.provider.Arguments; + +import java.util.Collection; +import java.util.function.BiFunction; +import java.util.stream.Stream; + +import static io.trino.plugin.opa.TestHelpers.createIllegalResponseTestCases; + +public class FilteringTestHelpers +{ + private FilteringTestHelpers() {} + + public static Stream emptyInputTestCases() + { + Stream> callables = Stream.of( + (authorizer, context) -> authorizer.filterViewQueryOwnedBy(context.getIdentity(), ImmutableSet.of()), + (authorizer, context) -> authorizer.filterCatalogs(context, ImmutableSet.of()), + (authorizer, context) -> authorizer.filterSchemas(context, "my_catalog", ImmutableSet.of()), + (authorizer, context) -> authorizer.filterTables(context, "my_catalog", ImmutableSet.of()), + (authorizer, context) -> authorizer.filterFunctions(context, "my_catalog", ImmutableSet.of())); + Stream testNames = Stream.of("filterViewQueryOwnedBy", "filterCatalogs", "filterSchemas", "filterTables", "filterFunctions"); + return Streams.zip(testNames, callables, (name, method) -> Arguments.of(Named.of(name, method))); + } + + public static Stream prepopulatedErrorCases() + { + Stream> callables = Stream.of( + (authorizer, context) -> authorizer.filterViewQueryOwnedBy(context.getIdentity(), ImmutableSet.of(Identity.ofUser("foo"))), + (authorizer, context) -> authorizer.filterCatalogs(context, ImmutableSet.of("foo")), + (authorizer, context) -> authorizer.filterSchemas(context, "my_catalog", ImmutableSet.of("foo")), + (authorizer, context) -> authorizer.filterTables(context, "my_catalog", ImmutableSet.of(new SchemaTableName("foo", "bar"))), + (authorizer, context) -> authorizer.filterColumns( + context, + "my_catalog", + ImmutableMap.of( + SchemaTableName.schemaTableName("my_schema", "my_table"), + ImmutableSet.of("some_column"))), + (authorizer, context) -> authorizer.filterFunctions( + context, + "my_catalog", + ImmutableSet.of(new SchemaFunctionName("some_schema", "some_function")))); + Stream testNames = Stream.of("filterViewQueryOwnedBy", "filterCatalogs", "filterSchemas", "filterTables", "filterColumns", "filterFunctions"); + return createIllegalResponseTestCases(Streams.zip(testNames, callables, (name, method) -> Arguments.of(Named.of(name, method)))); + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FunctionalHelpers.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FunctionalHelpers.java new file mode 100644 index 0000000000000..9a4cabccd79b8 --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FunctionalHelpers.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +public class FunctionalHelpers +{ + @FunctionalInterface + public static interface Consumer3 + { + void accept(T1 t1, T2 t2, T3 t3); + } + + @FunctionalInterface + public static interface Function3 + { + R apply(T1 t1, T2 t2, T3 t3); + } + + @FunctionalInterface + public static interface Consumer4 + { + void accept(T1 t1, T2 t2, T3 t3, T4 t4); + } + + @FunctionalInterface + public static interface Consumer5 + { + void accept(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5); + } + + @FunctionalInterface + public static interface Consumer6 + { + void accept(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6); + } + + public static class Pair + { + private T first; + private U second; + + public T getFirst() + { + return this.first; + } + + public U getSecond() + { + return this.second; + } + + public Pair(T first, U second) + { + this.first = first; + this.second = second; + } + + public static Pair of(T first, U second) + { + return new Pair(first, second); + } + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java new file mode 100644 index 0000000000000..c8e3aec7b9dc3 --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java @@ -0,0 +1,131 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import io.airlift.http.client.HttpStatus; +import io.airlift.http.client.Request; +import io.airlift.http.client.Response; +import io.airlift.http.client.StaticBodyGenerator; +import io.airlift.http.client.testing.TestingHttpClient; +import io.airlift.http.client.testing.TestingResponse; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.LinkedList; +import java.util.List; +import java.util.function.Function; + +import static com.google.common.net.HttpHeaders.CONTENT_TYPE; +import static com.google.common.net.MediaType.JSON_UTF_8; +import static java.util.Objects.requireNonNull; + +public class HttpClientUtils +{ + private HttpClientUtils() {} + + public static class RecordingHttpProcessor + implements TestingHttpClient.Processor + { + private final List requests = new LinkedList<>(); + private Function handler; + private final URI expectedURI; + private final String expectedMethod; + private final String expectedContentType; + + public RecordingHttpProcessor(URI expectedURI, String expectedMethod, String expectedContentType, Function handler) + { + this.expectedMethod = requireNonNull(expectedMethod, "expectedMethod is null"); + this.expectedContentType = requireNonNull(expectedContentType, "expectedContentType is null"); + this.expectedURI = requireNonNull(expectedURI, "expectedURI is null"); + this.handler = requireNonNull(handler, "handler is null"); + } + + @Override + public Response handle(Request request) + { + if (!requireNonNull(request.getMethod()).equalsIgnoreCase(expectedMethod)) { + throw new IllegalArgumentException("Unexpected method: %s".formatted(request.getMethod())); + } + String actualContentType = request.getHeader(CONTENT_TYPE); + if (!requireNonNull(actualContentType).equalsIgnoreCase(expectedContentType)) { + throw new IllegalArgumentException("Unexpected content type header: %s".formatted(actualContentType)); + } + if (!requireNonNull(request.getUri()).equals(expectedURI)) { + throw new IllegalArgumentException("Unexpected URI: %s".formatted(request.getUri().toString())); + } + if (requireNonNull(request.getBodyGenerator()) instanceof StaticBodyGenerator bodyGenerator) { + synchronized (this.requests) { + String requestContents = new String(bodyGenerator.getBody(), StandardCharsets.UTF_8); + requests.add(requestContents); + return handler.apply(requestContents).buildResponse(); + } + } + else { + throw new IllegalArgumentException("Request has an unexpected body generator"); + } + } + + public List getRequests() + { + synchronized (this.requests) { + return ImmutableList.copyOf(this.requests); + } + } + + public void setHandler(Function handler) + { + this.handler = handler; + } + } + + public static class InstrumentedHttpClient + extends TestingHttpClient + { + private final RecordingHttpProcessor httpProcessor; + + public InstrumentedHttpClient(URI expectedURI, String expectedMethod, String expectedContentType, Function handler) + { + this(new RecordingHttpProcessor(expectedURI, expectedMethod, expectedContentType, handler)); + } + + public InstrumentedHttpClient(RecordingHttpProcessor processor) + { + super(processor); + this.httpProcessor = processor; + } + + public void setHandler(Function handler) + { + this.httpProcessor.setHandler(handler); + } + + public List getRequests() + { + return this.httpProcessor.getRequests(); + } + } + + public record MockResponse(String contents, int statusCode) + { + public TestingResponse buildResponse() + { + return new TestingResponse( + HttpStatus.fromStatusCode(this.statusCode), + ImmutableListMultimap.of(CONTENT_TYPE, JSON_UTF_8.toString()), + this.contents.getBytes(StandardCharsets.UTF_8)); + } + }; +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlFilteringUnitTest.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlFilteringUnitTest.java new file mode 100644 index 0000000000000..85064653d6f2c --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlFilteringUnitTest.java @@ -0,0 +1,362 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.json.JsonMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.function.SchemaFunctionName; +import io.trino.spi.security.Identity; +import io.trino.spi.security.SystemSecurityContext; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.io.IOException; +import java.net.URI; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.net.MediaType.JSON_UTF_8; +import static io.trino.plugin.opa.RequestTestUtilities.assertStringRequestsEqual; +import static io.trino.plugin.opa.TestHelpers.NO_ACCESS_RESPONSE; +import static io.trino.plugin.opa.TestHelpers.OK_RESPONSE; +import static io.trino.plugin.opa.TestHelpers.systemSecurityContextFromIdentity; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class OpaAccessControlFilteringUnitTest +{ + private static final URI OPA_SERVER_URI = URI.create("http://my-uri/"); + private HttpClientUtils.InstrumentedHttpClient mockClient; + private OpaAccessControl authorizer; + private final JsonMapper jsonMapper = new JsonMapper(); + private Identity requestingIdentity; + private SystemSecurityContext requestingSecurityContext; + + @BeforeEach + public void setupAuthorizer() + { + this.mockClient = new HttpClientUtils.InstrumentedHttpClient(OPA_SERVER_URI, "POST", JSON_UTF_8.toString(), request -> OK_RESPONSE); + this.authorizer = (OpaAccessControl) OpaAccessControlFactory.create(ImmutableMap.of("opa.policy.uri", OPA_SERVER_URI.toString()), Optional.of(mockClient)); + this.requestingIdentity = Identity.ofUser("source-user"); + this.requestingSecurityContext = systemSecurityContextFromIdentity(requestingIdentity); + } + + @AfterEach + public void ensureRequestContextCorrect() + throws IOException + { + for (String request : mockClient.getRequests()) { + JsonNode parsedRequest = jsonMapper.readTree(request); + assertEquals(parsedRequest.at("/input/context/identity/user").asText(), requestingIdentity.getUser()); + } + } + + @Test + public void testFilterViewQueryOwnedBy() + { + Identity userOne = Identity.ofUser("user-one"); + Identity userTwo = Identity.ofUser("user-two"); + List requestedIdentities = ImmutableList.of(userOne, userTwo); + this.mockClient.setHandler(buildHandler("/input/action/resource/user/user", "user-one")); + + Collection result = authorizer.filterViewQueryOwnedBy( + requestingIdentity, + requestedIdentities); + assertEquals(ImmutableSet.copyOf(result), ImmutableSet.of(userOne)); + + List expectedRequests = ImmutableList.builder() + .add(""" + { + "operation": "FilterViewQueryOwnedBy", + "resource": { + "user": { + "user": "user-one", + "groups": [], + "extraCredentials": {} + } + } + } + """) + .add(""" + { + "operation": "FilterViewQueryOwnedBy", + "resource": { + "user": { + "user": "user-two", + "groups": [], + "extraCredentials": {} + } + } + } + """) + .build(); + assertStringRequestsEqual(expectedRequests, this.mockClient.getRequests(), "/input/action"); + } + + @Test + public void testFilterCatalogs() + { + Set requestedCatalogs = ImmutableSet.of("catalog_one", "catalog_two"); + this.mockClient.setHandler(buildHandler("/input/action/resource/catalog/name", "catalog_two")); + + Set result = authorizer.filterCatalogs( + requestingSecurityContext, + requestedCatalogs); + assertEquals(ImmutableSet.copyOf(result), ImmutableSet.of("catalog_two")); + + List expectedRequests = ImmutableList.builder() + .add(""" + { + "operation": "FilterCatalogs", + "resource": { + "catalog": { + "name": "catalog_one" + } + } + } + """) + .add(""" + { + "operation": "FilterCatalogs", + "resource": { + "catalog": { + "name": "catalog_two" + } + } + } + """) + .build(); + assertStringRequestsEqual(expectedRequests, this.mockClient.getRequests(), "/input/action"); + } + + @Test + public void testFilterSchemas() + { + Set requestedSchemas = ImmutableSet.of("schema_one", "schema_two"); + this.mockClient.setHandler(buildHandler("/input/action/resource/schema/schemaName", "schema_one")); + + Set result = authorizer.filterSchemas( + requestingSecurityContext, + "my_catalog", + requestedSchemas); + assertEquals(ImmutableSet.copyOf(result), ImmutableSet.of("schema_one")); + + List expectedRequests = requestedSchemas.stream() + .map(""" + { + "operation": "FilterSchemas", + "resource": { + "schema": { + "schemaName": "%s", + "catalogName": "my_catalog" + } + } + } + """::formatted) + .collect(toImmutableList()); + assertStringRequestsEqual(expectedRequests, this.mockClient.getRequests(), "/input/action"); + } + + @Test + public void testFilterTables() + { + Set tables = ImmutableSet.builder() + .add(new SchemaTableName("schema_one", "table_one")) + .add(new SchemaTableName("schema_one", "table_two")) + .add(new SchemaTableName("schema_two", "table_one")) + .add(new SchemaTableName("schema_two", "table_two")) + .build(); + this.mockClient.setHandler(buildHandler("/input/action/resource/table/tableName", "table_one")); + + Set result = authorizer.filterTables(requestingSecurityContext, "my_catalog", tables); + assertEquals(ImmutableSet.copyOf(result), tables.stream().filter(table -> table.getTableName().equals("table_one")).collect(toImmutableSet())); + + List expectedRequests = tables.stream() + .map(table -> """ + { + "operation": "FilterTables", + "resource": { + "table": { + "tableName": "%s", + "schemaName": "%s", + "catalogName": "my_catalog" + } + } + } + """.formatted(table.getTableName(), table.getSchemaName())) + .collect(toImmutableList()); + assertStringRequestsEqual(expectedRequests, this.mockClient.getRequests(), "/input/action"); + } + + @Test + public void testFilterColumns() + { + SchemaTableName tableOne = SchemaTableName.schemaTableName("my_schema", "table_one"); + SchemaTableName tableTwo = SchemaTableName.schemaTableName("my_schema", "table_two"); + SchemaTableName tableThree = SchemaTableName.schemaTableName("my_schema", "table_three"); + Map> requestedColumns = ImmutableMap.>builder() + .put(tableOne, ImmutableSet.of("table_one_column_one", "table_one_column_two")) + .put(tableTwo, ImmutableSet.of("table_two_column_one", "table_two_column_two")) + .put(tableThree, ImmutableSet.of("table_three_column_one", "table_three_column_two")) + .buildOrThrow(); + // Allow both columns from one table, one column from another one and no columns from the last one + Set columnsToAllow = ImmutableSet.builder() + .add("table_one_column_one") + .add("table_one_column_two") + .add("table_two_column_two") + .build(); + + this.mockClient.setHandler(buildHandler("/input/action/resource/table/columns/0", columnsToAllow)); + + Map> result = authorizer.filterColumns(requestingSecurityContext, "my_catalog", requestedColumns); + + List expectedRequests = requestedColumns.entrySet().stream() + .mapMulti( + (requestedColumnsForTable, accepter) -> requestedColumnsForTable.getValue().forEach( + column -> accepter.accept(""" + { + "operation": "FilterColumns", + "resource": { + "table": { + "tableName": "%s", + "schemaName": "my_schema", + "catalogName": "my_catalog", + "columns": ["%s"] + } + } + } + """.formatted(requestedColumnsForTable.getKey().getTableName(), column)))) + .collect(toImmutableList()); + assertStringRequestsEqual(expectedRequests, this.mockClient.getRequests(), "/input/action"); + assertTrue( + Maps.difference( + result, + ImmutableMap.builder() + .put(tableOne, ImmutableSet.of("table_one_column_one", "table_one_column_two")) + .put(tableTwo, ImmutableSet.of("table_two_column_two")) + .buildOrThrow()).areEqual()); + } + + @Test + public void testEmptyFilterColumns() + { + SchemaTableName someTable = SchemaTableName.schemaTableName("my_schema", "my_table"); + Map> requestedColumns = ImmutableMap.of(someTable, ImmutableSet.of()); + + Map> result = authorizer.filterColumns( + requestingSecurityContext, + "my_catalog", + requestedColumns); + + assertTrue(mockClient.getRequests().isEmpty()); + assertTrue(result.isEmpty()); + } + + @Test + public void testFilterFunctions() + { + SchemaFunctionName functionOne = new SchemaFunctionName("my_schema", "function_one"); + SchemaFunctionName functionTwo = new SchemaFunctionName("my_schema", "function_two"); + Set requestedFunctions = ImmutableSet.of(functionOne, functionTwo); + this.mockClient.setHandler(buildHandler("/input/action/resource/function/functionName", "function_two")); + + Set result = authorizer.filterFunctions( + requestingSecurityContext, + "my_catalog", + requestedFunctions); + assertEquals(ImmutableSet.copyOf(result), ImmutableSet.of(functionTwo)); + + List expectedRequests = requestedFunctions.stream() + .map(function -> """ + { + "operation": "FilterFunctions", + "resource": { + "function": { + "catalogName": "my_catalog", + "schemaName": "%s", + "functionName": "%s" + } + } + }""".formatted(function.getSchemaName(), function.getFunctionName())) + .collect(toImmutableList()); + assertStringRequestsEqual(expectedRequests, this.mockClient.getRequests(), "/input/action"); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.FilteringTestHelpers#emptyInputTestCases") + public void testEmptyRequests( + BiFunction callable) + { + Collection result = callable.apply(authorizer, requestingSecurityContext); + assertTrue(result.isEmpty()); + assertTrue(mockClient.getRequests().isEmpty()); + } + + @ParameterizedTest(name = "{index}: {0} - {1}") + @MethodSource("io.trino.plugin.opa.FilteringTestHelpers#prepopulatedErrorCases") + public void testIllegalResponseThrows( + BiFunction callable, + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Throwable actualError = assertThrows( + expectedException, + () -> callable.apply(authorizer, requestingSecurityContext)); + assertTrue( + actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + assertEquals(mockClient.getRequests().size(), 1); + } + + private Function buildHandler(String jsonPath, Set resourcesToAccept) + { + return request -> { + try { + JsonNode parsedRequest = this.jsonMapper.readTree(request); + String requestedItem = parsedRequest.at(jsonPath).asText(); + if (resourcesToAccept.contains(requestedItem)) { + return OK_RESPONSE; + } + } + catch (IOException e) { + fail("Could not parse request"); + } + return NO_ACCESS_RESPONSE; + }; + } + private Function buildHandler(String jsonPath, String resourceToAccept) + { + return buildHandler(jsonPath, ImmutableSet.of(resourceToAccept)); + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlSystemTest.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlSystemTest.java new file mode 100644 index 0000000000000..d7c3ff6954fe1 --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlSystemTest.java @@ -0,0 +1,330 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.Session; +import io.trino.plugin.blackhole.BlackHolePlugin; +import io.trino.spi.security.Identity; +import io.trino.testing.DistributedQueryRunner; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Named; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketTimeoutException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.plugin.opa.FunctionalHelpers.Pair; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; + +@Testcontainers +@TestInstance(PER_CLASS) +public class OpaAccessControlSystemTest +{ + private URI opaServerUri; + private DistributedQueryRunner runner; + + private static final int OPA_PORT = 8181; + @Container + public static GenericContainer opaContainer = new GenericContainer<>(DockerImageName.parse("openpolicyagent/opa:latest-rootless")) + .withCommand("run", "--server", "--addr", ":%d".formatted(OPA_PORT)) + .withExposedPorts(OPA_PORT); + + @Nested + @TestInstance(PER_CLASS) + @DisplayName("Unbatched Authorizer Tests") + class UnbatchedAuthorizerTests + { + @BeforeAll + public void setupTrino() + throws Exception + { + setupTrinoWithOpa("v1/data/trino/allow", Optional.empty()); + } + + @AfterAll + public void teardown() + { + if (runner != null) { + runner.close(); + } + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlSystemTest#filterSchemaTests") + public void testAllowsQueryAndFilters(String userName, Set expectedCatalogs) + throws IOException, InterruptedException + { + submitPolicy(""" + package trino + import future.keywords.in + import future.keywords.if + + default allow = false + allow { + is_bob + can_be_accessed_by_bob + } + allow if is_admin + + is_admin { + input.context.identity.user == "admin" + } + is_bob { + input.context.identity.user == "bob" + } + can_be_accessed_by_bob { + input.action.operation in ["ImpersonateUser", "ExecuteQuery"] + } + can_be_accessed_by_bob { + input.action.operation in ["FilterCatalogs", "AccessCatalog"] + input.action.resource.catalog.name == "catalog_one" + } + """); + Set catalogs = querySetOfStrings(user(userName), "SHOW CATALOGS"); + assertEquals(expectedCatalogs, catalogs); + } + + @Test + public void testShouldDenyQueryIfDirected() + throws IOException, InterruptedException + { + submitPolicy(""" + package trino + import future.keywords.in + default allow = false + + allow { + input.context.identity.user in ["someone", "admin"] + } + """); + RuntimeException error = assertThrows(RuntimeException.class, () -> { + runner.execute(user("bob"), "SHOW CATALOGS"); + }); + assertTrue(error.getMessage().contains("Access Denied"), + "Error must mention 'Access Denied': " + error.getMessage()); + // smoke test: we can still query if we are the right user + runner.execute(user("admin"), "SHOW CATALOGS"); + } + } + + @Nested + @TestInstance(PER_CLASS) + @DisplayName("Batched Authorizer Tests") + class BatchedAuthorizerTests + { + @BeforeAll + public void setupTrino() + throws Exception + { + setupTrinoWithOpa("v1/data/trino/allow", Optional.of("v1/data/trino/batchAllow")); + } + + @AfterAll + public void teardown() + { + if (runner != null) { + runner.close(); + } + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlSystemTest#filterSchemaTests") + public void testFilterOutItemsBatch(String userName, Set expectedCatalogs) + throws IOException, InterruptedException + { + submitPolicy(""" + package trino + import future.keywords.in + import future.keywords.if + default allow = false + + allow if is_admin + + allow { + is_bob + input.action.operation in ["AccessCatalog", "ExecuteQuery", "ImpersonateUser", "ShowSchemas", "SelectFromColumns"] + } + + is_bob { + input.context.identity.user == "bob" + } + + is_admin { + input.context.identity.user == "admin" + } + + batchAllow[i] { + some i + is_bob + input.action.operation == "FilterCatalogs" + input.action.filterResources[i].catalog.name == "catalog_one" + } + + batchAllow[i] { + some i + input.action.filterResources[i] + is_admin + } + """); + Set catalogs = querySetOfStrings(user(userName), "SHOW CATALOGS"); + assertEquals(expectedCatalogs, catalogs); + } + + @Test + public void testDenyUnbatchedQuery() + throws IOException, InterruptedException + { + submitPolicy(""" + package trino + import future.keywords.in + default allow = false + """); + RuntimeException error = assertThrows(RuntimeException.class, () -> { + runner.execute(user("bob"), "SELECT version()"); + }); + assertTrue(error.getMessage().contains("Access Denied"), + "Error must mention 'Access Denied': " + error.getMessage()); + } + + @Test + public void testAllowUnbatchedQuery() + throws IOException, InterruptedException + { + submitPolicy(""" + package trino + import future.keywords.in + default allow = false + allow { + input.context.identity.user == "bob" + input.action.operation in ["ImpersonateUser", "ExecuteFunction", "AccessCatalog", "ExecuteQuery"] + } + """); + Set version = querySetOfStrings(user("bob"), "SELECT version()"); + assertFalse(version.isEmpty()); + } + } + + private void ensureOpaUp() + throws IOException, InterruptedException + { + assertTrue(opaContainer.isRunning()); + InetSocketAddress opaSocket = new InetSocketAddress(opaContainer.getHost(), opaContainer.getMappedPort(OPA_PORT)); + String opaEndpoint = String.format("%s:%d", opaSocket.getHostString(), opaSocket.getPort()); + awaitSocketOpen(opaSocket, 100, 200); + this.opaServerUri = URI.create(String.format("http://%s/", opaEndpoint)); + } + + private void setupTrinoWithOpa(String basePolicyRelativeUri, Optional batchPolicyRelativeUri) + throws Exception + { + ensureOpaUp(); + ImmutableMap.Builder opaConfigBuilder = ImmutableMap.builder(); + opaConfigBuilder.put("opa.policy.uri", opaServerUri.resolve(basePolicyRelativeUri).toString()); + batchPolicyRelativeUri.ifPresent(relativeUri -> opaConfigBuilder.put("opa.policy.batched-uri", opaServerUri.resolve(relativeUri).toString())); + this.runner = DistributedQueryRunner.builder(testSessionBuilder().build()) + .setSystemAccessControl(new OpaAccessControlFactory().create(opaConfigBuilder.buildOrThrow())) + .setNodeCount(1) + .build(); + runner.installPlugin(new BlackHolePlugin()); + runner.createCatalog("catalog_one", "blackhole"); + runner.createCatalog("catalog_two", "blackhole"); + } + + private static void awaitSocketOpen(InetSocketAddress addr, int attempts, int timeoutMs) + throws IOException, InterruptedException + { + for (int i = 0; i < attempts; ++i) { + try (Socket socket = new Socket()) { + socket.connect(addr, timeoutMs); + return; + } + catch (SocketTimeoutException e) { + // ignored + } + catch (IOException e) { + Thread.sleep(timeoutMs); + } + } + throw new SocketTimeoutException("Timed out waiting for addr %s to be available (%d attempts made with a %d ms wait)".formatted(addr, attempts, timeoutMs)); + } + + private static String stringOfLines(String... lines) + { + StringBuilder out = new StringBuilder(); + for (String line : lines) { + out.append(line); + out.append("\r\n"); + } + return out.toString(); + } + + private void submitPolicy(String... policyLines) + throws IOException, InterruptedException + { + HttpClient httpClient = HttpClient.newHttpClient(); + HttpResponse policyResponse = + httpClient.send( + HttpRequest.newBuilder(opaServerUri.resolve("v1/policies/trino")) + .PUT(HttpRequest.BodyPublishers.ofString(stringOfLines(policyLines))) + .header("Content-Type", "text/plain").build(), + HttpResponse.BodyHandlers.ofString()); + assertEquals(policyResponse.statusCode(), 200, "Failed to submit policy: " + policyResponse.body()); + } + + private Session user(String user) + { + return testSessionBuilder().setIdentity(Identity.ofUser(user)).build(); + } + + private Set querySetOfStrings(Session session, String query) + { + return runner.execute(session, query).getMaterializedRows().stream().map(row -> row.getField(0).toString()).collect(toImmutableSet()); + } + + private static Stream filterSchemaTests() + { + Stream>> userAndExpectedCatalogs = Stream.of( + Pair.of("bob", ImmutableSet.of("catalog_one")), + Pair.of("admin", ImmutableSet.of("catalog_one", "catalog_two", "system"))); + return userAndExpectedCatalogs.map(testCase -> Arguments.of(Named.of(testCase.getFirst(), testCase.getFirst()), testCase.getSecond())); + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlUnitTest.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlUnitTest.java new file mode 100644 index 0000000000000..1cf2265f6e403 --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlUnitTest.java @@ -0,0 +1,1459 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.json.JsonMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; +import io.trino.spi.connector.CatalogSchemaName; +import io.trino.spi.connector.CatalogSchemaRoutineName; +import io.trino.spi.connector.CatalogSchemaTableName; +import io.trino.spi.security.Identity; +import io.trino.spi.security.PrincipalType; +import io.trino.spi.security.Privilege; +import io.trino.spi.security.SystemSecurityContext; +import io.trino.spi.security.TrinoPrincipal; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Named; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.io.IOException; +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.stream.Stream; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.net.MediaType.JSON_UTF_8; +import static io.trino.plugin.opa.HttpClientUtils.InstrumentedHttpClient; +import static io.trino.plugin.opa.RequestTestUtilities.assertJsonRequestsEqual; +import static io.trino.plugin.opa.RequestTestUtilities.assertStringRequestsEqual; +import static io.trino.plugin.opa.TestHelpers.NO_ACCESS_RESPONSE; +import static io.trino.plugin.opa.TestHelpers.OK_RESPONSE; +import static io.trino.plugin.opa.TestHelpers.convertSystemSecurityContextToIdentityArgument; +import static io.trino.plugin.opa.TestHelpers.createFailingTestCases; +import static io.trino.plugin.opa.TestHelpers.createIllegalResponseTestCases; +import static io.trino.plugin.opa.TestHelpers.systemSecurityContextFromIdentity; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class OpaAccessControlUnitTest +{ + private static final URI OPA_SERVER_URI = URI.create("http://my-uri/"); + private InstrumentedHttpClient mockClient; + private OpaAccessControl authorizer; + private final JsonMapper jsonMapper = new JsonMapper(); + private Identity requestingIdentity; + private SystemSecurityContext requestingSecurityContext; + + @BeforeEach + public void setupAuthorizer() + { + this.mockClient = new InstrumentedHttpClient(OPA_SERVER_URI, "POST", JSON_UTF_8.toString(), request -> OK_RESPONSE); + this.authorizer = (OpaAccessControl) new OpaAccessControlFactory().create(ImmutableMap.of("opa.policy.uri", OPA_SERVER_URI.toString()), Optional.of(mockClient)); + this.requestingIdentity = Identity.ofUser("source-user"); + this.requestingSecurityContext = systemSecurityContextFromIdentity(requestingIdentity); + } + + @AfterEach + public void ensureRequestContextCorrect() + throws IOException + { + for (String request : mockClient.getRequests()) { + JsonNode parsedRequest = jsonMapper.readTree(request); + assertEquals(parsedRequest.at("/input/context/identity/user").asText(), requestingIdentity.getUser()); + } + } + + @Test + public void testResponseHasExtraFields() + { + mockClient.setHandler(request -> new HttpClientUtils.MockResponse(""" + { + "result": true, + "decision_id": "foo", + "some_debug_info": {"test": ""} + } + """, + 200)); + authorizer.checkCanShowRoles(requestingSecurityContext); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#noResourceActionTestCases") + public void testNoResourceAction(String actionName, BiConsumer method) + { + method.accept(authorizer, requestingSecurityContext); + ObjectNode expectedRequest = jsonMapper.createObjectNode().put("operation", actionName); + assertJsonRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + } + + @ParameterizedTest(name = "{index}: {0} - {2}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#noResourceActionFailureTestCases") + public void testNoResourceActionFailure( + String actionName, + BiConsumer method, + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Throwable actualError = assertThrows( + expectedException, + () -> method.accept(authorizer, requestingSecurityContext)); + ObjectNode expectedRequest = jsonMapper.createObjectNode().put("operation", actionName); + assertJsonRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + assertTrue(actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + private static Stream tableResourceTestCases() + { + Stream> methods = Stream.of( + OpaAccessControl::checkCanShowCreateTable, + OpaAccessControl::checkCanDropTable, + OpaAccessControl::checkCanSetTableComment, + OpaAccessControl::checkCanSetViewComment, + OpaAccessControl::checkCanSetColumnComment, + OpaAccessControl::checkCanShowColumns, + OpaAccessControl::checkCanAddColumn, + OpaAccessControl::checkCanDropColumn, + OpaAccessControl::checkCanAlterColumn, + OpaAccessControl::checkCanRenameColumn, + OpaAccessControl::checkCanInsertIntoTable, + OpaAccessControl::checkCanDeleteFromTable, + OpaAccessControl::checkCanTruncateTable, + OpaAccessControl::checkCanCreateView, + OpaAccessControl::checkCanDropView, + OpaAccessControl::checkCanRefreshMaterializedView, + OpaAccessControl::checkCanDropMaterializedView); + Stream actions = Stream.of( + "ShowCreateTable", + "DropTable", + "SetTableComment", + "SetViewComment", + "SetColumnComment", + "ShowColumns", + "AddColumn", + "DropColumn", + "AlterColumn", + "RenameColumn", + "InsertIntoTable", + "DeleteFromTable", + "TruncateTable", + "CreateView", + "DropView", + "RefreshMaterializedView", + "DropMaterializedView"); + return Streams.zip(actions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tableResourceTestCases") + public void testTableResourceActions( + String actionName, + FunctionalHelpers.Consumer3 callable) + { + callable.accept( + authorizer, + requestingSecurityContext, + new CatalogSchemaTableName("my_catalog", "my_schema", "my_table")); + + String expectedRequest = """ + { + "operation": "%s", + "resource": { + "table": { + "catalogName": "my_catalog", + "schemaName": "my_schema", + "tableName": "my_table" + } + } + } + """.formatted(actionName); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + } + + private static Stream tableResourceFailureTestCases() + { + return createFailingTestCases(tableResourceTestCases()); + } + + @ParameterizedTest(name = "{index}: {0} - {3}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tableResourceFailureTestCases") + public void testTableResourceFailure( + String actionName, + FunctionalHelpers.Consumer3 method, + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Throwable actualError = assertThrows( + expectedException, + () -> method.accept( + authorizer, + requestingSecurityContext, + new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"))); + assertTrue(actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + private static Stream tableWithPropertiesTestCases() + { + Stream> methods = Stream.of( + OpaAccessControl::checkCanSetTableProperties, + OpaAccessControl::checkCanSetMaterializedViewProperties, + OpaAccessControl::checkCanCreateTable, + OpaAccessControl::checkCanCreateMaterializedView); + Stream actions = Stream.of( + "SetTableProperties", + "SetMaterializedViewProperties", + "CreateTable", + "CreateMaterializedView"); + return Streams.zip(actions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tableWithPropertiesTestCases") + public void testTableWithPropertiesActions( + String actionName, + FunctionalHelpers.Consumer4 callable) + { + CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); + Map> properties = ImmutableMap.>builder() + .put("string_item", Optional.of("string_value")) + .put("empty_item", Optional.empty()) + .put("boxed_number_item", Optional.of(Integer.valueOf(32))) + .buildOrThrow(); + + callable.accept(authorizer, requestingSecurityContext, table, properties); + + String expectedRequest = """ + { + "operation": "%s", + "resource": { + "table": { + "tableName": "my_table", + "catalogName": "my_catalog", + "schemaName": "my_schema", + "properties": { + "string_item": "string_value", + "empty_item": null, + "boxed_number_item": 32 + } + } + } + } + """.formatted(actionName); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + } + + private static Stream tableWithPropertiesFailureTestCases() + { + return createFailingTestCases(tableWithPropertiesTestCases()); + } + + @ParameterizedTest(name = "{index}: {0} - {3}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tableWithPropertiesFailureTestCases") + public void testTableWithPropertiesActionFailure( + String actionName, + FunctionalHelpers.Consumer4 method, + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Throwable actualError = assertThrows( + expectedException, + () -> method.accept( + authorizer, + requestingSecurityContext, + new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"), + ImmutableMap.of())); + assertTrue(actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + private static Stream identityResourceTestCases() + { + Stream> methods = Stream.of( + OpaAccessControl::checkCanViewQueryOwnedBy, + OpaAccessControl::checkCanKillQueryOwnedBy); + Stream actions = Stream.of( + "ViewQueryOwnedBy", + "KillQueryOwnedBy"); + return Streams.zip(actions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#identityResourceTestCases") + public void testIdentityResourceActions( + String actionName, + FunctionalHelpers.Consumer3 callable) + { + Identity dummyIdentity = Identity.forUser("dummy-user") + .withGroups(ImmutableSet.of("some-group")) + .withExtraCredentials(ImmutableMap.of("some_extra_credential", "value")) + .build(); + callable.accept(authorizer, requestingIdentity, dummyIdentity); + + String expectedRequest = """ + { + "operation": "%s", + "resource": { + "user": { + "user": "dummy-user", + "groups": ["some-group"], + "extraCredentials": {"some_extra_credential": "value"} + } + } + } + """.formatted(actionName); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + } + + private static Stream identityResourceFailureTestCases() + { + return createFailingTestCases(identityResourceTestCases()); + } + + @ParameterizedTest(name = "{index}: {0} - {2}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#identityResourceFailureTestCases") + public void testIdentityResourceActionsFailure( + String actionName, + FunctionalHelpers.Consumer3 method, + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Throwable actualError = assertThrows( + expectedException, + () -> method.accept( + authorizer, + requestingIdentity, + Identity.ofUser("dummy-user"))); + assertTrue(actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + private static Stream stringResourceTestCases() + { + Stream> methods = Stream.of( + convertSystemSecurityContextToIdentityArgument(OpaAccessControl::checkCanSetSystemSessionProperty), + OpaAccessControl::checkCanCreateCatalog, + OpaAccessControl::checkCanDropCatalog, + OpaAccessControl::checkCanShowSchemas, + OpaAccessControl::checkCanDropRole); + Stream> actionAndResource = Stream.of( + FunctionalHelpers.Pair.of("SetSystemSessionProperty", "systemSessionProperty"), + FunctionalHelpers.Pair.of("CreateCatalog", "catalog"), + FunctionalHelpers.Pair.of("DropCatalog", "catalog"), + FunctionalHelpers.Pair.of("ShowSchemas", "catalog"), + FunctionalHelpers.Pair.of("DropRole", "role")); + return Streams.zip( + actionAndResource, + methods, + (action, method) -> Arguments.of(Named.of(action.getFirst(), action.getFirst()), action.getSecond(), method)); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#stringResourceTestCases") + public void testStringResourceAction( + String actionName, + String resourceName, + FunctionalHelpers.Consumer3 callable) + { + callable.accept(authorizer, requestingSecurityContext, "resource_name"); + + String expectedRequest = """ + { + "operation": "%s", + "resource": { + "%s": { + "name": "resource_name" + } + } + } + """.formatted(actionName, resourceName); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + } + + public static Stream stringResourceFailureTestCases() + { + return createFailingTestCases(stringResourceTestCases()); + } + + @ParameterizedTest(name = "{index}: {0} - {3}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#stringResourceFailureTestCases") + public void testStringResourceActionsFailure( + String actionName, + String resourceName, + FunctionalHelpers.Consumer3 method, + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Throwable actualError = assertThrows( + expectedException, + () -> method.accept( + authorizer, + requestingSecurityContext, + "dummy_value")); + assertTrue(actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + @Test + public void testCanImpersonateUser() + { + authorizer.checkCanImpersonateUser(requestingIdentity, "some_other_user"); + + String expectedRequest = """ + { + "operation": "ImpersonateUser", + "resource": { + "user": { + "user": "some_other_user" + } + } + } + """; + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.TestHelpers#allErrorCasesArgumentProvider") + public void testCanImpersonateUserFailure( + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Throwable actualError = assertThrows( + expectedException, + () -> authorizer.checkCanImpersonateUser(requestingIdentity, "some_other_user")); + assertTrue(actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + @Test + public void testCanAccessCatalog() + { + mockClient.setHandler(request -> OK_RESPONSE); + assertTrue(authorizer.canAccessCatalog(requestingSecurityContext, "my_catalog_one")); + + mockClient.setHandler(request -> NO_ACCESS_RESPONSE); + assertFalse(authorizer.canAccessCatalog(requestingSecurityContext, "my_catalog_two")); + + Set expectedRequests = ImmutableSet.of("my_catalog_one", "my_catalog_two").stream().map(""" + { + "operation": "AccessCatalog", + "resource": { + "catalog": { + "name": "%s" + } + } + } + """::formatted) + .collect(toImmutableSet()); + assertStringRequestsEqual(expectedRequests, mockClient.getRequests(), "/input/action"); + } + + @ParameterizedTest(name = "{index}: {0} - {3}") + @MethodSource("io.trino.plugin.opa.TestHelpers#illegalResponseArgumentProvider") + public void testCanAccessCatalogIllegalResponses( + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Throwable actualError = assertThrows( + expectedException, + () -> authorizer.canAccessCatalog(requestingSecurityContext, "my_catalog")); + assertTrue(actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + private static Stream schemaResourceTestCases() + { + Stream> methods = Stream.of( + OpaAccessControl::checkCanDropSchema, + OpaAccessControl::checkCanShowCreateSchema, + OpaAccessControl::checkCanShowTables, + OpaAccessControl::checkCanShowFunctions); + Stream actions = Stream.of( + "DropSchema", + "ShowCreateSchema", + "ShowTables", + "ShowFunctions"); + return Streams.zip(actions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#schemaResourceTestCases") + public void testSchemaResourceActions( + String actionName, + FunctionalHelpers.Consumer3 callable) + { + callable.accept(authorizer, requestingSecurityContext, new CatalogSchemaName("my_catalog", "my_schema")); + + String expectedRequest = """ + { + "operation": "%s", + "resource": { + "schema": { + "catalogName": "my_catalog", + "schemaName": "my_schema" + } + } + } + """.formatted(actionName); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + } + + public static Stream schemaResourceFailureTestCases() + { + return createFailingTestCases(schemaResourceTestCases()); + } + + @ParameterizedTest(name = "{index}: {0} - {2}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#schemaResourceFailureTestCases") + public void testSchemaResourceActionsFailure( + String actionName, + FunctionalHelpers.Consumer3 method, + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Throwable actualError = assertThrows( + expectedException, + () -> method.accept( + authorizer, + requestingSecurityContext, + new CatalogSchemaName("dummy_catalog", "dummy_schema"))); + assertTrue(actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + @Test + public void testCreateSchema() + { + CatalogSchemaName schema = new CatalogSchemaName("my_catalog", "my_schema"); + authorizer.checkCanCreateSchema(requestingSecurityContext, schema, ImmutableMap.of("some_key", "some_value")); + authorizer.checkCanCreateSchema(requestingSecurityContext, schema, ImmutableMap.of()); + + List expectedRequests = ImmutableList.builder() + .add(""" + { + "operation": "CreateSchema", + "resource": { + "schema": { + "catalogName": "my_catalog", + "schemaName": "my_schema", + "properties": { + "some_key": "some_value" + } + } + } + } + """) + .add(""" + { + "operation": "CreateSchema", + "resource": { + "schema": { + "catalogName": "my_catalog", + "schemaName": "my_schema", + "properties": {} + } + } + } + """) + .build(); + assertStringRequestsEqual(expectedRequests, mockClient.getRequests(), "/input/action"); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.TestHelpers#allErrorCasesArgumentProvider") + public void testCreateSchemaFailure( + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Throwable actualError = assertThrows( + expectedException, + () -> authorizer.checkCanCreateSchema( + requestingSecurityContext, + new CatalogSchemaName("my_catalog", "my_schema"), + ImmutableMap.of())); + assertTrue(actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + @Test + public void testCanRenameSchema() + { + CatalogSchemaName sourceSchema = new CatalogSchemaName("my_catalog", "my_schema"); + authorizer.checkCanRenameSchema(requestingSecurityContext, sourceSchema, "new_schema_name"); + + String expectedRequest = """ + { + "operation": "RenameSchema", + "resource": { + "schema": { + "catalogName": "my_catalog", + "schemaName": "my_schema" + } + }, + "targetResource": { + "schema": { + "catalogName": "my_catalog", + "schemaName": "new_schema_name" + } + } + } + """; + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.TestHelpers#allErrorCasesArgumentProvider") + public void testCanRenameSchemaFailure( + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Throwable actualError = assertThrows( + expectedException, + () -> authorizer.checkCanRenameSchema( + requestingSecurityContext, + new CatalogSchemaName("my_catalog", "my_schema"), + "new_schema_name")); + assertTrue(actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + private static Stream renameTableTestCases() + { + Stream> methods = Stream.of( + OpaAccessControl::checkCanRenameTable, + OpaAccessControl::checkCanRenameView, + OpaAccessControl::checkCanRenameMaterializedView); + Stream actions = Stream.of( + "RenameTable", + "RenameView", + "RenameMaterializedView"); + return Streams.zip(actions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#renameTableTestCases") + public void testRenameTableActions( + String actionName, + FunctionalHelpers.Consumer4 method) + { + CatalogSchemaTableName sourceTable = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); + CatalogSchemaTableName targetTable = new CatalogSchemaTableName("my_catalog", "new_schema_name", "new_table_name"); + + method.accept(authorizer, requestingSecurityContext, sourceTable, targetTable); + + String expectedRequest = """ + { + "operation": "%s", + "resource": { + "table": { + "catalogName": "my_catalog", + "schemaName": "my_schema", + "tableName": "my_table" + } + }, + "targetResource": { + "table": { + "catalogName": "my_catalog", + "schemaName": "new_schema_name", + "tableName": "new_table_name" + } + } + } + """.formatted(actionName); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + } + + public static Stream renameTableFailureTestCases() + { + return createFailingTestCases(renameTableTestCases()); + } + + @ParameterizedTest(name = "{index}: {0} - {3}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#renameTableFailureTestCases") + public void testRenameTableFailure( + String actionName, + FunctionalHelpers.Consumer4 method, + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + CatalogSchemaTableName sourceTable = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); + CatalogSchemaTableName targetTable = new CatalogSchemaTableName("my_catalog", "new_schema_name", "new_table_name"); + Throwable actualError = assertThrows( + expectedException, + () -> method.accept( + authorizer, + requestingSecurityContext, + sourceTable, + targetTable)); + assertTrue(actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + @Test + public void testCanSetSchemaAuthorization() + { + CatalogSchemaName schema = new CatalogSchemaName("my_catalog", "my_schema"); + + authorizer.checkCanSetSchemaAuthorization(requestingSecurityContext, schema, new TrinoPrincipal(PrincipalType.USER, "my_user")); + + String expectedRequest = """ + { + "operation": "SetSchemaAuthorization", + "resource": { + "schema": { + "catalogName": "my_catalog", + "schemaName": "my_schema" + } + }, + "grantee": { + "principals": [ + { + "name": "my_user", + "type": "USER" + } + ] + } + } + """; + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.TestHelpers#allErrorCasesArgumentProvider") + public void testCanSetSchemaAuthorizationFailure( + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + CatalogSchemaName schema = new CatalogSchemaName("my_catalog", "my_schema"); + Throwable actualError = assertThrows( + expectedException, + () -> authorizer.checkCanSetSchemaAuthorization( + requestingSecurityContext, + schema, + new TrinoPrincipal(PrincipalType.USER, "my_user"))); + assertTrue(actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + private static Stream setTableAuthorizationTestCases() + { + Stream> methods = Stream.of( + OpaAccessControl::checkCanSetTableAuthorization, + OpaAccessControl::checkCanSetViewAuthorization); + Stream actions = Stream.of( + "SetTableAuthorization", + "SetViewAuthorization"); + return Streams.zip(actions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#setTableAuthorizationTestCases") + public void testCanSetTableAuthorization( + String actionName, + FunctionalHelpers.Consumer4 method) + { + CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); + + method.accept(authorizer, requestingSecurityContext, table, new TrinoPrincipal(PrincipalType.USER, "my_user")); + + String expectedRequest = """ + { + "operation": "%s", + "resource": { + "table": { + "catalogName": "my_catalog", + "schemaName": "my_schema", + "tableName": "my_table" + } + }, + "grantee": { + "principals": [ + { + "name": "my_user", + "type": "USER" + } + ] + } + } + """.formatted(actionName); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + } + + private static Stream setTableAuthorizationFailureTestCases() + { + return createFailingTestCases(setTableAuthorizationTestCases()); + } + + @ParameterizedTest(name = "{index}: {0} - {3}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#setTableAuthorizationFailureTestCases") + public void testCanSetTableAuthorizationFailure( + String actionName, + FunctionalHelpers.Consumer4 method, + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); + + Throwable actualError = assertThrows( + expectedException, + () -> method.accept( + authorizer, + requestingSecurityContext, + table, + new TrinoPrincipal(PrincipalType.USER, "my_user"))); + assertTrue(actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + private static Stream tableColumnOperationTestCases() + { + Stream>> methods = Stream.of( + OpaAccessControl::checkCanSelectFromColumns, + OpaAccessControl::checkCanUpdateTableColumns, + OpaAccessControl::checkCanCreateViewWithSelectFromColumns); + Stream actionAndResource = Stream.of( + "SelectFromColumns", + "UpdateTableColumns", + "CreateViewWithSelectFromColumns"); + return Streams.zip(actionAndResource, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tableColumnOperationTestCases") + public void testTableColumnOperations( + String actionName, + FunctionalHelpers.Consumer4> method) + { + CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); + Set columns = ImmutableSet.of("my_column"); + + method.accept(authorizer, requestingSecurityContext, table, columns); + + String expectedRequest = """ + { + "operation": "%s", + "resource": { + "table": { + "catalogName": "my_catalog", + "schemaName": "my_schema", + "tableName": "my_table", + "columns": ["my_column"] + } + } + } + """.formatted(actionName); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + } + + private static Stream tableColumnOperationFailureTestCases() + { + return createFailingTestCases(tableColumnOperationTestCases()); + } + + @ParameterizedTest(name = "{index}: {0} - {2}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tableColumnOperationFailureTestCases") + public void testTableColumnOperationsFailure( + String actionName, + FunctionalHelpers.Consumer4> method, + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); + Set columns = ImmutableSet.of("my_column"); + + Throwable actualError = assertThrows( + expectedException, + () -> method.accept(authorizer, requestingSecurityContext, table, columns)); + assertTrue( + actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + @Test + public void testCanSetCatalogSessionProperty() + { + authorizer.checkCanSetCatalogSessionProperty( + requestingSecurityContext, "my_catalog", "my_property"); + + String expectedRequest = """ + { + "operation": "SetCatalogSessionProperty", + "resource": { + "catalogSessionProperty": { + "catalogName": "my_catalog", + "propertyName": "my_property" + } + } + } + """; + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.TestHelpers#allErrorCasesArgumentProvider") + public void testCanSetCatalogSessionPropertyFailure( + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Throwable actualError = assertThrows( + expectedException, + () -> authorizer.checkCanSetCatalogSessionProperty( + requestingSecurityContext, + "my_catalog", + "my_property")); + assertTrue( + actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + private static Stream schemaPrivilegeTestCases() + { + Stream> methods = Stream.of( + OpaAccessControl::checkCanDenySchemaPrivilege, + (authorizer, context, privilege, catalog, principal) -> authorizer.checkCanGrantSchemaPrivilege( + context, privilege, catalog, principal, true), + (authorizer, context, privilege, catalog, principal) -> authorizer.checkCanRevokeSchemaPrivilege( + context, privilege, catalog, principal, true)); + Stream actions = Stream.of( + "DenySchemaPrivilege", + "GrantSchemaPrivilege", + "RevokeSchemaPrivilege"); + return Streams.zip(actions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#schemaPrivilegeTestCases") + public void testSchemaPrivileges( + String actionName, + FunctionalHelpers.Consumer5 method) + throws IOException + { + Privilege privilege = Privilege.CREATE; + method.accept( + authorizer, + requestingSecurityContext, + privilege, + new CatalogSchemaName("my_catalog", "my_schema"), + new TrinoPrincipal(PrincipalType.USER, "my_user")); + + String expectedRequest = """ + { + "operation": "%s", + "resource": { + "schema": { + "catalogName": "my_catalog", + "schemaName": "my_schema" + } + }, + "grantee": { + "principals": [ + { + "name": "my_user", + "type": "USER" + } + ], + "privilege": "CREATE", + "grantOption": true + } + } + """.formatted(actionName); + List actualRequests = mockClient.getRequests(); + assertEquals(actualRequests.size(), 1, "Unexpected number of requests"); + + JsonNode actualRequestInput = jsonMapper.readTree(mockClient.getRequests().get(0)).at("/input/action"); + if (!actualRequestInput.at("/grantee").has("grantOption")) { + // The DenySchemaPrivilege request does not have a grant option, we'll default it to true so we can use the same test + ((ObjectNode) actualRequestInput.at("/grantee")).put("grantOption", true); + } + assertEquals(jsonMapper.readTree(expectedRequest), actualRequestInput); + } + + private static Stream schemaPrivilegeFailureTestCases() + { + return createFailingTestCases(schemaPrivilegeTestCases()); + } + + @ParameterizedTest(name = "{index}: {0} - {2}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#schemaPrivilegeFailureTestCases") + public void testSchemaPrivilegesFailure( + String actionName, + FunctionalHelpers.Consumer5 method, + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Privilege privilege = Privilege.CREATE; + Throwable actualError = assertThrows( + expectedException, + () -> method.accept( + authorizer, + requestingSecurityContext, + privilege, + new CatalogSchemaName("my_catalog", "my_schema"), + new TrinoPrincipal(PrincipalType.USER, "my_user"))); + assertTrue( + actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + private static Stream tablePrivilegeTestCases() + { + Stream> methods = Stream.of( + OpaAccessControl::checkCanDenyTablePrivilege, + (authorizer, context, privilege, catalog, principal) -> authorizer.checkCanGrantTablePrivilege(context, privilege, catalog, principal, true), + (authorizer, context, privilege, catalog, principal) -> authorizer.checkCanRevokeTablePrivilege(context, privilege, catalog, principal, true)); + Stream actions = Stream.of( + "DenyTablePrivilege", + "GrantTablePrivilege", + "RevokeTablePrivilege"); + return Streams.zip(actions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tablePrivilegeTestCases") + public void testTablePrivileges( + String actionName, + FunctionalHelpers.Consumer5 method) + throws IOException + { + Privilege privilege = Privilege.CREATE; + method.accept( + authorizer, + requestingSecurityContext, + privilege, + new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"), + new TrinoPrincipal(PrincipalType.USER, "my_user")); + + String expectedRequest = """ + { + "operation": "%s", + "resource": { + "table": { + "catalogName": "my_catalog", + "schemaName": "my_schema", + "tableName": "my_table" + } + }, + "grantee": { + "principals": [ + { + "name": "my_user", + "type": "USER" + } + ], + "privilege": "CREATE", + "grantOption": true + } + } + """.formatted(actionName); + List actualRequests = mockClient.getRequests(); + assertEquals(actualRequests.size(), 1, "Unexpected number of requests"); + + JsonNode actualRequestInput = jsonMapper.readTree(mockClient.getRequests().get(0)).at("/input/action"); + if (!actualRequestInput.at("/grantee").has("grantOption")) { + // The DenySchemaPrivilege request does not have a grant option, we'll default it to true so we can use the same test + ((ObjectNode) actualRequestInput.at("/grantee")).put("grantOption", true); + } + assertEquals(jsonMapper.readTree(expectedRequest), actualRequestInput); + } + + private static Stream tablePrivilegeFailureTestCases() + { + return createFailingTestCases(tablePrivilegeTestCases()); + } + + @ParameterizedTest(name = "{index}: {0} - {2}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tablePrivilegeFailureTestCases") + public void testTablePrivilegesFailure( + String actionName, + FunctionalHelpers.Consumer5 method, + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Privilege privilege = Privilege.CREATE; + Throwable actualError = assertThrows( + expectedException, + () -> method.accept( + authorizer, + requestingSecurityContext, + privilege, + new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"), + new TrinoPrincipal(PrincipalType.USER, "my_user"))); + assertTrue( + actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + @Test + public void testCanCreateRole() + { + authorizer.checkCanCreateRole(requestingSecurityContext, "my_role_without_grantor", Optional.empty()); + TrinoPrincipal grantor = new TrinoPrincipal(PrincipalType.USER, "my_grantor"); + authorizer.checkCanCreateRole(requestingSecurityContext, "my_role_with_grantor", Optional.of(grantor)); + + Set expectedRequests = ImmutableSet.builder() + .add(""" + { + "operation": "CreateRole", + "resource": { + "role": { + "name": "my_role_without_grantor" + } + } + } + """) + .add(""" + { + "operation": "CreateRole", + "resource": { + "role": { + "name": "my_role_with_grantor" + } + }, + "grantor": { + "name": "my_grantor", + "type": "USER" + } + } + """) + .build(); + assertStringRequestsEqual(expectedRequests, mockClient.getRequests(), "/input/action"); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.TestHelpers#allErrorCasesArgumentProvider") + public void testCanCreateRoleFailure( + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Throwable actualError = assertThrows( + expectedException, + () -> authorizer.checkCanCreateRole( + requestingSecurityContext, + "my_role_without_grantor", + Optional.empty())); + assertTrue( + actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + private static Stream roleGrantingTestCases() + { + Stream, Set, Boolean, Optional>> methods = Stream.of( + OpaAccessControl::checkCanGrantRoles, + OpaAccessControl::checkCanRevokeRoles); + Stream actions = Stream.of( + "GrantRoles", + "RevokeRoles"); + return Streams.zip(actions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#roleGrantingTestCases") + public void testRoleGranting( + String actionName, + FunctionalHelpers.Consumer6, Set, Boolean, Optional> method) + { + TrinoPrincipal grantee = new TrinoPrincipal(PrincipalType.ROLE, "my_grantee_role"); + method.accept(authorizer, requestingSecurityContext, ImmutableSet.of("my_role_without_grantor"), ImmutableSet.of(grantee), true, Optional.empty()); + + TrinoPrincipal grantor = new TrinoPrincipal(PrincipalType.USER, "my_grantor_user"); + method.accept(authorizer, requestingSecurityContext, ImmutableSet.of("my_role_with_grantor"), ImmutableSet.of(grantee), false, Optional.of(grantor)); + + Set expectedRequests = ImmutableSet.builder() + .add(""" + { + "operation": "%s", + "resource": { + "roles": [ + { + "name": "my_role_with_grantor" + } + ] + }, + "grantor": { + "name": "my_grantor_user", + "type": "USER" + }, + "grantee": { + "principals": [ + { + "name": "my_grantee_role", + "type": "ROLE" + } + ], + "grantOption": false + } + } + """.formatted(actionName)) + .add(""" + { + "operation": "%s", + "resource": { + "roles": [ + { + "name": "my_role_without_grantor" + } + ] + }, + "grantee": { + "principals": [ + { + "name": "my_grantee_role", + "type": "ROLE" + } + ], + "grantOption": true + } + } + """.formatted(actionName)) + .build(); + assertStringRequestsEqual(expectedRequests, mockClient.getRequests(), "/input/action"); + } + + private static Stream roleGrantingFailureTestCases() + { + return createFailingTestCases(roleGrantingTestCases()); + } + + @ParameterizedTest(name = "{index}: {0} - {2}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#roleGrantingFailureTestCases") + public void testRoleGrantingFailure( + String actionName, + FunctionalHelpers.Consumer6, Set, Boolean, Optional> method, + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + TrinoPrincipal grantee = new TrinoPrincipal(PrincipalType.ROLE, "my_grantee_role"); + Throwable actualError = assertThrows( + expectedException, + () -> method.accept( + authorizer, + requestingSecurityContext, + ImmutableSet.of("my_role_without_grantor"), + ImmutableSet.of(grantee), + true, + Optional.empty())); + assertTrue( + actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + private static Stream functionResourceTestCases() + { + Stream> methods = Stream.of( + new TestHelpers.ThrowingMethodWrapper<>(OpaAccessControl::checkCanExecuteProcedure), + new TestHelpers.ThrowingMethodWrapper<>(OpaAccessControl::checkCanCreateFunction), + new TestHelpers.ThrowingMethodWrapper<>(OpaAccessControl::checkCanDropFunction), + new TestHelpers.ReturningMethodWrapper<>(OpaAccessControl::canExecuteFunction), + new TestHelpers.ReturningMethodWrapper<>(OpaAccessControl::canCreateViewWithExecuteFunction)); + Stream actions = Stream.of( + "ExecuteProcedure", + "CreateFunction", + "DropFunction", + "ExecuteFunction", + "CreateViewWithExecuteFunction"); + return Streams.zip(actions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#functionResourceTestCases") + public void testFunctionResourceAction( + String actionName, + TestHelpers.MethodWrapper method) + { + CatalogSchemaRoutineName routine = new CatalogSchemaRoutineName("my_catalog", "my_schema", "my_routine_name"); + assertTrue(method.isAccessAllowed(authorizer, requestingSecurityContext, routine)); + + mockClient.setHandler(request -> NO_ACCESS_RESPONSE); + assertFalse(method.isAccessAllowed(authorizer, requestingSecurityContext, routine)); + + String expectedRequest = """ + { + "operation": "%s", + "resource": { + "function": { + "catalogName": "my_catalog", + "schemaName": "my_schema", + "functionName": "my_routine_name" + } + } + }""".formatted(actionName); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + assertEquals(mockClient.getRequests().size(), 2); + } + + private static Stream functionResourceIllegalResponseTestCases() + { + return createIllegalResponseTestCases(functionResourceTestCases()); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#functionResourceIllegalResponseTestCases") + public void testFunctionResourceIllegalResponses( + String actionName, + TestHelpers.MethodWrapper method, + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + CatalogSchemaRoutineName routine = new CatalogSchemaRoutineName("my_catalog", "my_schema", "my_routine_name"); + Throwable actualError = assertThrows( + expectedException, + () -> method.isAccessAllowed(authorizer, requestingSecurityContext, routine)); + assertTrue( + actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + @Test + public void testCanExecuteTableProcedure() + { + CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); + authorizer.checkCanExecuteTableProcedure(requestingSecurityContext, table, "my_procedure"); + + String expectedRequest = """ + { + "operation": "ExecuteTableProcedure", + "resource": { + "table": { + "schemaName": "my_schema", + "catalogName": "my_catalog", + "tableName": "my_table" + }, + "function": { + "functionName": "my_procedure" + } + } + }"""; + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.TestHelpers#allErrorCasesArgumentProvider") + public void testCanExecuteTableProcedureFailure( + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); + Throwable actualError = assertThrows( + expectedException, + () -> authorizer.checkCanExecuteTableProcedure( + requestingSecurityContext, + table, + "my_procedure")); + assertTrue( + actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + } + + private static Stream noResourceActionTestCases() + { + Stream> methods = Stream.of( + convertSystemSecurityContextToIdentityArgument(OpaAccessControl::checkCanExecuteQuery), + convertSystemSecurityContextToIdentityArgument(OpaAccessControl::checkCanReadSystemInformation), + convertSystemSecurityContextToIdentityArgument(OpaAccessControl::checkCanWriteSystemInformation), + OpaAccessControl::checkCanShowRoles, + OpaAccessControl::checkCanShowCurrentRoles, + OpaAccessControl::checkCanShowRoleGrants); + Stream expectedActions = Stream.of( + "ExecuteQuery", + "ReadSystemInformation", + "WriteSystemInformation", + "ShowRoles", + "ShowCurrentRoles", + "ShowRoleGrants"); + return Streams.zip(expectedActions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); + } + + private static Stream noResourceActionFailureTestCases() + { + return createFailingTestCases(noResourceActionTestCases()); + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaBatchAccessControlFilteringUnitTest.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaBatchAccessControlFilteringUnitTest.java new file mode 100644 index 0000000000000..4d004e427d159 --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaBatchAccessControlFilteringUnitTest.java @@ -0,0 +1,470 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.json.JsonMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; +import io.trino.plugin.opa.schema.TrinoUser; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.function.SchemaFunctionName; +import io.trino.spi.security.Identity; +import io.trino.spi.security.SystemSecurityContext; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Named; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.io.IOException; +import java.net.URI; +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Stream; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.net.MediaType.JSON_UTF_8; +import static io.trino.plugin.opa.RequestTestUtilities.assertJsonRequestsEqual; +import static io.trino.plugin.opa.RequestTestUtilities.assertStringRequestsEqual; +import static io.trino.plugin.opa.TestHelpers.systemSecurityContextFromIdentity; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class OpaBatchAccessControlFilteringUnitTest +{ + private static final URI OPA_SERVER_URI = URI.create("http://my-uri/"); + private static final URI OPA_BATCH_SERVER_URI = URI.create("http://my-uri/batchAllow"); + private HttpClientUtils.InstrumentedHttpClient mockClient; + private OpaAccessControl authorizer; + private final JsonMapper jsonMapper = new JsonMapper(); + private Identity requestingIdentity; + private SystemSecurityContext requestingSecurityContext; + + @BeforeEach + public void setupAuthorizer() + { + this.jsonMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); + this.jsonMapper.registerModule(new Jdk8Module()); + this.mockClient = new HttpClientUtils.InstrumentedHttpClient(OPA_BATCH_SERVER_URI, "POST", JSON_UTF_8.toString(), request -> null); + this.authorizer = (OpaAccessControl) new OpaAccessControlFactory().create( + ImmutableMap.builder() + .put("opa.policy.uri", OPA_SERVER_URI.toString()) + .put("opa.policy.batched-uri", OPA_BATCH_SERVER_URI.toString()) + .buildOrThrow(), + Optional.of(mockClient)); + this.requestingIdentity = Identity.ofUser("source-user"); + this.requestingSecurityContext = systemSecurityContextFromIdentity(requestingIdentity); + } + + @AfterEach + public void ensureRequestContextCorrect() + throws IOException + { + for (String request : mockClient.getRequests()) { + JsonNode parsedRequest = jsonMapper.readTree(request); + assertEquals(parsedRequest.at("/input/context/identity/user").asText(), requestingIdentity.getUser()); + } + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaBatchAccessControlFilteringUnitTest#subsetProvider") + public void testFilterViewQueryOwnedBy( + HttpClientUtils.MockResponse response, + List expectedItems) + { + this.mockClient.setHandler(request -> response); + + Identity identityOne = Identity.ofUser("user-one"); + Identity identityTwo = Identity.ofUser("user-two"); + Identity identityThree = Identity.ofUser("user-three"); + List requestedIdentities = ImmutableList.of(identityOne, identityTwo, identityThree); + + Collection result = authorizer.filterViewQueryOwnedBy(requestingIdentity, requestedIdentities); + assertEquals(ImmutableSet.copyOf(result), ImmutableSet.copyOf(getSubset(requestedIdentities, expectedItems))); + + ArrayNode allExpectedUsers = jsonMapper.createArrayNode().addAll( + requestedIdentities.stream() + .map(TrinoUser::new) + .map(user -> encodeObjectWithKey(user, "user")) + .collect(toImmutableList())); + ObjectNode expectedRequest = jsonMapper.createObjectNode() + .put("operation", "FilterViewQueryOwnedBy") + .set("filterResources", allExpectedUsers); + assertJsonRequestsEqual(ImmutableSet.of(expectedRequest), this.mockClient.getRequests(), "/input/action"); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaBatchAccessControlFilteringUnitTest#subsetProvider") + public void testFilterCatalogs( + HttpClientUtils.MockResponse response, + List expectedItems) + { + this.mockClient.setHandler(request -> response); + + List requestedCatalogs = ImmutableList.of("catalog_one", "catalog_two", "catalog_three"); + + Set result = authorizer.filterCatalogs( + requestingSecurityContext, + new LinkedHashSet<>(requestedCatalogs)); + assertEquals(ImmutableSet.copyOf(result), ImmutableSet.copyOf(getSubset(requestedCatalogs, expectedItems))); + + String expectedRequest = """ + { + "operation": "FilterCatalogs", + "filterResources": [ + { + "catalog": { + "name": "catalog_one" + } + }, + { + "catalog": { + "name": "catalog_two" + } + }, + { + "catalog": { + "name": "catalog_three" + } + } + ] + }"""; + assertStringRequestsEqual(ImmutableList.of(expectedRequest), this.mockClient.getRequests(), "/input/action"); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaBatchAccessControlFilteringUnitTest#subsetProvider") + public void testFilterSchemas( + HttpClientUtils.MockResponse response, + List expectedItems) + { + this.mockClient.setHandler(request -> response); + List requestedSchemas = ImmutableList.of("schema_one", "schema_two", "schema_three"); + + Set result = authorizer.filterSchemas( + requestingSecurityContext, + "my_catalog", + new LinkedHashSet<>(requestedSchemas)); + assertEquals(ImmutableSet.copyOf(result), ImmutableSet.copyOf(getSubset(requestedSchemas, expectedItems))); + + String expectedRequest = """ + { + "operation": "FilterSchemas", + "filterResources": [ + { + "schema": { + "schemaName": "schema_one", + "catalogName": "my_catalog" + } + }, + { + "schema": { + "schemaName": "schema_two", + "catalogName": "my_catalog" + } + }, + { + "schema": { + "schemaName": "schema_three", + "catalogName": "my_catalog" + } + } + ] + }"""; + assertStringRequestsEqual(ImmutableList.of(expectedRequest), this.mockClient.getRequests(), "/input/action"); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaBatchAccessControlFilteringUnitTest#subsetProvider") + public void testFilterTables( + HttpClientUtils.MockResponse response, + List expectedItems) + { + this.mockClient.setHandler(request -> response); + List tables = ImmutableList.builder() + .add(new SchemaTableName("schema_one", "table_one")) + .add(new SchemaTableName("schema_one", "table_two")) + .add(new SchemaTableName("schema_two", "table_one")) + .build(); + + Set result = authorizer.filterTables( + requestingSecurityContext, + "my_catalog", + new LinkedHashSet<>(tables)); + assertEquals(ImmutableSet.copyOf(result), ImmutableSet.copyOf(getSubset(tables, expectedItems))); + + String expectedRequest = """ + { + "operation": "FilterTables", + "filterResources": [ + { + "table": { + "tableName": "table_one", + "schemaName": "schema_one", + "catalogName": "my_catalog" + } + }, + { + "table": { + "tableName": "table_two", + "schemaName": "schema_one", + "catalogName": "my_catalog" + } + }, + { + "table": { + "tableName": "table_one", + "schemaName": "schema_two", + "catalogName": "my_catalog" + } + } + ] + }"""; + assertStringRequestsEqual(ImmutableList.of(expectedRequest), this.mockClient.getRequests(), "/input/action"); + } + + private static Function buildHandler(Function dataBuilder) + { + return request -> new HttpClientUtils.MockResponse(dataBuilder.apply(request), 200); + } + + @Test + public void testFilterColumns() + { + SchemaTableName tableOne = SchemaTableName.schemaTableName("my_schema", "table_one"); + SchemaTableName tableTwo = SchemaTableName.schemaTableName("my_schema", "table_two"); + SchemaTableName tableThree = SchemaTableName.schemaTableName("my_schema", "table_three"); + Map> requestedColumns = ImmutableMap.>builder() + .put(tableOne, ImmutableSet.of("table_one_column_one", "table_one_column_two")) + .put(tableTwo, ImmutableSet.of("table_two_column_one", "table_two_column_two")) + .put(tableThree, ImmutableSet.of("table_three_column_one", "table_three_column_two")) + .buildOrThrow(); + + // Allow both columns from one table, one column from another one and no columns from the last one + this.mockClient.setHandler( + buildHandler( + request -> { + if (request.contains("table_one")) { + return "{\"result\": [0, 1]}"; + } else if (request.contains("table_two")) { + return "{\"result\": [1]}"; + } + return "{\"result\": []}"; + })); + + Map> result = authorizer.filterColumns( + requestingSecurityContext, + "my_catalog", + requestedColumns); + + List expectedRequests = Stream.of("table_one", "table_two", "table_three") + .map(tableName -> """ + { + "operation": "FilterColumns", + "filterResources": [ + { + "table": { + "tableName": "%s", + "schemaName": "my_schema", + "catalogName": "my_catalog", + "columns": ["%s_column_one", "%s_column_two"] + } + } + ] + } + """.formatted(tableName, tableName, tableName)) + .collect(toImmutableList()); + assertStringRequestsEqual(expectedRequests, this.mockClient.getRequests(), "/input/action"); + assertTrue(Maps.difference( + result, + ImmutableMap.builder() + .put(tableOne, ImmutableSet.of("table_one_column_one", "table_one_column_two")) + .put(tableTwo, ImmutableSet.of("table_two_column_two")) + .buildOrThrow()).areEqual()); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.OpaBatchAccessControlFilteringUnitTest#subsetProvider") + public void testFilterFunctions( + HttpClientUtils.MockResponse response, + List expectedItems) + { + this.mockClient.setHandler(request -> response); + List requestedFunctions = ImmutableList.builder() + .add(new SchemaFunctionName("my_schema", "function_one")) + .add(new SchemaFunctionName("my_schema", "function_two")) + .add(new SchemaFunctionName("my_schema", "function_three")) + .build(); + + Set result = authorizer.filterFunctions( + requestingSecurityContext, + "my_catalog", + new LinkedHashSet<>(requestedFunctions)); + assertEquals(ImmutableSet.copyOf(result), ImmutableSet.copyOf(getSubset(requestedFunctions, expectedItems))); + + String expectedRequest = """ + { + "operation": "FilterFunctions", + "filterResources": [ + { + "function": { + "catalogName": "my_catalog", + "schemaName": "my_schema", + "functionName": "function_one" + } + }, + { + "function": { + "catalogName": "my_catalog", + "schemaName": "my_schema", + "functionName": "function_two" + } + }, + { + "function": { + "catalogName": "my_catalog", + "schemaName": "my_schema", + "functionName": "function_three" + } + } + ] + }"""; + assertStringRequestsEqual(ImmutableList.of(expectedRequest), this.mockClient.getRequests(), "/input/action"); + } + + @Test + public void testEmptyFilterColumns() + { + SchemaTableName tableOne = SchemaTableName.schemaTableName("my_schema", "table_one"); + SchemaTableName tableTwo = SchemaTableName.schemaTableName("my_schema", "table_two"); + Map> requestedColumns = ImmutableMap.>builder() + .put(tableOne, ImmutableSet.of()) + .put(tableTwo, ImmutableSet.of()) + .buildOrThrow(); + + Map> result = authorizer.filterColumns( + requestingSecurityContext, + "my_catalog", + requestedColumns); + assertTrue(mockClient.getRequests().isEmpty()); + assertTrue(result.isEmpty()); + } + + @ParameterizedTest(name = "{index}: {0}") + @MethodSource("io.trino.plugin.opa.FilteringTestHelpers#emptyInputTestCases") + public void testEmptyRequests( + BiFunction callable) + { + Collection result = callable.apply(authorizer, requestingSecurityContext); + assertTrue(result.isEmpty()); + assertTrue(mockClient.getRequests().isEmpty()); + } + + @ParameterizedTest(name = "{index}: {0} - {1}") + @MethodSource("io.trino.plugin.opa.FilteringTestHelpers#prepopulatedErrorCases") + public void testIllegalResponseThrows( + BiFunction callable, + HttpClientUtils.MockResponse failureResponse, + Class expectedException, + String expectedErrorMessage) + { + mockClient.setHandler(request -> failureResponse); + + Throwable actualError = assertThrows( + expectedException, + () -> callable.apply(authorizer, requestingSecurityContext)); + assertTrue( + actualError.getMessage().contains(expectedErrorMessage), + String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + assertFalse(mockClient.getRequests().isEmpty()); + } + + @Test + public void testResponseOutOfBoundsThrows() + { + mockClient.setHandler(request -> new HttpClientUtils.MockResponse("{\"result\": [0, 1, 2]}", 200)); + + assertThrows( + OpaQueryException.QueryFailed.class, + () -> authorizer.filterCatalogs(requestingSecurityContext, ImmutableSet.of("catalog_one", "catalog_two"))); + assertThrows( + OpaQueryException.QueryFailed.class, + () -> authorizer.filterSchemas(requestingSecurityContext, "some_catalog", ImmutableSet.of("schema_one", "schema_two"))); + assertThrows( + OpaQueryException.QueryFailed.class, + () -> authorizer.filterTables( + requestingSecurityContext, + "some_catalog", + ImmutableSet.of( + new SchemaTableName("some_schema", "table_one"), + new SchemaTableName("some_schema", "table_two")))); + assertThrows( + OpaQueryException.QueryFailed.class, + () -> authorizer.filterColumns( + requestingSecurityContext, + "some_catalog", + ImmutableMap.>builder() + .put(new SchemaTableName("some_schema", "some_table"), ImmutableSet.of("column_one", "column_two")) + .buildOrThrow())); + assertThrows( + OpaQueryException.QueryFailed.class, + () -> authorizer.filterViewQueryOwnedBy( + requestingIdentity, + ImmutableSet.of(Identity.ofUser("identity_one"), Identity.ofUser("identity_two")))); + } + + private ObjectNode encodeObjectWithKey(Object inp, String key) + { + return jsonMapper.createObjectNode().set(key, jsonMapper.valueToTree(inp)); + } + + private static Stream subsetProvider() + { + return Stream.of( + Arguments.of(Named.of("All-3-resources", new HttpClientUtils.MockResponse("{\"result\": [0, 1, 2]}", 200)), ImmutableList.of(0, 1, 2)), + Arguments.of(Named.of("First-and-last-resources", new HttpClientUtils.MockResponse("{\"result\": [0, 2]}", 200)), ImmutableList.of(0, 2)), + Arguments.of(Named.of("Only-one-resource", new HttpClientUtils.MockResponse("{\"result\": [2]}", 200)), ImmutableList.of(2)), + Arguments.of(Named.of("No-resources", new HttpClientUtils.MockResponse("{\"result\": []}", 200)), ImmutableList.of())); + } + + private List getSubset(List allItems, List subsetPositions) + { + List result = new ArrayList<>(); + for (int i : subsetPositions) { + if (i < 0 || i >= allItems.size()) { + throw new IllegalArgumentException("Invalid subset of items provided"); + } + result.add(allItems.get(i)); + } + return result; + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java new file mode 100644 index 0000000000000..97c7dff942dbf --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.json.JsonMapper; +import com.google.common.collect.ImmutableSet; + +import java.io.IOException; +import java.util.Collection; +import java.util.Set; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +public class RequestTestUtilities +{ + private RequestTestUtilities() {} + + private static final JsonMapper jsonMapper = new JsonMapper(); + + public static void assertStringRequestsEqual( + Collection expectedRequests, Collection actualRequests, String extractPath) + { + Set parsedExpectedRequests = expectedRequests.stream() + .map(expectedRequest -> { + try { + return jsonMapper.readTree(expectedRequest); + } + catch (IOException e) { + fail("Could not parse request", e); + return null; + } + }) + .collect(toImmutableSet()); + assertJsonRequestsEqual(parsedExpectedRequests, actualRequests, extractPath); + } + + public static void assertJsonRequestsEqual( + Collection expectedRequests, Collection actualRequests, String extractPath) + { + Set parsedActualRequests = actualRequests.stream() + .map(actualRequest -> { + try { + JsonNode parsed = jsonMapper.readTree(actualRequest); + if (extractPath != null) { + return parsed.at(extractPath); + } + return parsed; + } + catch (IOException e) { + fail("Could not parse request", e); + return null; + } + }) + .collect(toImmutableSet()); + Set expectedRequestSet = ImmutableSet.copyOf(expectedRequests); + assertEquals( + expectedRequestSet.size(), + parsedActualRequests.size(), + "Mismatch in expected vs. actual request count"); + assertEquals(expectedRequestSet, parsedActualRequests, "Requests do not match"); + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/ResponseTest.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/ResponseTest.java new file mode 100644 index 0000000000000..66a2b21f5594b --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/ResponseTest.java @@ -0,0 +1,133 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.common.collect.ImmutableList; +import io.airlift.json.JsonCodec; +import io.airlift.json.JsonCodecFactory; +import io.trino.plugin.opa.schema.OpaBatchQueryResult; +import io.trino.plugin.opa.schema.OpaQueryResult; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ResponseTest +{ + private JsonCodec responseCodec; + private JsonCodec batchResponseCodec; + + @BeforeEach + public void setupParser() + { + this.responseCodec = new JsonCodecFactory().jsonCodec(OpaQueryResult.class); + this.batchResponseCodec = new JsonCodecFactory().jsonCodec(OpaBatchQueryResult.class); + } + + @ParameterizedTest + @ValueSource(booleans = {false, true}) + public void testCanDeserializeOpaSingleResponse(boolean response) + { + OpaQueryResult result = this.responseCodec.fromJson(""" + { + "decision_id": "foo", + "result": %s + }""".formatted(String.valueOf(response))); + assertEquals(response, result.result()); + assertEquals("foo", result.decisionId()); + } + + @ParameterizedTest + @ValueSource(booleans = {false, true}) + public void testCanDeserializeOpaSingleResponseWithNoDecisionId(boolean response) + { + OpaQueryResult result = this.responseCodec.fromJson(""" + { + "result": %s + }""".formatted(String.valueOf(response))); + assertEquals(response, result.result()); + assertNull(result.decisionId()); + } + + @Test + public void testSingleResponseWithExtraFields() + { + OpaQueryResult result = this.responseCodec.fromJson(""" + { + "result": true, + "someExtraInfo": ["foo"] + }"""); + assertTrue(result.result()); + assertNull(result.decisionId()); + } + + @Test + public void testUndefinedDecisionSingleResponseTreatedAsDeny() + { + OpaQueryResult result = this.responseCodec.fromJson("{}"); + assertFalse(result.result()); + assertNull(result.decisionId()); + } + + @ParameterizedTest + @ValueSource(strings = {"{}", "{\"result\": []}"}) + public void testEmptyOrUndefinedResponses(String response) + { + OpaBatchQueryResult result = this.batchResponseCodec.fromJson(response); + assertEquals(ImmutableList.of(), result.result()); + assertNull(result.decisionId()); + } + + @Test + public void testBatchResponseWithItemsNoDecisionId() + { + OpaBatchQueryResult result = this.batchResponseCodec.fromJson(""" + { + "result": [1, 2, 3] + }"""); + assertEquals(ImmutableList.of(1, 2, 3), result.result()); + assertNull(result.decisionId()); + } + + @Test + public void testBatchResponseWithItemsAndDecisionId() + { + OpaBatchQueryResult result = this.batchResponseCodec.fromJson(""" + { + "result": [1, 2, 3], + "decision_id": "foobar" + }"""); + assertEquals(ImmutableList.of(1, 2, 3), result.result()); + assertEquals("foobar", result.decisionId()); + } + + @Test + public void testBatchResponseWithExtraFields() + { + OpaBatchQueryResult result = this.batchResponseCodec.fromJson(""" + { + "result": [1, 2, 3], + "decision_id": "foobar", + "someInfo": "foo", + "andAnObject": {} + }"""); + assertEquals(ImmutableList.of(1, 2, 3), result.result()); + assertEquals("foobar", result.decisionId()); + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestFactory.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestFactory.java new file mode 100644 index 0000000000000..50a81a7aaaeb7 --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestFactory.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.common.collect.ImmutableMap; +import io.airlift.bootstrap.ApplicationConfigurationException; +import io.trino.spi.security.SystemAccessControl; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class TestFactory +{ + @Test + public void testCreatesSimpleAuthorizerIfNoBatchUriProvided() + { + OpaAccessControlFactory factory = new OpaAccessControlFactory(); + SystemAccessControl opaAuthorizer = factory.create(ImmutableMap.of("opa.policy.uri", "foo")); + + assertInstanceOf(OpaAccessControl.class, opaAuthorizer); + assertFalse(opaAuthorizer instanceof OpaBatchAccessControl); + } + + @Test + public void testCreatesBatchAuthorizerIfBatchUriProvided() + { + OpaAccessControlFactory factory = new OpaAccessControlFactory(); + SystemAccessControl opaAuthorizer = factory.create( + ImmutableMap.builder() + .put("opa.policy.uri", "foo") + .put("opa.policy.batched-uri", "bar") + .buildOrThrow()); + + assertInstanceOf(OpaBatchAccessControl.class, opaAuthorizer); + assertInstanceOf(OpaAccessControl.class, opaAuthorizer); + } + + @Test + public void testBasePolicyUriCannotBeUnset() + { + OpaAccessControlFactory factory = new OpaAccessControlFactory(); + + assertThrows( + ApplicationConfigurationException.class, + () -> factory.create(ImmutableMap.of()), + "may not be null"); + } + + @Test + public void testConfigMayNotBeNull() + { + OpaAccessControlFactory factory = new OpaAccessControlFactory(); + + assertThrows( + NullPointerException.class, + () -> factory.create(null)); + } + + @Test + public void testSupportsAirliftHttpConfigs() + { + OpaAccessControlFactory factory = new OpaAccessControlFactory(); + SystemAccessControl opaAuthorizer = factory.create( + ImmutableMap.builder() + .put("opa.policy.uri", "foo") + .put("opa.http-client.log.enabled", "true") + .buildOrThrow()); + + assertInstanceOf(OpaAccessControl.class, opaAuthorizer); + assertFalse(opaAuthorizer instanceof OpaBatchAccessControl); + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java new file mode 100644 index 0000000000000..7e67030834d64 --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java @@ -0,0 +1,142 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.common.collect.Sets; +import io.trino.execution.QueryIdGenerator; +import io.trino.spi.security.AccessDeniedException; +import io.trino.spi.security.Identity; +import io.trino.spi.security.SystemSecurityContext; +import org.junit.jupiter.api.Named; +import org.junit.jupiter.params.provider.Arguments; + +import java.time.Instant; +import java.util.Arrays; +import java.util.function.BiConsumer; +import java.util.stream.Stream; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; + +public class TestHelpers +{ + private TestHelpers() {} + + public static final HttpClientUtils.MockResponse OK_RESPONSE = new HttpClientUtils.MockResponse(""" + { + "decision_id": "", + "result": true + } + """, + 200); + public static final HttpClientUtils.MockResponse NO_ACCESS_RESPONSE = new HttpClientUtils.MockResponse(""" + { + "decision_id": "", + "result": false + } + """, + 200); + public static final HttpClientUtils.MockResponse MALFORMED_RESPONSE = new HttpClientUtils.MockResponse(""" + { "this"": is broken_json; } + """, + 200); + public static final HttpClientUtils.MockResponse UNDEFINED_RESPONSE = new HttpClientUtils.MockResponse("{}", 404); + public static final HttpClientUtils.MockResponse BAD_REQUEST_RESPONSE = new HttpClientUtils.MockResponse("{}", 400); + public static final HttpClientUtils.MockResponse SERVER_ERROR_RESPONSE = new HttpClientUtils.MockResponse("", 500); + + public static Stream createFailingTestCases(Stream baseTestCases) + { + return Sets.cartesianProduct( + baseTestCases.collect(toImmutableSet()), + allErrorCasesArgumentProvider().collect(toImmutableSet())) + .stream() + .map(items -> Arguments.of(items.stream().flatMap((args) -> Arrays.stream(args.get())).toArray())); + } + + public static Stream createIllegalResponseTestCases(Stream baseTestCases) + { + return Sets.cartesianProduct( + baseTestCases.collect(toImmutableSet()), + illegalResponseArgumentProvider().collect(toImmutableSet())) + .stream() + .map(items -> Arguments.of(items.stream().flatMap((args) -> Arrays.stream(args.get())).toArray())); + } + + public static Stream illegalResponseArgumentProvider() + { + // Invalid responses from OPA + return Stream.of( + Arguments.of(Named.of("Undefined policy response", UNDEFINED_RESPONSE), OpaQueryException.OpaServerError.PolicyNotFound.class, "did not return a value"), + Arguments.of(Named.of("Bad request response", BAD_REQUEST_RESPONSE), OpaQueryException.OpaServerError.class, "returned status 400"), + Arguments.of(Named.of("Server error response", SERVER_ERROR_RESPONSE), OpaQueryException.OpaServerError.class, "returned status 500"), + Arguments.of(Named.of("Malformed JSON response", MALFORMED_RESPONSE), OpaQueryException.class, "Failed to deserialize")); + } + + public static Stream allErrorCasesArgumentProvider() + { + // All possible failure scenarios, including a well-formed access denied response + return Stream.concat( + illegalResponseArgumentProvider(), + Stream.of(Arguments.of(Named.of("No access response", NO_ACCESS_RESPONSE), AccessDeniedException.class, "Access Denied"))); + } + + public static SystemSecurityContext systemSecurityContextFromIdentity(Identity identity) { + return new SystemSecurityContext(identity, new QueryIdGenerator().createNextQueryId(), Instant.now()); + } + + public static BiConsumer convertSystemSecurityContextToIdentityArgument( + BiConsumer callable) + { + return (accessControl, systemSecurityContext) -> callable.accept(accessControl, systemSecurityContext.getIdentity()); + } + + public static FunctionalHelpers.Consumer3 convertSystemSecurityContextToIdentityArgument( + FunctionalHelpers.Consumer3 callable) { + return (accessControl, systemSecurityContext, argument) -> callable.accept(accessControl, systemSecurityContext.getIdentity(), argument); + } + + public abstract static class MethodWrapper { + public abstract boolean isAccessAllowed(OpaAccessControl opaAccessControl, SystemSecurityContext systemSecurityContext, T argument); + } + + public static class ThrowingMethodWrapper extends MethodWrapper { + private final FunctionalHelpers.Consumer3 callable; + + public ThrowingMethodWrapper(FunctionalHelpers.Consumer3 callable) { + this.callable = callable; + } + + @Override + public boolean isAccessAllowed(OpaAccessControl opaAccessControl, SystemSecurityContext systemSecurityContext, T argument) { + try { + this.callable.accept(opaAccessControl, systemSecurityContext, argument); + return true; + } catch (AccessDeniedException e) { + return false; + } + } + } + + public static class ReturningMethodWrapper extends MethodWrapper { + private final FunctionalHelpers.Function3 callable; + + public ReturningMethodWrapper(FunctionalHelpers.Function3 callable) { + this.callable = callable; + } + + @Override + public boolean isAccessAllowed(OpaAccessControl opaAccessControl, SystemSecurityContext systemSecurityContext, T argument) { + return this.callable.apply(opaAccessControl, systemSecurityContext, argument); + } + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaConfig.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaConfig.java new file mode 100644 index 0000000000000..87d14a40cb9ee --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaConfig.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; + +public class TestOpaConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(OpaConfig.class) + .setOpaUri(null) + .setOpaBatchUri(null) + .setLogRequests(false) + .setLogResponses(false)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = ImmutableMap.builder() + .put("opa.policy.uri", "https://opa.example.com") + .put("opa.policy.batched-uri", "https://opa-batch.example.com") + .put("opa.log-requests", "true") + .put("opa.log-responses", "true") + .buildOrThrow(); + + OpaConfig expected = new OpaConfig() + .setOpaUri(URI.create("https://opa.example.com")) + .setOpaBatchUri(URI.create("https://opa-batch.example.com")) + .setLogRequests(true) + .setLogResponses(true); + + assertFullMapping(properties, expected); + } +} diff --git a/pom.xml b/pom.xml index 2284eccc29681..f5c4fc8dc8e9f 100644 --- a/pom.xml +++ b/pom.xml @@ -92,6 +92,7 @@ plugin/trino-mongodb plugin/trino-mysql plugin/trino-mysql-event-listener + plugin/trino-opa plugin/trino-oracle plugin/trino-password-authenticators plugin/trino-phoenix5 From 389f4db98dc82a3f47291a4e7473cb9160c37413 Mon Sep 17 00:00:00 2001 From: Pablo Arteaga Date: Fri, 3 Nov 2023 12:40:24 +0000 Subject: [PATCH 02/11] Code review: - Simplify tests and remove some parameterization - Make more fields final - Remove unnecessary annotations - Make tests more parallelizable - Remove permissioning operations and replace them by a blanket allow/deny setting And rebase - bump pom version --- plugin/trino-opa/README.md | 3 + plugin/trino-opa/pom.xml | 7 +- .../io/trino/plugin/opa/OpaAccessControl.java | 227 +---- .../plugin/opa/OpaBatchAccessControl.java | 24 +- .../java/io/trino/plugin/opa/OpaConfig.java | 14 + .../trino/plugin/opa/OpaHighLevelClient.java | 8 +- .../io/trino/plugin/opa/OpaHttpClient.java | 15 +- .../opa/schema/OpaQueryInputAction.java | 17 +- .../plugin/opa/schema/OpaQueryInputGrant.java | 76 -- .../opa/schema/OpaQueryInputResource.java | 33 +- .../plugin/opa/schema/PropertiesMapper.java | 32 - .../opa/schema/TrinoGrantPrincipal.java | 7 - .../plugin/opa/schema/TrinoIdentity.java | 9 +- .../plugin/opa/FilteringTestHelpers.java | 4 +- .../trino/plugin/opa/FunctionalHelpers.java | 38 +- .../io/trino/plugin/opa/HttpClientUtils.java | 48 +- .../plugin/opa/RequestTestUtilities.java | 60 +- .../java/io/trino/plugin/opa/TestHelpers.java | 52 +- ...nitTest.java => TestOpaAccessControl.java} | 836 +++++------------- ....java => TestOpaAccessControlFactory.java} | 28 +- ...ava => TestOpaAccessControlFiltering.java} | 185 ++-- ...aAccessControlPermissioningOperations.java | 147 +++ .../opa/TestOpaAccessControlPlugin.java | 32 + ...t.java => TestOpaAccessControlSystem.java} | 44 +- ...> TestOpaBatchAccessControlFiltering.java} | 287 +++--- .../io/trino/plugin/opa/TestOpaConfig.java | 7 +- ...Test.java => TestOpaResponseDecoding.java} | 81 +- 27 files changed, 903 insertions(+), 1418 deletions(-) delete mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputGrant.java delete mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/PropertiesMapper.java rename plugin/trino-opa/src/test/java/io/trino/plugin/opa/{OpaAccessControlUnitTest.java => TestOpaAccessControl.java} (54%) rename plugin/trino-opa/src/test/java/io/trino/plugin/opa/{TestFactory.java => TestOpaAccessControlFactory.java} (71%) rename plugin/trino-opa/src/test/java/io/trino/plugin/opa/{OpaAccessControlFilteringUnitTest.java => TestOpaAccessControlFiltering.java} (62%) create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlPermissioningOperations.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlPlugin.java rename plugin/trino-opa/src/test/java/io/trino/plugin/opa/{OpaAccessControlSystemTest.java => TestOpaAccessControlSystem.java} (86%) rename plugin/trino-opa/src/test/java/io/trino/plugin/opa/{OpaBatchAccessControlFilteringUnitTest.java => TestOpaBatchAccessControlFiltering.java} (57%) rename plugin/trino-opa/src/test/java/io/trino/plugin/opa/{ResponseTest.java => TestOpaResponseDecoding.java} (57%) diff --git a/plugin/trino-opa/README.md b/plugin/trino-opa/README.md index 25878d80d60b8..f4724973c7f88 100644 --- a/plugin/trino-opa/README.md +++ b/plugin/trino-opa/README.md @@ -4,6 +4,9 @@ This plugin enables Trino to use Open Policy Agent (OPA) as an authorization eng For more information on OPA, please refer to the Open Policy Agent [documentation](https://www.openpolicyagent.org/). +> While every attempt will be made to keep backwards compatibility, this plugin is a recent addition +> and as such the API may change. + ## Configuration You will need to configure Trino to use the OPA plugin as its access control engine, then configure the diff --git a/plugin/trino-opa/pom.xml b/plugin/trino-opa/pom.xml index 3964a1dd30da9..5b9f3fa36a180 100644 --- a/plugin/trino-opa/pom.xml +++ b/plugin/trino-opa/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 431-SNAPSHOT + 434-SNAPSHOT ../../pom.xml @@ -126,6 +126,11 @@ trino-testing test + + org.assertj + assertj-core + test + org.junit.jupiter diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java index 0d8d9bd881e57..ca19d45981dc1 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java @@ -20,7 +20,6 @@ import io.trino.plugin.opa.schema.OpaQueryContext; import io.trino.plugin.opa.schema.OpaQueryInput; import io.trino.plugin.opa.schema.OpaQueryInputAction; -import io.trino.plugin.opa.schema.OpaQueryInputGrant; import io.trino.plugin.opa.schema.OpaQueryInputResource; import io.trino.plugin.opa.schema.TrinoCatalogSessionProperty; import io.trino.plugin.opa.schema.TrinoFunction; @@ -48,34 +47,23 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; -import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.plugin.opa.OpaHighLevelClient.buildQueryInputForSimpleResource; -import static io.trino.plugin.opa.schema.PropertiesMapper.convertProperties; import static io.trino.spi.security.AccessDeniedException.denyCreateCatalog; import static io.trino.spi.security.AccessDeniedException.denyCreateFunction; -import static io.trino.spi.security.AccessDeniedException.denyCreateRole; import static io.trino.spi.security.AccessDeniedException.denyCreateSchema; import static io.trino.spi.security.AccessDeniedException.denyCreateViewWithSelect; -import static io.trino.spi.security.AccessDeniedException.denyDenySchemaPrivilege; -import static io.trino.spi.security.AccessDeniedException.denyDenyTablePrivilege; import static io.trino.spi.security.AccessDeniedException.denyDropCatalog; import static io.trino.spi.security.AccessDeniedException.denyDropFunction; -import static io.trino.spi.security.AccessDeniedException.denyDropRole; import static io.trino.spi.security.AccessDeniedException.denyDropSchema; import static io.trino.spi.security.AccessDeniedException.denyExecuteProcedure; import static io.trino.spi.security.AccessDeniedException.denyExecuteTableProcedure; -import static io.trino.spi.security.AccessDeniedException.denyGrantRoles; -import static io.trino.spi.security.AccessDeniedException.denyGrantSchemaPrivilege; -import static io.trino.spi.security.AccessDeniedException.denyGrantTablePrivilege; import static io.trino.spi.security.AccessDeniedException.denyImpersonateUser; import static io.trino.spi.security.AccessDeniedException.denyRenameMaterializedView; import static io.trino.spi.security.AccessDeniedException.denyRenameSchema; import static io.trino.spi.security.AccessDeniedException.denyRenameTable; import static io.trino.spi.security.AccessDeniedException.denyRenameView; -import static io.trino.spi.security.AccessDeniedException.denyRevokeRoles; -import static io.trino.spi.security.AccessDeniedException.denyRevokeSchemaPrivilege; -import static io.trino.spi.security.AccessDeniedException.denyRevokeTablePrivilege; import static io.trino.spi.security.AccessDeniedException.denySetCatalogSessionProperty; import static io.trino.spi.security.AccessDeniedException.denySetSchemaAuthorization; import static io.trino.spi.security.AccessDeniedException.denySetSystemSessionProperty; @@ -90,11 +78,13 @@ public sealed class OpaAccessControl permits OpaBatchAccessControl { private final OpaHighLevelClient opaHighLevelClient; + private final boolean allowPermissioningOperations; @Inject - public OpaAccessControl(OpaHighLevelClient opaHighLevelClient) + public OpaAccessControl(OpaHighLevelClient opaHighLevelClient, OpaConfig config) { this.opaHighLevelClient = opaHighLevelClient; + this.allowPermissioningOperations = config.getAllowPermissioningOperations(); } @Override @@ -239,11 +229,10 @@ public void checkCanRenameSchema(SystemSecurityContext context, CatalogSchemaNam public void checkCanSetSchemaAuthorization(SystemSecurityContext context, CatalogSchemaName schema, TrinoPrincipal principal) { OpaQueryInputResource resource = OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build(); - OpaQueryInputGrant grantee = OpaQueryInputGrant.builder().principal(TrinoGrantPrincipal.fromTrinoPrincipal(principal)).build(); OpaQueryInputAction action = OpaQueryInputAction.builder() .operation("SetSchemaAuthorization") .resource(resource) - .grantee(grantee) + .grantee(TrinoGrantPrincipal.fromTrinoPrincipal(principal)) .build(); OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); @@ -413,11 +402,10 @@ public void checkCanDropColumn(SystemSecurityContext context, CatalogSchemaTable public void checkCanSetTableAuthorization(SystemSecurityContext context, CatalogSchemaTableName table, TrinoPrincipal principal) { OpaQueryInputResource resource = OpaQueryInputResource.builder().table(new TrinoTable(table)).build(); - OpaQueryInputGrant grantee = OpaQueryInputGrant.builder().principal(TrinoGrantPrincipal.fromTrinoPrincipal(principal)).build(); OpaQueryInputAction action = OpaQueryInputAction.builder() .operation("SetTableAuthorization") .resource(resource) - .grantee(grantee) + .grantee(TrinoGrantPrincipal.fromTrinoPrincipal(principal)) .build(); OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); @@ -484,11 +472,10 @@ public void checkCanRenameView(SystemSecurityContext context, CatalogSchemaTable public void checkCanSetViewAuthorization(SystemSecurityContext context, CatalogSchemaTableName view, TrinoPrincipal principal) { OpaQueryInputResource resource = OpaQueryInputResource.builder().table(new TrinoTable(view)).build(); - OpaQueryInputGrant grantee = OpaQueryInputGrant.builder().principal(TrinoGrantPrincipal.fromTrinoPrincipal(principal)).build(); OpaQueryInputAction action = OpaQueryInputAction.builder() .operation("SetViewAuthorization") .resource(resource) - .grantee(grantee) + .grantee(TrinoGrantPrincipal.fromTrinoPrincipal(principal)) .build(); OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); @@ -558,224 +545,79 @@ public void checkCanSetCatalogSessionProperty(SystemSecurityContext context, Str @Override public void checkCanGrantSchemaPrivilege(SystemSecurityContext context, Privilege privilege, CatalogSchemaName schema, TrinoPrincipal grantee, boolean grantOption) { - OpaQueryInputResource resource = OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build(); - OpaQueryInputGrant opaGrantee = OpaQueryInputGrant.builder() - .principal(TrinoGrantPrincipal.fromTrinoPrincipal(grantee)) - .grantOption(grantOption) - .privilege(privilege) - .build(); - OpaQueryInputAction action = OpaQueryInputAction.builder() - .operation("GrantSchemaPrivilege") - .resource(resource) - .grantee(opaGrantee) - .build(); - OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); - - if (!opaHighLevelClient.queryOpa(input)) { - denyGrantSchemaPrivilege(privilege.toString(), schema.toString()); - } + enforcePermissioningOperation(AccessDeniedException::denyGrantSchemaPrivilege, privilege.toString(), schema.toString()); } @Override public void checkCanDenySchemaPrivilege(SystemSecurityContext context, Privilege privilege, CatalogSchemaName schema, TrinoPrincipal grantee) { - OpaQueryInputResource resource = OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build(); - OpaQueryInputGrant opaGrantee = OpaQueryInputGrant.builder() - .principal(TrinoGrantPrincipal.fromTrinoPrincipal(grantee)) - .privilege(privilege) - .build(); - OpaQueryInputAction action = OpaQueryInputAction.builder() - .operation("DenySchemaPrivilege") - .resource(resource) - .grantee(opaGrantee) - .build(); - OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); - - if (!opaHighLevelClient.queryOpa(input)) { - denyDenySchemaPrivilege(privilege.toString(), schema.toString()); - } + enforcePermissioningOperation(AccessDeniedException::denyDenySchemaPrivilege, privilege.toString(), schema.toString()); } @Override public void checkCanRevokeSchemaPrivilege(SystemSecurityContext context, Privilege privilege, CatalogSchemaName schema, TrinoPrincipal revokee, boolean grantOption) { - OpaQueryInputResource resource = OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build(); - OpaQueryInputGrant opaGrantee = OpaQueryInputGrant.builder() - .principal(TrinoGrantPrincipal.fromTrinoPrincipal(revokee)) - .grantOption(grantOption) - .privilege(privilege) - .build(); - OpaQueryInputAction action = OpaQueryInputAction.builder() - .operation("RevokeSchemaPrivilege") - .resource(resource) - .grantee(opaGrantee) - .build(); - OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); - - if (!opaHighLevelClient.queryOpa(input)) { - denyRevokeSchemaPrivilege(privilege.toString(), schema.toString()); - } + enforcePermissioningOperation(AccessDeniedException::denyRevokeSchemaPrivilege, privilege.toString(), schema.toString()); } @Override public void checkCanGrantTablePrivilege(SystemSecurityContext context, Privilege privilege, CatalogSchemaTableName table, TrinoPrincipal grantee, boolean grantOption) { - OpaQueryInputResource resource = OpaQueryInputResource.builder().table(new TrinoTable(table)).build(); - OpaQueryInputGrant opaGrantee = OpaQueryInputGrant.builder() - .principal(TrinoGrantPrincipal.fromTrinoPrincipal(grantee)) - .grantOption(grantOption) - .privilege(privilege) - .build(); - OpaQueryInputAction action = OpaQueryInputAction.builder() - .operation("GrantTablePrivilege") - .resource(resource) - .grantee(opaGrantee) - .build(); - OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); - - if (!opaHighLevelClient.queryOpa(input)) { - denyGrantTablePrivilege(privilege.toString(), table.toString()); - } + enforcePermissioningOperation(AccessDeniedException::denyGrantTablePrivilege, privilege.toString(), table.toString()); } @Override public void checkCanDenyTablePrivilege(SystemSecurityContext context, Privilege privilege, CatalogSchemaTableName table, TrinoPrincipal grantee) { - OpaQueryInputResource resource = OpaQueryInputResource.builder().table(new TrinoTable(table)).build(); - OpaQueryInputGrant opaGrantee = OpaQueryInputGrant.builder() - .principal(TrinoGrantPrincipal.fromTrinoPrincipal(grantee)) - .privilege(privilege) - .build(); - OpaQueryInputAction action = OpaQueryInputAction.builder() - .operation("DenyTablePrivilege") - .resource(resource) - .grantee(opaGrantee) - .build(); - OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); - - if (!opaHighLevelClient.queryOpa(input)) { - denyDenyTablePrivilege(privilege.toString(), table.toString()); - } + enforcePermissioningOperation(AccessDeniedException::denyDenyTablePrivilege, privilege.toString(), table.toString()); } @Override public void checkCanRevokeTablePrivilege(SystemSecurityContext context, Privilege privilege, CatalogSchemaTableName table, TrinoPrincipal revokee, boolean grantOption) { - OpaQueryInputResource resource = OpaQueryInputResource.builder().table(new TrinoTable(table)).build(); - OpaQueryInputGrant opaRevokee = OpaQueryInputGrant.builder() - .principal(TrinoGrantPrincipal.fromTrinoPrincipal(revokee)) - .privilege(privilege) - .grantOption(grantOption) - .build(); - OpaQueryInputAction action = OpaQueryInputAction.builder() - .operation("RevokeTablePrivilege") - .resource(resource) - .grantee(opaRevokee) - .build(); - OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); - - if (!opaHighLevelClient.queryOpa(input)) { - denyRevokeTablePrivilege(privilege.toString(), table.toString()); - } - } - - @Override - public void checkCanShowRoles(SystemSecurityContext context) - { - opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), - "ShowRoles", - AccessDeniedException::denyShowRoles); + enforcePermissioningOperation(AccessDeniedException::denyRevokeTablePrivilege, privilege.toString(), table.toString()); } @Override public void checkCanCreateRole(SystemSecurityContext context, String role, Optional grantor) { - OpaQueryInputResource resource = OpaQueryInputResource.builder().role(role).build(); - OpaQueryInputAction action = OpaQueryInputAction.builder() - .operation("CreateRole") - .resource(resource) - .grantor(TrinoGrantPrincipal.fromTrinoPrincipal(grantor)) - .build(); - OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); - - if (!opaHighLevelClient.queryOpa(input)) { - denyCreateRole(role); - } + enforcePermissioningOperation(AccessDeniedException::denyCreateRole, role); } @Override public void checkCanDropRole(SystemSecurityContext context, String role) { - opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), - "DropRole", - () -> denyDropRole(role), - OpaQueryInputResource.builder().role(role).build()); + enforcePermissioningOperation(AccessDeniedException::denyDropRole, role); } @Override public void checkCanGrantRoles(SystemSecurityContext context, Set roles, Set grantees, boolean adminOption, Optional grantor) { - OpaQueryInputResource resource = OpaQueryInputResource.builder().roles(roles).build(); - OpaQueryInputGrant opaGrantees = OpaQueryInputGrant.builder() - .grantOption(adminOption) - .principals(grantees.stream() - .map(TrinoGrantPrincipal::fromTrinoPrincipal) - .collect(toImmutableSet())) - .build(); - OpaQueryInputAction action = OpaQueryInputAction.builder() - .operation("GrantRoles") - .resource(resource) - .grantee(opaGrantees) - .grantor(TrinoGrantPrincipal.fromTrinoPrincipal(grantor)) - .build(); - OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); - - if (!opaHighLevelClient.queryOpa(input)) { - denyGrantRoles(roles, grantees); - } + enforcePermissioningOperation(AccessDeniedException::denyGrantRoles, roles, grantees); } @Override public void checkCanRevokeRoles(SystemSecurityContext context, Set roles, Set grantees, boolean adminOption, Optional grantor) { - OpaQueryInputResource resource = OpaQueryInputResource.builder().roles(roles).build(); - OpaQueryInputGrant opaGrantees = OpaQueryInputGrant.builder() - .grantOption(adminOption) - .principals(grantees.stream() - .map(TrinoGrantPrincipal::fromTrinoPrincipal) - .collect(toImmutableSet())) - .build(); - OpaQueryInputAction action = OpaQueryInputAction.builder() - .operation("RevokeRoles") - .resource(resource) - .grantee(opaGrantees) - .grantor(TrinoGrantPrincipal.fromTrinoPrincipal(grantor)) - .build(); - OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + enforcePermissioningOperation(AccessDeniedException::denyRevokeRoles, roles, grantees); + } - if (!opaHighLevelClient.queryOpa(input)) { - denyRevokeRoles(roles, grantees); - } + @Override + public void checkCanShowRoles(SystemSecurityContext context) + { + // We always want to allow users to query their current roles, since OPA does not deal with role information } @Override public void checkCanShowCurrentRoles(SystemSecurityContext context) { - opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), - "ShowCurrentRoles", - AccessDeniedException::denyShowCurrentRoles); + // We always want to allow users to query their current roles, since OPA does not deal with role information } @Override public void checkCanShowRoleGrants(SystemSecurityContext context) { - opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), - "ShowRoleGrants", - AccessDeniedException::denyShowRoleGrants); + // We always want to allow users to query their current roles, since OPA does not deal with role information } @Override @@ -888,4 +730,25 @@ private void checkTableAndColumnsOperation(SystemSecurityContext context, String () -> deny.accept(table.toString(), columns), OpaQueryInputResource.builder().table(new TrinoTable(table).withColumns(columns)).build()); } + + private void enforcePermissioningOperation(Consumer deny, T arg) + { + if (!allowPermissioningOperations) { + deny.accept(arg); + } + } + + private void enforcePermissioningOperation(BiConsumer deny, T arg1, U arg2) + { + if (!allowPermissioningOperations) { + deny.accept(arg1, arg2); + } + } + + private static Map> convertProperties(Map properties) + { + return properties.entrySet().stream() + .map(propertiesEntry -> Map.entry(propertiesEntry.getKey(), Optional.ofNullable(propertiesEntry.getValue()))) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + } } diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaBatchAccessControl.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaBatchAccessControl.java index 266068cbb8a2d..1ac684fd38d9e 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaBatchAccessControl.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaBatchAccessControl.java @@ -55,7 +55,7 @@ public OpaBatchAccessControl( OpaHttpClient opaHttpClient, OpaConfig config) { - super(opaHighLevelClient); + super(opaHighLevelClient, config); this.opaBatchedPolicyUri = config.getOpaBatchUri().orElseThrow(); this.batchResultCodec = batchResultCodec; this.opaHttpClient = opaHttpClient; @@ -129,7 +129,16 @@ public Set filterFunctions(SystemSecurityContext context, St .build()); } - private Function, OpaQueryInput> batchRequestBuilder(OpaQueryContext context, String operation, Function resourceMapper) + private Set batchFilterFromOpa(OpaQueryContext context, String operation, Collection items, Function converter) + { + return opaHttpClient.batchFilterFromOpa( + items, + batchRequestBuilder(context, operation, converter), + opaBatchedPolicyUri, + batchResultCodec); + } + + private static Function, OpaQueryInput> batchRequestBuilder(OpaQueryContext context, String operation, Function resourceMapper) { return items -> new OpaQueryInput( context, @@ -139,7 +148,7 @@ private Function, OpaQueryInput> batchRequestBuilder(OpaQueryContext .build()); } - private BiFunction, OpaQueryInput> batchRequestBuilder(OpaQueryContext context, String operation, BiFunction, OpaQueryInputResource> resourceMapper) + private static BiFunction, OpaQueryInput> batchRequestBuilder(OpaQueryContext context, String operation, BiFunction, OpaQueryInputResource> resourceMapper) { return (resourcesKey, resourcesList) -> new OpaQueryInput( context, @@ -148,13 +157,4 @@ private BiFunction, OpaQueryInput> batchRequestBuilder(OpaQuer .filterResources(ImmutableList.of(resourceMapper.apply(resourcesKey, resourcesList))) .build()); } - - private Set batchFilterFromOpa(OpaQueryContext context, String operation, Collection items, Function converter) - { - return opaHttpClient.batchFilterFromOpa( - items, - batchRequestBuilder(context, operation, converter), - opaBatchedPolicyUri, - batchResultCodec); - } } diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaConfig.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaConfig.java index 6169cd2748e38..675e02b8c7dcc 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaConfig.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaConfig.java @@ -27,6 +27,7 @@ public class OpaConfig private Optional opaBatchUri = Optional.empty(); private boolean logRequests; private boolean logResponses; + private boolean allowPermissioningOperations; @NotNull public URI getOpaUri() @@ -80,4 +81,17 @@ public OpaConfig setLogResponses(boolean logResponses) this.logResponses = logResponses; return this; } + + public boolean getAllowPermissioningOperations() + { + return this.allowPermissioningOperations; + } + + @Config("opa.allow-permissioning-operations") + @ConfigDescription("Whether to allow permissioning operations (GRANT, DENY, ...) as well as role management - OPA will not be queried for any such operations, they will be bulk allowed or denied depending on this setting") + public OpaConfig setAllowPermissioningOperations(boolean allowPermissioningOperations) + { + this.allowPermissioningOperations = allowPermissioningOperations; + return this; + } } diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHighLevelClient.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHighLevelClient.java index 5270c1b7f02da..81a99ddde5dc4 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHighLevelClient.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHighLevelClient.java @@ -105,13 +105,13 @@ public Set parallelFilterFromOpa( return opaHttpClient.parallelFilterFromOpa(items, requestBuilder, opaPolicyUri, queryResultCodec); } - public static OpaQueryInput buildQueryInputForSimpleAction(OpaQueryContext context, String operation) + public static OpaQueryInput buildQueryInputForSimpleResource(OpaQueryContext context, String operation, OpaQueryInputResource resource) { - return new OpaQueryInput(context, OpaQueryInputAction.builder().operation(operation).build()); + return new OpaQueryInput(context, OpaQueryInputAction.builder().operation(operation).resource(resource).build()); } - public static OpaQueryInput buildQueryInputForSimpleResource(OpaQueryContext context, String operation, OpaQueryInputResource resource) + private static OpaQueryInput buildQueryInputForSimpleAction(OpaQueryContext context, String operation) { - return new OpaQueryInput(context, OpaQueryInputAction.builder().operation(operation).resource(resource).build()); + return new OpaQueryInput(context, OpaQueryInputAction.builder().operation(operation).build()); } } diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHttpClient.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHttpClient.java index 5cdc40f2e1c74..d207f5d54e70d 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHttpClient.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHttpClient.java @@ -97,7 +97,7 @@ public FluentFuture submitOpaRequest(OpaQueryInput input, URI uri, JsonCo log.debug( "Sending OPA request to URI \"%s\" ; request body = %s ; request headers = %s", uri.toString(), - tryConvertBytesToString(requestBodyGenerator.getBody()), + new String(requestBodyGenerator.getBody(), UTF_8), request.getHeaders()); } return FluentFuture.from(httpClient.executeAsync(request, createFullJsonResponseHandler(deserializer))) @@ -203,20 +203,9 @@ private T parseOpaResponse(FullJsonResponseHandler.JsonResponse response, "OPA response for URI \"%s\" received: status code = %d ; response payload = %s ; response headers = %s", uriString, statusCode, - tryConvertBytesToString(response.getJsonBytes()), + new String(response.getJsonBytes(), UTF_8), response.getHeaders()); } return response.getValue(); } - - private static String tryConvertBytesToString(byte[] bytes) - { - try { - return new String(bytes, UTF_8); - } - catch (Exception e) { - log.error(e, "Failed to convert JSON bytes to string for logging"); - return ""; - } - } } diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputAction.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputAction.java index 67f47e5077af4..f398529385186 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputAction.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputAction.java @@ -29,8 +29,7 @@ public record OpaQueryInputAction( OpaQueryInputResource resource, List filterResources, OpaQueryInputResource targetResource, - OpaQueryInputGrant grantee, - TrinoGrantPrincipal grantor) + TrinoGrantPrincipal grantee) { public OpaQueryInputAction { @@ -54,8 +53,7 @@ public static class Builder private OpaQueryInputResource resource; private List filterResources; private OpaQueryInputResource targetResource; - private OpaQueryInputGrant grantee; - private TrinoGrantPrincipal grantor; + private TrinoGrantPrincipal grantee; private Builder() {} @@ -83,18 +81,12 @@ public Builder targetResource(OpaQueryInputResource targetResource) return this; } - public Builder grantee(OpaQueryInputGrant grantee) + public Builder grantee(TrinoGrantPrincipal grantee) { this.grantee = grantee; return this; } - public Builder grantor(TrinoGrantPrincipal grantor) - { - this.grantor = grantor; - return this; - } - public OpaQueryInputAction build() { return new OpaQueryInputAction( @@ -102,8 +94,7 @@ public OpaQueryInputAction build() this.resource, this.filterResources, this.targetResource, - this.grantee, - this.grantor); + this.grantee); } } } diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputGrant.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputGrant.java deleted file mode 100644 index 93650035bbb34..0000000000000 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputGrant.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.opa.schema; - -import com.fasterxml.jackson.annotation.JsonInclude; -import com.google.common.collect.ImmutableSet; -import io.trino.spi.security.Privilege; -import jakarta.validation.constraints.NotNull; - -import java.util.Set; - -import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; -import static java.util.Objects.requireNonNull; - -@JsonInclude(NON_NULL) -public record OpaQueryInputGrant(@NotNull Set principals, Boolean grantOption, String privilege) -{ - public OpaQueryInputGrant - { - principals = ImmutableSet.copyOf(requireNonNull(principals, "principals is null")); - } - - public static Builder builder() - { - return new Builder(); - } - - public static class Builder - { - private Set principals; - private Boolean grantOption; - private String privilege; - - private Builder() {} - - public Builder principal(TrinoGrantPrincipal principal) - { - this.principals = ImmutableSet.of(principal); - return this; - } - - public Builder principals(Set principals) - { - this.principals = principals; - return this; - } - - public Builder grantOption(boolean grantOption) - { - this.grantOption = grantOption; - return this; - } - - public Builder privilege(Privilege privilege) - { - this.privilege = privilege.name(); - return this; - } - - public OpaQueryInputGrant build() - { - return new OpaQueryInputGrant(this.principals, this.grantOption, this.privilege); - } - } -} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputResource.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputResource.java index 67cdb010edc15..61820030882ef 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputResource.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputResource.java @@ -14,13 +14,9 @@ package io.trino.plugin.opa.schema; import com.fasterxml.jackson.annotation.JsonInclude; -import com.google.common.collect.ImmutableSet; import jakarta.validation.constraints.NotNull; -import java.util.Set; - import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; -import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Objects.requireNonNull; @JsonInclude(NON_NULL) @@ -31,17 +27,8 @@ public record OpaQueryInputResource( TrinoFunction function, NamedEntity catalog, TrinoSchema schema, - TrinoTable table, - NamedEntity role, - Set roles) + TrinoTable table) { - public OpaQueryInputResource - { - if (roles != null) { - roles = ImmutableSet.copyOf(roles); - } - } - public record NamedEntity(@NotNull String name) { public NamedEntity @@ -63,8 +50,6 @@ public static class Builder private NamedEntity catalog; private TrinoSchema schema; private TrinoTable table; - private NamedEntity role; - private Set roles; private TrinoFunction function; private Builder() {} @@ -105,18 +90,6 @@ public Builder table(TrinoTable table) return this; } - public Builder role(String role) - { - this.role = new NamedEntity(role); - return this; - } - - public Builder roles(Set roles) - { - this.roles = roles.stream().map(NamedEntity::new).collect(toImmutableSet()); - return this; - } - public Builder function(TrinoFunction function) { this.function = function; @@ -138,9 +111,7 @@ public OpaQueryInputResource build() this.function, this.catalog, this.schema, - this.table, - this.role, - this.roles); + this.table); } } } diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/PropertiesMapper.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/PropertiesMapper.java deleted file mode 100644 index 7d38279e97a92..0000000000000 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/PropertiesMapper.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.plugin.opa.schema; - -import java.util.Map; -import java.util.Optional; - -import static com.google.common.collect.ImmutableMap.toImmutableMap; - -public class PropertiesMapper -{ - private PropertiesMapper() - {} - - public static Map> convertProperties(Map properties) - { - return properties.entrySet().stream() - .map(propertiesEntry -> Map.entry(propertiesEntry.getKey(), Optional.ofNullable(propertiesEntry.getValue()))) - .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); - } -} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoGrantPrincipal.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoGrantPrincipal.java index 702c2b0311c81..3532e2450c57d 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoGrantPrincipal.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoGrantPrincipal.java @@ -17,8 +17,6 @@ import io.trino.spi.security.TrinoPrincipal; import jakarta.validation.constraints.NotNull; -import java.util.Optional; - import static com.fasterxml.jackson.annotation.JsonInclude.Include.NON_NULL; import static java.util.Objects.requireNonNull; @@ -30,11 +28,6 @@ public static TrinoGrantPrincipal fromTrinoPrincipal(TrinoPrincipal principal) return new TrinoGrantPrincipal(principal.getType().name(), principal.getName()); } - public static TrinoGrantPrincipal fromTrinoPrincipal(Optional principal) - { - return principal.map(TrinoGrantPrincipal::fromTrinoPrincipal).orElse(null); - } - public TrinoGrantPrincipal { requireNonNull(type, "type is null"); diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoIdentity.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoIdentity.java index da104199f7e2c..c1e5500b57bea 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoIdentity.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoIdentity.java @@ -13,33 +13,28 @@ */ package io.trino.plugin.opa.schema; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.trino.spi.security.Identity; import jakarta.validation.constraints.NotNull; -import java.util.Map; import java.util.Set; import static java.util.Objects.requireNonNull; public record TrinoIdentity( @NotNull String user, - @NotNull Set groups, - @NotNull Map extraCredentials) + @NotNull Set groups) { public static TrinoIdentity fromTrinoIdentity(Identity identity) { return new TrinoIdentity( identity.getUser(), - identity.getGroups(), - identity.getExtraCredentials()); + identity.getGroups()); } public TrinoIdentity { requireNonNull(user, "user is null"); groups = ImmutableSet.copyOf(requireNonNull(groups, "groups is null")); - extraCredentials = ImmutableMap.copyOf(requireNonNull(extraCredentials, "extraCredentials is null")); } } diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FilteringTestHelpers.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FilteringTestHelpers.java index 38ca687c02927..75c9b628aa5e7 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FilteringTestHelpers.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FilteringTestHelpers.java @@ -35,8 +35,8 @@ private FilteringTestHelpers() {} public static Stream emptyInputTestCases() { - Stream> callables = Stream.of( - (authorizer, context) -> authorizer.filterViewQueryOwnedBy(context.getIdentity(), ImmutableSet.of()), + Stream>> callables = Stream.of( + (authorizer, context) -> authorizer.filterViewQueryOwnedBy(context.getIdentity(), ImmutableSet.of()), (authorizer, context) -> authorizer.filterCatalogs(context, ImmutableSet.of()), (authorizer, context) -> authorizer.filterSchemas(context, "my_catalog", ImmutableSet.of()), (authorizer, context) -> authorizer.filterTables(context, "my_catalog", ImmutableSet.of()), diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FunctionalHelpers.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FunctionalHelpers.java index 9a4cabccd79b8..e3fbdc2c78a30 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FunctionalHelpers.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FunctionalHelpers.java @@ -15,60 +15,36 @@ public class FunctionalHelpers { - @FunctionalInterface - public static interface Consumer3 + public interface Consumer3 { void accept(T1 t1, T2 t2, T3 t3); } - @FunctionalInterface - public static interface Function3 + public interface Function3 { R apply(T1 t1, T2 t2, T3 t3); } - @FunctionalInterface - public static interface Consumer4 + public interface Consumer4 { void accept(T1 t1, T2 t2, T3 t3, T4 t4); } - @FunctionalInterface - public static interface Consumer5 + public interface Consumer5 { void accept(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5); } - @FunctionalInterface - public static interface Consumer6 + public interface Consumer6 { void accept(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6); } - public static class Pair + public record Pair(T first, U second) { - private T first; - private U second; - - public T getFirst() - { - return this.first; - } - - public U getSecond() - { - return this.second; - } - - public Pair(T first, U second) - { - this.first = first; - this.second = second; - } - public static Pair of(T first, U second) { - return new Pair(first, second); + return new Pair<>(first, second); } } } diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java index c8e3aec7b9dc3..980ce862c735e 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java @@ -13,6 +13,8 @@ */ package io.trino.plugin.opa; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.json.JsonMapper; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import io.airlift.http.client.HttpStatus; @@ -22,6 +24,7 @@ import io.airlift.http.client.testing.TestingHttpClient; import io.airlift.http.client.testing.TestingResponse; +import java.io.IOException; import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.LinkedList; @@ -39,13 +42,14 @@ private HttpClientUtils() {} public static class RecordingHttpProcessor implements TestingHttpClient.Processor { - private final List requests = new LinkedList<>(); - private Function handler; + private static final JsonMapper jsonMapper = new JsonMapper(); + private final List requests = new LinkedList<>(); + private final Function handler; private final URI expectedURI; private final String expectedMethod; private final String expectedContentType; - public RecordingHttpProcessor(URI expectedURI, String expectedMethod, String expectedContentType, Function handler) + public RecordingHttpProcessor(URI expectedURI, String expectedMethod, String expectedContentType, Function handler) { this.expectedMethod = requireNonNull(expectedMethod, "expectedMethod is null"); this.expectedContentType = requireNonNull(expectedContentType, "expectedContentType is null"); @@ -54,7 +58,7 @@ public RecordingHttpProcessor(URI expectedURI, String expectedMethod, String exp } @Override - public Response handle(Request request) + public synchronized Response handle(Request request) { if (!requireNonNull(request.getMethod()).equalsIgnoreCase(expectedMethod)) { throw new IllegalArgumentException("Unexpected method: %s".formatted(request.getMethod())); @@ -67,10 +71,14 @@ public Response handle(Request request) throw new IllegalArgumentException("Unexpected URI: %s".formatted(request.getUri().toString())); } if (requireNonNull(request.getBodyGenerator()) instanceof StaticBodyGenerator bodyGenerator) { - synchronized (this.requests) { - String requestContents = new String(bodyGenerator.getBody(), StandardCharsets.UTF_8); - requests.add(requestContents); - return handler.apply(requestContents).buildResponse(); + String requestContents = new String(bodyGenerator.getBody(), StandardCharsets.UTF_8); + try { + JsonNode parsedRequest = jsonMapper.readTree(requestContents); + requests.add(parsedRequest); + return handler.apply(parsedRequest).buildResponse(); + } + catch (IOException e) { + throw new IllegalArgumentException("Request has illegal JSON", e); } } else { @@ -78,16 +86,9 @@ public Response handle(Request request) } } - public List getRequests() - { - synchronized (this.requests) { - return ImmutableList.copyOf(this.requests); - } - } - - public void setHandler(Function handler) + public synchronized List getRequests() { - this.handler = handler; + return ImmutableList.copyOf(requests); } } @@ -96,7 +97,7 @@ public static class InstrumentedHttpClient { private final RecordingHttpProcessor httpProcessor; - public InstrumentedHttpClient(URI expectedURI, String expectedMethod, String expectedContentType, Function handler) + public InstrumentedHttpClient(URI expectedURI, String expectedMethod, String expectedContentType, Function handler) { this(new RecordingHttpProcessor(expectedURI, expectedMethod, expectedContentType, handler)); } @@ -107,14 +108,9 @@ public InstrumentedHttpClient(RecordingHttpProcessor processor) this.httpProcessor = processor; } - public void setHandler(Function handler) + public List getRequests() { - this.httpProcessor.setHandler(handler); - } - - public List getRequests() - { - return this.httpProcessor.getRequests(); + return httpProcessor.getRequests(); } } @@ -127,5 +123,5 @@ public TestingResponse buildResponse() ImmutableListMultimap.of(CONTENT_TYPE, JSON_UTF_8.toString()), this.contents.getBytes(StandardCharsets.UTF_8)); } - }; + } } diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java index 97c7dff942dbf..73bd79f09a356 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java @@ -16,14 +16,16 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.json.JsonMapper; import com.google.common.collect.ImmutableSet; +import io.trino.plugin.opa.HttpClientUtils.MockResponse; +import io.trino.spi.security.Identity; import java.io.IOException; import java.util.Collection; import java.util.Set; +import java.util.function.Function; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.fail; +import static org.assertj.core.api.Assertions.assertThat; public class RequestTestUtilities { @@ -31,8 +33,7 @@ private RequestTestUtilities() {} private static final JsonMapper jsonMapper = new JsonMapper(); - public static void assertStringRequestsEqual( - Collection expectedRequests, Collection actualRequests, String extractPath) + public static void assertStringRequestsEqual(Set expectedRequests, Collection actualRequests, String extractPath) { Set parsedExpectedRequests = expectedRequests.stream() .map(expectedRequest -> { @@ -40,37 +41,36 @@ public static void assertStringRequestsEqual( return jsonMapper.readTree(expectedRequest); } catch (IOException e) { - fail("Could not parse request", e); - return null; + throw new AssertionError("Cannot parse expected request", e); } }) .collect(toImmutableSet()); - assertJsonRequestsEqual(parsedExpectedRequests, actualRequests, extractPath); + Set extractedActualRequests = actualRequests.stream().map(node -> node.at(extractPath)).collect(toImmutableSet()); + assertThat(extractedActualRequests).containsExactlyInAnyOrderElementsOf(parsedExpectedRequests); } - public static void assertJsonRequestsEqual( - Collection expectedRequests, Collection actualRequests, String extractPath) + public static Function buildValidatingRequestHandler(Identity expectedUser, int statusCode, String responseContents) { - Set parsedActualRequests = actualRequests.stream() - .map(actualRequest -> { - try { - JsonNode parsed = jsonMapper.readTree(actualRequest); - if (extractPath != null) { - return parsed.at(extractPath); - } - return parsed; - } - catch (IOException e) { - fail("Could not parse request", e); - return null; - } - }) - .collect(toImmutableSet()); - Set expectedRequestSet = ImmutableSet.copyOf(expectedRequests); - assertEquals( - expectedRequestSet.size(), - parsedActualRequests.size(), - "Mismatch in expected vs. actual request count"); - assertEquals(expectedRequestSet, parsedActualRequests, "Requests do not match"); + return buildValidatingRequestHandler(expectedUser, new MockResponse(responseContents, statusCode)); + } + + public static Function buildValidatingRequestHandler(Identity expectedUser, MockResponse response) + { + return buildValidatingRequestHandler(expectedUser, jsonNode -> response); + } + + public static Function buildValidatingRequestHandler(Identity expectedUser, Function customHandler) + { + return parsedRequest -> { + if (!parsedRequest.at("/input/context/identity/user").asText().equals(expectedUser.getUser())) { + throw new AssertionError("Request had invalid user in the identity block"); + } + ImmutableSet.Builder groupsInRequestBuilder = ImmutableSet.builder(); + parsedRequest.at("/input/context/identity/groups").iterator().forEachRemaining(node -> groupsInRequestBuilder.add(node.asText())); + if (!groupsInRequestBuilder.build().equals(expectedUser.getGroups())) { + throw new AssertionError("Request had invalid set of groups in the identity block"); + } + return customHandler.apply(parsedRequest); + }; } } diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java index 7e67030834d64..7061498fe826c 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java @@ -13,46 +13,53 @@ */ package io.trino.plugin.opa; +import com.fasterxml.jackson.databind.JsonNode; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; import io.trino.execution.QueryIdGenerator; +import io.trino.plugin.opa.HttpClientUtils.InstrumentedHttpClient; +import io.trino.plugin.opa.HttpClientUtils.MockResponse; import io.trino.spi.security.AccessDeniedException; import io.trino.spi.security.Identity; import io.trino.spi.security.SystemSecurityContext; import org.junit.jupiter.api.Named; import org.junit.jupiter.params.provider.Arguments; +import java.net.URI; import java.time.Instant; import java.util.Arrays; -import java.util.function.BiConsumer; +import java.util.Optional; +import java.util.function.Function; import java.util.stream.Stream; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.net.MediaType.JSON_UTF_8; public class TestHelpers { private TestHelpers() {} - public static final HttpClientUtils.MockResponse OK_RESPONSE = new HttpClientUtils.MockResponse(""" + public static final MockResponse OK_RESPONSE = new MockResponse(""" { "decision_id": "", "result": true } """, 200); - public static final HttpClientUtils.MockResponse NO_ACCESS_RESPONSE = new HttpClientUtils.MockResponse(""" + public static final MockResponse NO_ACCESS_RESPONSE = new MockResponse(""" { "decision_id": "", "result": false } """, 200); - public static final HttpClientUtils.MockResponse MALFORMED_RESPONSE = new HttpClientUtils.MockResponse(""" + public static final MockResponse MALFORMED_RESPONSE = new MockResponse(""" { "this"": is broken_json; } """, 200); - public static final HttpClientUtils.MockResponse UNDEFINED_RESPONSE = new HttpClientUtils.MockResponse("{}", 404); - public static final HttpClientUtils.MockResponse BAD_REQUEST_RESPONSE = new HttpClientUtils.MockResponse("{}", 400); - public static final HttpClientUtils.MockResponse SERVER_ERROR_RESPONSE = new HttpClientUtils.MockResponse("", 500); + public static final MockResponse UNDEFINED_RESPONSE = new MockResponse("{}", 404); + public static final MockResponse BAD_REQUEST_RESPONSE = new MockResponse("{}", 400); + public static final MockResponse SERVER_ERROR_RESPONSE = new MockResponse("", 500); public static Stream createFailingTestCases(Stream baseTestCases) { @@ -94,17 +101,6 @@ public static SystemSecurityContext systemSecurityContextFromIdentity(Identity i return new SystemSecurityContext(identity, new QueryIdGenerator().createNextQueryId(), Instant.now()); } - public static BiConsumer convertSystemSecurityContextToIdentityArgument( - BiConsumer callable) - { - return (accessControl, systemSecurityContext) -> callable.accept(accessControl, systemSecurityContext.getIdentity()); - } - - public static FunctionalHelpers.Consumer3 convertSystemSecurityContextToIdentityArgument( - FunctionalHelpers.Consumer3 callable) { - return (accessControl, systemSecurityContext, argument) -> callable.accept(accessControl, systemSecurityContext.getIdentity(), argument); - } - public abstract static class MethodWrapper { public abstract boolean isAccessAllowed(OpaAccessControl opaAccessControl, SystemSecurityContext systemSecurityContext, T argument); } @@ -139,4 +135,24 @@ public boolean isAccessAllowed(OpaAccessControl opaAccessControl, SystemSecurity return this.callable.apply(opaAccessControl, systemSecurityContext, argument); } } + + public static InstrumentedHttpClient createMockHttpClient(URI expectedUri, Function handler) + { + return new InstrumentedHttpClient(expectedUri, "POST", JSON_UTF_8.toString(), handler); + } + + public static OpaAccessControl createOpaAuthorizer(URI opaUri, InstrumentedHttpClient mockHttpClient) + { + return (OpaAccessControl) OpaAccessControlFactory.create(ImmutableMap.of("opa.policy.uri", opaUri.toString()), Optional.of(mockHttpClient)); + } + + public static OpaAccessControl createOpaAuthorizer(URI opaUri, URI opaBatchUri, InstrumentedHttpClient mockHttpClient) + { + return (OpaAccessControl) OpaAccessControlFactory.create( + ImmutableMap.builder() + .put("opa.policy.uri", opaUri.toString()) + .put("opa.policy.batched-uri", opaBatchUri.toString()) + .buildOrThrow(), + Optional.of(mockHttpClient)); + } } diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlUnitTest.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java similarity index 54% rename from plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlUnitTest.java rename to plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java index 1cf2265f6e403..027f851bc3bb9 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlUnitTest.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java @@ -13,123 +13,97 @@ */ package io.trino.plugin.opa; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.json.JsonMapper; -import com.fasterxml.jackson.databind.node.ObjectNode; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; +import io.trino.plugin.opa.HttpClientUtils.InstrumentedHttpClient; +import io.trino.plugin.opa.HttpClientUtils.MockResponse; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.CatalogSchemaRoutineName; import io.trino.spi.connector.CatalogSchemaTableName; import io.trino.spi.security.Identity; import io.trino.spi.security.PrincipalType; -import io.trino.spi.security.Privilege; import io.trino.spi.security.SystemSecurityContext; import io.trino.spi.security.TrinoPrincipal; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Named; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import java.io.IOException; import java.net.URI; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.function.BiConsumer; import java.util.stream.Stream; -import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.net.MediaType.JSON_UTF_8; -import static io.trino.plugin.opa.HttpClientUtils.InstrumentedHttpClient; -import static io.trino.plugin.opa.RequestTestUtilities.assertJsonRequestsEqual; import static io.trino.plugin.opa.RequestTestUtilities.assertStringRequestsEqual; +import static io.trino.plugin.opa.RequestTestUtilities.buildValidatingRequestHandler; import static io.trino.plugin.opa.TestHelpers.NO_ACCESS_RESPONSE; import static io.trino.plugin.opa.TestHelpers.OK_RESPONSE; -import static io.trino.plugin.opa.TestHelpers.convertSystemSecurityContextToIdentityArgument; import static io.trino.plugin.opa.TestHelpers.createFailingTestCases; import static io.trino.plugin.opa.TestHelpers.createIllegalResponseTestCases; +import static io.trino.plugin.opa.TestHelpers.createMockHttpClient; +import static io.trino.plugin.opa.TestHelpers.createOpaAuthorizer; import static io.trino.plugin.opa.TestHelpers.systemSecurityContextFromIdentity; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; -public class OpaAccessControlUnitTest +public class TestOpaAccessControl { private static final URI OPA_SERVER_URI = URI.create("http://my-uri/"); - private InstrumentedHttpClient mockClient; - private OpaAccessControl authorizer; - private final JsonMapper jsonMapper = new JsonMapper(); - private Identity requestingIdentity; - private SystemSecurityContext requestingSecurityContext; - - @BeforeEach - public void setupAuthorizer() - { - this.mockClient = new InstrumentedHttpClient(OPA_SERVER_URI, "POST", JSON_UTF_8.toString(), request -> OK_RESPONSE); - this.authorizer = (OpaAccessControl) new OpaAccessControlFactory().create(ImmutableMap.of("opa.policy.uri", OPA_SERVER_URI.toString()), Optional.of(mockClient)); - this.requestingIdentity = Identity.ofUser("source-user"); - this.requestingSecurityContext = systemSecurityContextFromIdentity(requestingIdentity); - } - - @AfterEach - public void ensureRequestContextCorrect() - throws IOException - { - for (String request : mockClient.getRequests()) { - JsonNode parsedRequest = jsonMapper.readTree(request); - assertEquals(parsedRequest.at("/input/context/identity/user").asText(), requestingIdentity.getUser()); - } - } + private final Identity requestingIdentity = Identity.ofUser("source-user"); + private final SystemSecurityContext requestingSecurityContext = systemSecurityContextFromIdentity(requestingIdentity); @Test public void testResponseHasExtraFields() { - mockClient.setHandler(request -> new HttpClientUtils.MockResponse(""" + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, 200,""" { "result": true, "decision_id": "foo", "some_debug_info": {"test": ""} - } - """, - 200)); - authorizer.checkCanShowRoles(requestingSecurityContext); + }""")); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + authorizer.checkCanExecuteQuery(requestingIdentity); } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#noResourceActionTestCases") - public void testNoResourceAction(String actionName, BiConsumer method) + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#noResourceActionTestCases") + public void testNoResourceAction(String actionName, BiConsumer method) { - method.accept(authorizer, requestingSecurityContext); - ObjectNode expectedRequest = jsonMapper.createObjectNode().put("operation", actionName); - assertJsonRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + + method.accept(authorizer, requestingIdentity); + String expectedRequest = """ + { + "operation": "%s" + }""".formatted(actionName); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); } @ParameterizedTest(name = "{index}: {0} - {2}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#noResourceActionFailureTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#noResourceActionFailureTestCases") public void testNoResourceActionFailure( String actionName, - BiConsumer method, - HttpClientUtils.MockResponse failureResponse, + BiConsumer method, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); - - Throwable actualError = assertThrows( - expectedException, - () -> method.accept(authorizer, requestingSecurityContext)); - ObjectNode expectedRequest = jsonMapper.createObjectNode().put("operation", actionName); - assertJsonRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); - assertTrue(actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + + assertThatThrownBy(() -> method.accept(authorizer, requestingIdentity)) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); + String expectedRequest = """ + { + "operation": "%s" + }""".formatted(actionName); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); } private static Stream tableResourceTestCases() @@ -174,11 +148,14 @@ private static Stream tableResourceTestCases() } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tableResourceTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#tableResourceTestCases") public void testTableResourceActions( String actionName, FunctionalHelpers.Consumer3 callable) { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + callable.accept( authorizer, requestingSecurityContext, @@ -205,24 +182,24 @@ private static Stream tableResourceFailureTestCases() } @ParameterizedTest(name = "{index}: {0} - {3}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tableResourceFailureTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#tableResourceFailureTestCases") public void testTableResourceFailure( String actionName, FunctionalHelpers.Consumer3 method, - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - Throwable actualError = assertThrows( - expectedException, + assertThatThrownBy( () -> method.accept( authorizer, requestingSecurityContext, - new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"))); - assertTrue(actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"))) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } private static Stream tableWithPropertiesTestCases() @@ -241,11 +218,14 @@ private static Stream tableWithPropertiesTestCases() } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tableWithPropertiesTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#tableWithPropertiesTestCases") public void testTableWithPropertiesActions( String actionName, FunctionalHelpers.Consumer4 callable) { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); Map> properties = ImmutableMap.>builder() .put("string_item", Optional.of("string_value")) @@ -281,25 +261,25 @@ private static Stream tableWithPropertiesFailureTestCases() } @ParameterizedTest(name = "{index}: {0} - {3}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tableWithPropertiesFailureTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#tableWithPropertiesFailureTestCases") public void testTableWithPropertiesActionFailure( String actionName, FunctionalHelpers.Consumer4 method, - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - Throwable actualError = assertThrows( - expectedException, + assertThatThrownBy( () -> method.accept( authorizer, requestingSecurityContext, new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"), - ImmutableMap.of())); - assertTrue(actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + ImmutableMap.of())) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } private static Stream identityResourceTestCases() @@ -314,14 +294,16 @@ private static Stream identityResourceTestCases() } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#identityResourceTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#identityResourceTestCases") public void testIdentityResourceActions( String actionName, FunctionalHelpers.Consumer3 callable) { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + Identity dummyIdentity = Identity.forUser("dummy-user") .withGroups(ImmutableSet.of("some-group")) - .withExtraCredentials(ImmutableMap.of("some_extra_credential", "value")) .build(); callable.accept(authorizer, requestingIdentity, dummyIdentity); @@ -331,8 +313,7 @@ public void testIdentityResourceActions( "resource": { "user": { "user": "dummy-user", - "groups": ["some-group"], - "extraCredentials": {"some_extra_credential": "value"} + "groups": ["some-group"] } } } @@ -346,53 +327,54 @@ private static Stream identityResourceFailureTestCases() } @ParameterizedTest(name = "{index}: {0} - {2}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#identityResourceFailureTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#identityResourceFailureTestCases") public void testIdentityResourceActionsFailure( String actionName, FunctionalHelpers.Consumer3 method, - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - Throwable actualError = assertThrows( - expectedException, + assertThatThrownBy( () -> method.accept( authorizer, requestingIdentity, - Identity.ofUser("dummy-user"))); - assertTrue(actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + Identity.ofUser("dummy-user"))) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } private static Stream stringResourceTestCases() { Stream> methods = Stream.of( - convertSystemSecurityContextToIdentityArgument(OpaAccessControl::checkCanSetSystemSessionProperty), + (accessControl, systemSecurityContext, argument) -> accessControl.checkCanSetSystemSessionProperty(systemSecurityContext.getIdentity(), argument), OpaAccessControl::checkCanCreateCatalog, OpaAccessControl::checkCanDropCatalog, - OpaAccessControl::checkCanShowSchemas, - OpaAccessControl::checkCanDropRole); + OpaAccessControl::checkCanShowSchemas); Stream> actionAndResource = Stream.of( FunctionalHelpers.Pair.of("SetSystemSessionProperty", "systemSessionProperty"), FunctionalHelpers.Pair.of("CreateCatalog", "catalog"), FunctionalHelpers.Pair.of("DropCatalog", "catalog"), - FunctionalHelpers.Pair.of("ShowSchemas", "catalog"), - FunctionalHelpers.Pair.of("DropRole", "role")); + FunctionalHelpers.Pair.of("ShowSchemas", "catalog")); return Streams.zip( actionAndResource, methods, - (action, method) -> Arguments.of(Named.of(action.getFirst(), action.getFirst()), action.getSecond(), method)); + (action, method) -> Arguments.of(Named.of(action.first(), action.first()), action.second(), method)); } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#stringResourceTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#stringResourceTestCases") public void testStringResourceAction( String actionName, String resourceName, FunctionalHelpers.Consumer3 callable) { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + callable.accept(authorizer, requestingSecurityContext, "resource_name"); String expectedRequest = """ @@ -414,30 +396,33 @@ public static Stream stringResourceFailureTestCases() } @ParameterizedTest(name = "{index}: {0} - {3}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#stringResourceFailureTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#stringResourceFailureTestCases") public void testStringResourceActionsFailure( String actionName, String resourceName, FunctionalHelpers.Consumer3 method, - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - Throwable actualError = assertThrows( - expectedException, + assertThatThrownBy( () -> method.accept( authorizer, requestingSecurityContext, - "dummy_value")); - assertTrue(actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + "dummy_value")) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } @Test public void testCanImpersonateUser() { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + authorizer.checkCanImpersonateUser(requestingIdentity, "some_other_user"); String expectedRequest = """ @@ -456,56 +441,57 @@ public void testCanImpersonateUser() @ParameterizedTest(name = "{index}: {0}") @MethodSource("io.trino.plugin.opa.TestHelpers#allErrorCasesArgumentProvider") public void testCanImpersonateUserFailure( - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - Throwable actualError = assertThrows( - expectedException, - () -> authorizer.checkCanImpersonateUser(requestingIdentity, "some_other_user")); - assertTrue(actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + assertThatThrownBy( + () -> authorizer.checkCanImpersonateUser(requestingIdentity, "some_other_user")) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } @Test public void testCanAccessCatalog() { - mockClient.setHandler(request -> OK_RESPONSE); - assertTrue(authorizer.canAccessCatalog(requestingSecurityContext, "my_catalog_one")); + InstrumentedHttpClient permissiveClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl permissiveAuthorizer = createOpaAuthorizer(OPA_SERVER_URI, permissiveClient); + assertThat(permissiveAuthorizer.canAccessCatalog(requestingSecurityContext, "test_catalog")).isTrue(); - mockClient.setHandler(request -> NO_ACCESS_RESPONSE); - assertFalse(authorizer.canAccessCatalog(requestingSecurityContext, "my_catalog_two")); + InstrumentedHttpClient restrictiveClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, NO_ACCESS_RESPONSE)); + OpaAccessControl restrictiveAuthorizer = createOpaAuthorizer(OPA_SERVER_URI, restrictiveClient); + assertThat(restrictiveAuthorizer.canAccessCatalog(requestingSecurityContext, "test_catalog")).isFalse(); - Set expectedRequests = ImmutableSet.of("my_catalog_one", "my_catalog_two").stream().map(""" + String expectedRequest = """ { "operation": "AccessCatalog", "resource": { "catalog": { - "name": "%s" + "name": "test_catalog" } } - } - """::formatted) - .collect(toImmutableSet()); - assertStringRequestsEqual(expectedRequests, mockClient.getRequests(), "/input/action"); + }"""; + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), permissiveClient.getRequests(), "/input/action"); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), restrictiveClient.getRequests(), "/input/action"); } @ParameterizedTest(name = "{index}: {0} - {3}") @MethodSource("io.trino.plugin.opa.TestHelpers#illegalResponseArgumentProvider") public void testCanAccessCatalogIllegalResponses( - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - Throwable actualError = assertThrows( - expectedException, - () -> authorizer.canAccessCatalog(requestingSecurityContext, "my_catalog")); - assertTrue(actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + assertThatThrownBy( + () -> authorizer.canAccessCatalog(requestingSecurityContext, "my_catalog")) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } private static Stream schemaResourceTestCases() @@ -524,11 +510,14 @@ private static Stream schemaResourceTestCases() } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#schemaResourceTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#schemaResourceTestCases") public void testSchemaResourceActions( String actionName, FunctionalHelpers.Consumer3 callable) { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + callable.accept(authorizer, requestingSecurityContext, new CatalogSchemaName("my_catalog", "my_schema")); String expectedRequest = """ @@ -551,34 +540,37 @@ public static Stream schemaResourceFailureTestCases() } @ParameterizedTest(name = "{index}: {0} - {2}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#schemaResourceFailureTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#schemaResourceFailureTestCases") public void testSchemaResourceActionsFailure( String actionName, FunctionalHelpers.Consumer3 method, - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - Throwable actualError = assertThrows( - expectedException, + assertThatThrownBy( () -> method.accept( authorizer, requestingSecurityContext, - new CatalogSchemaName("dummy_catalog", "dummy_schema"))); - assertTrue(actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + new CatalogSchemaName("dummy_catalog", "dummy_schema"))) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } @Test public void testCreateSchema() { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + CatalogSchemaName schema = new CatalogSchemaName("my_catalog", "my_schema"); authorizer.checkCanCreateSchema(requestingSecurityContext, schema, ImmutableMap.of("some_key", "some_value")); authorizer.checkCanCreateSchema(requestingSecurityContext, schema, ImmutableMap.of()); - List expectedRequests = ImmutableList.builder() + Set expectedRequests = ImmutableSet.builder() .add(""" { "operation": "CreateSchema", @@ -612,25 +604,28 @@ public void testCreateSchema() @ParameterizedTest(name = "{index}: {0}") @MethodSource("io.trino.plugin.opa.TestHelpers#allErrorCasesArgumentProvider") public void testCreateSchemaFailure( - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - Throwable actualError = assertThrows( - expectedException, + assertThatThrownBy( () -> authorizer.checkCanCreateSchema( requestingSecurityContext, new CatalogSchemaName("my_catalog", "my_schema"), - ImmutableMap.of())); - assertTrue(actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + ImmutableMap.of())) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } @Test public void testCanRenameSchema() { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + CatalogSchemaName sourceSchema = new CatalogSchemaName("my_catalog", "my_schema"); authorizer.checkCanRenameSchema(requestingSecurityContext, sourceSchema, "new_schema_name"); @@ -657,20 +652,20 @@ public void testCanRenameSchema() @ParameterizedTest(name = "{index}: {0}") @MethodSource("io.trino.plugin.opa.TestHelpers#allErrorCasesArgumentProvider") public void testCanRenameSchemaFailure( - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - Throwable actualError = assertThrows( - expectedException, + assertThatThrownBy( () -> authorizer.checkCanRenameSchema( requestingSecurityContext, new CatalogSchemaName("my_catalog", "my_schema"), - "new_schema_name")); - assertTrue(actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + "new_schema_name")) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } private static Stream renameTableTestCases() @@ -687,11 +682,14 @@ private static Stream renameTableTestCases() } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#renameTableTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#renameTableTestCases") public void testRenameTableActions( String actionName, FunctionalHelpers.Consumer4 method) { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + CatalogSchemaTableName sourceTable = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); CatalogSchemaTableName targetTable = new CatalogSchemaTableName("my_catalog", "new_schema_name", "new_table_name"); @@ -725,32 +723,35 @@ public static Stream renameTableFailureTestCases() } @ParameterizedTest(name = "{index}: {0} - {3}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#renameTableFailureTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#renameTableFailureTestCases") public void testRenameTableFailure( String actionName, FunctionalHelpers.Consumer4 method, - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); CatalogSchemaTableName sourceTable = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); CatalogSchemaTableName targetTable = new CatalogSchemaTableName("my_catalog", "new_schema_name", "new_table_name"); - Throwable actualError = assertThrows( - expectedException, + assertThatThrownBy( () -> method.accept( authorizer, requestingSecurityContext, sourceTable, - targetTable)); - assertTrue(actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + targetTable)) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } @Test public void testCanSetSchemaAuthorization() { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + CatalogSchemaName schema = new CatalogSchemaName("my_catalog", "my_schema"); authorizer.checkCanSetSchemaAuthorization(requestingSecurityContext, schema, new TrinoPrincipal(PrincipalType.USER, "my_user")); @@ -765,12 +766,8 @@ public void testCanSetSchemaAuthorization() } }, "grantee": { - "principals": [ - { - "name": "my_user", - "type": "USER" - } - ] + "name": "my_user", + "type": "USER" } } """; @@ -780,21 +777,21 @@ public void testCanSetSchemaAuthorization() @ParameterizedTest(name = "{index}: {0}") @MethodSource("io.trino.plugin.opa.TestHelpers#allErrorCasesArgumentProvider") public void testCanSetSchemaAuthorizationFailure( - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); CatalogSchemaName schema = new CatalogSchemaName("my_catalog", "my_schema"); - Throwable actualError = assertThrows( - expectedException, + assertThatThrownBy( () -> authorizer.checkCanSetSchemaAuthorization( requestingSecurityContext, schema, - new TrinoPrincipal(PrincipalType.USER, "my_user"))); - assertTrue(actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + new TrinoPrincipal(PrincipalType.USER, "my_user"))) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } private static Stream setTableAuthorizationTestCases() @@ -809,11 +806,14 @@ private static Stream setTableAuthorizationTestCases() } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#setTableAuthorizationTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#setTableAuthorizationTestCases") public void testCanSetTableAuthorization( String actionName, FunctionalHelpers.Consumer4 method) { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); method.accept(authorizer, requestingSecurityContext, table, new TrinoPrincipal(PrincipalType.USER, "my_user")); @@ -829,12 +829,8 @@ public void testCanSetTableAuthorization( } }, "grantee": { - "principals": [ - { - "name": "my_user", - "type": "USER" - } - ] + "name": "my_user", + "type": "USER" } } """.formatted(actionName); @@ -847,27 +843,27 @@ private static Stream setTableAuthorizationFailureTestCases() } @ParameterizedTest(name = "{index}: {0} - {3}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#setTableAuthorizationFailureTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#setTableAuthorizationFailureTestCases") public void testCanSetTableAuthorizationFailure( String actionName, FunctionalHelpers.Consumer4 method, - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); - Throwable actualError = assertThrows( - expectedException, + assertThatThrownBy( () -> method.accept( authorizer, requestingSecurityContext, table, - new TrinoPrincipal(PrincipalType.USER, "my_user"))); - assertTrue(actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + new TrinoPrincipal(PrincipalType.USER, "my_user"))) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } private static Stream tableColumnOperationTestCases() @@ -884,11 +880,14 @@ private static Stream tableColumnOperationTestCases() } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tableColumnOperationTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#tableColumnOperationTestCases") public void testTableColumnOperations( String actionName, FunctionalHelpers.Consumer4> method) { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); Set columns = ImmutableSet.of("my_column"); @@ -916,29 +915,32 @@ private static Stream tableColumnOperationFailureTestCases() } @ParameterizedTest(name = "{index}: {0} - {2}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tableColumnOperationFailureTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#tableColumnOperationFailureTestCases") public void testTableColumnOperationsFailure( String actionName, FunctionalHelpers.Consumer4> method, - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); Set columns = ImmutableSet.of("my_column"); - Throwable actualError = assertThrows( - expectedException, - () -> method.accept(authorizer, requestingSecurityContext, table, columns)); - assertTrue( - actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + assertThatThrownBy( + () -> method.accept(authorizer, requestingSecurityContext, table, columns)) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } @Test public void testCanSetCatalogSessionProperty() { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + authorizer.checkCanSetCatalogSessionProperty( requestingSecurityContext, "my_catalog", "my_property"); @@ -959,365 +961,20 @@ public void testCanSetCatalogSessionProperty() @ParameterizedTest(name = "{index}: {0}") @MethodSource("io.trino.plugin.opa.TestHelpers#allErrorCasesArgumentProvider") public void testCanSetCatalogSessionPropertyFailure( - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - Throwable actualError = assertThrows( - expectedException, + assertThatThrownBy( () -> authorizer.checkCanSetCatalogSessionProperty( requestingSecurityContext, "my_catalog", - "my_property")); - assertTrue( - actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); - } - - private static Stream schemaPrivilegeTestCases() - { - Stream> methods = Stream.of( - OpaAccessControl::checkCanDenySchemaPrivilege, - (authorizer, context, privilege, catalog, principal) -> authorizer.checkCanGrantSchemaPrivilege( - context, privilege, catalog, principal, true), - (authorizer, context, privilege, catalog, principal) -> authorizer.checkCanRevokeSchemaPrivilege( - context, privilege, catalog, principal, true)); - Stream actions = Stream.of( - "DenySchemaPrivilege", - "GrantSchemaPrivilege", - "RevokeSchemaPrivilege"); - return Streams.zip(actions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); - } - - @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#schemaPrivilegeTestCases") - public void testSchemaPrivileges( - String actionName, - FunctionalHelpers.Consumer5 method) - throws IOException - { - Privilege privilege = Privilege.CREATE; - method.accept( - authorizer, - requestingSecurityContext, - privilege, - new CatalogSchemaName("my_catalog", "my_schema"), - new TrinoPrincipal(PrincipalType.USER, "my_user")); - - String expectedRequest = """ - { - "operation": "%s", - "resource": { - "schema": { - "catalogName": "my_catalog", - "schemaName": "my_schema" - } - }, - "grantee": { - "principals": [ - { - "name": "my_user", - "type": "USER" - } - ], - "privilege": "CREATE", - "grantOption": true - } - } - """.formatted(actionName); - List actualRequests = mockClient.getRequests(); - assertEquals(actualRequests.size(), 1, "Unexpected number of requests"); - - JsonNode actualRequestInput = jsonMapper.readTree(mockClient.getRequests().get(0)).at("/input/action"); - if (!actualRequestInput.at("/grantee").has("grantOption")) { - // The DenySchemaPrivilege request does not have a grant option, we'll default it to true so we can use the same test - ((ObjectNode) actualRequestInput.at("/grantee")).put("grantOption", true); - } - assertEquals(jsonMapper.readTree(expectedRequest), actualRequestInput); - } - - private static Stream schemaPrivilegeFailureTestCases() - { - return createFailingTestCases(schemaPrivilegeTestCases()); - } - - @ParameterizedTest(name = "{index}: {0} - {2}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#schemaPrivilegeFailureTestCases") - public void testSchemaPrivilegesFailure( - String actionName, - FunctionalHelpers.Consumer5 method, - HttpClientUtils.MockResponse failureResponse, - Class expectedException, - String expectedErrorMessage) - { - mockClient.setHandler(request -> failureResponse); - - Privilege privilege = Privilege.CREATE; - Throwable actualError = assertThrows( - expectedException, - () -> method.accept( - authorizer, - requestingSecurityContext, - privilege, - new CatalogSchemaName("my_catalog", "my_schema"), - new TrinoPrincipal(PrincipalType.USER, "my_user"))); - assertTrue( - actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); - } - - private static Stream tablePrivilegeTestCases() - { - Stream> methods = Stream.of( - OpaAccessControl::checkCanDenyTablePrivilege, - (authorizer, context, privilege, catalog, principal) -> authorizer.checkCanGrantTablePrivilege(context, privilege, catalog, principal, true), - (authorizer, context, privilege, catalog, principal) -> authorizer.checkCanRevokeTablePrivilege(context, privilege, catalog, principal, true)); - Stream actions = Stream.of( - "DenyTablePrivilege", - "GrantTablePrivilege", - "RevokeTablePrivilege"); - return Streams.zip(actions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); - } - - @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tablePrivilegeTestCases") - public void testTablePrivileges( - String actionName, - FunctionalHelpers.Consumer5 method) - throws IOException - { - Privilege privilege = Privilege.CREATE; - method.accept( - authorizer, - requestingSecurityContext, - privilege, - new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"), - new TrinoPrincipal(PrincipalType.USER, "my_user")); - - String expectedRequest = """ - { - "operation": "%s", - "resource": { - "table": { - "catalogName": "my_catalog", - "schemaName": "my_schema", - "tableName": "my_table" - } - }, - "grantee": { - "principals": [ - { - "name": "my_user", - "type": "USER" - } - ], - "privilege": "CREATE", - "grantOption": true - } - } - """.formatted(actionName); - List actualRequests = mockClient.getRequests(); - assertEquals(actualRequests.size(), 1, "Unexpected number of requests"); - - JsonNode actualRequestInput = jsonMapper.readTree(mockClient.getRequests().get(0)).at("/input/action"); - if (!actualRequestInput.at("/grantee").has("grantOption")) { - // The DenySchemaPrivilege request does not have a grant option, we'll default it to true so we can use the same test - ((ObjectNode) actualRequestInput.at("/grantee")).put("grantOption", true); - } - assertEquals(jsonMapper.readTree(expectedRequest), actualRequestInput); - } - - private static Stream tablePrivilegeFailureTestCases() - { - return createFailingTestCases(tablePrivilegeTestCases()); - } - - @ParameterizedTest(name = "{index}: {0} - {2}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#tablePrivilegeFailureTestCases") - public void testTablePrivilegesFailure( - String actionName, - FunctionalHelpers.Consumer5 method, - HttpClientUtils.MockResponse failureResponse, - Class expectedException, - String expectedErrorMessage) - { - mockClient.setHandler(request -> failureResponse); - - Privilege privilege = Privilege.CREATE; - Throwable actualError = assertThrows( - expectedException, - () -> method.accept( - authorizer, - requestingSecurityContext, - privilege, - new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"), - new TrinoPrincipal(PrincipalType.USER, "my_user"))); - assertTrue( - actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); - } - - @Test - public void testCanCreateRole() - { - authorizer.checkCanCreateRole(requestingSecurityContext, "my_role_without_grantor", Optional.empty()); - TrinoPrincipal grantor = new TrinoPrincipal(PrincipalType.USER, "my_grantor"); - authorizer.checkCanCreateRole(requestingSecurityContext, "my_role_with_grantor", Optional.of(grantor)); - - Set expectedRequests = ImmutableSet.builder() - .add(""" - { - "operation": "CreateRole", - "resource": { - "role": { - "name": "my_role_without_grantor" - } - } - } - """) - .add(""" - { - "operation": "CreateRole", - "resource": { - "role": { - "name": "my_role_with_grantor" - } - }, - "grantor": { - "name": "my_grantor", - "type": "USER" - } - } - """) - .build(); - assertStringRequestsEqual(expectedRequests, mockClient.getRequests(), "/input/action"); - } - - @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.TestHelpers#allErrorCasesArgumentProvider") - public void testCanCreateRoleFailure( - HttpClientUtils.MockResponse failureResponse, - Class expectedException, - String expectedErrorMessage) - { - mockClient.setHandler(request -> failureResponse); - - Throwable actualError = assertThrows( - expectedException, - () -> authorizer.checkCanCreateRole( - requestingSecurityContext, - "my_role_without_grantor", - Optional.empty())); - assertTrue( - actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); - } - - private static Stream roleGrantingTestCases() - { - Stream, Set, Boolean, Optional>> methods = Stream.of( - OpaAccessControl::checkCanGrantRoles, - OpaAccessControl::checkCanRevokeRoles); - Stream actions = Stream.of( - "GrantRoles", - "RevokeRoles"); - return Streams.zip(actions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); - } - - @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#roleGrantingTestCases") - public void testRoleGranting( - String actionName, - FunctionalHelpers.Consumer6, Set, Boolean, Optional> method) - { - TrinoPrincipal grantee = new TrinoPrincipal(PrincipalType.ROLE, "my_grantee_role"); - method.accept(authorizer, requestingSecurityContext, ImmutableSet.of("my_role_without_grantor"), ImmutableSet.of(grantee), true, Optional.empty()); - - TrinoPrincipal grantor = new TrinoPrincipal(PrincipalType.USER, "my_grantor_user"); - method.accept(authorizer, requestingSecurityContext, ImmutableSet.of("my_role_with_grantor"), ImmutableSet.of(grantee), false, Optional.of(grantor)); - - Set expectedRequests = ImmutableSet.builder() - .add(""" - { - "operation": "%s", - "resource": { - "roles": [ - { - "name": "my_role_with_grantor" - } - ] - }, - "grantor": { - "name": "my_grantor_user", - "type": "USER" - }, - "grantee": { - "principals": [ - { - "name": "my_grantee_role", - "type": "ROLE" - } - ], - "grantOption": false - } - } - """.formatted(actionName)) - .add(""" - { - "operation": "%s", - "resource": { - "roles": [ - { - "name": "my_role_without_grantor" - } - ] - }, - "grantee": { - "principals": [ - { - "name": "my_grantee_role", - "type": "ROLE" - } - ], - "grantOption": true - } - } - """.formatted(actionName)) - .build(); - assertStringRequestsEqual(expectedRequests, mockClient.getRequests(), "/input/action"); - } - - private static Stream roleGrantingFailureTestCases() - { - return createFailingTestCases(roleGrantingTestCases()); - } - - @ParameterizedTest(name = "{index}: {0} - {2}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#roleGrantingFailureTestCases") - public void testRoleGrantingFailure( - String actionName, - FunctionalHelpers.Consumer6, Set, Boolean, Optional> method, - HttpClientUtils.MockResponse failureResponse, - Class expectedException, - String expectedErrorMessage) - { - mockClient.setHandler(request -> failureResponse); - - TrinoPrincipal grantee = new TrinoPrincipal(PrincipalType.ROLE, "my_grantee_role"); - Throwable actualError = assertThrows( - expectedException, - () -> method.accept( - authorizer, - requestingSecurityContext, - ImmutableSet.of("my_role_without_grantor"), - ImmutableSet.of(grantee), - true, - Optional.empty())); - assertTrue( - actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + "my_property")) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } private static Stream functionResourceTestCases() @@ -1338,16 +995,18 @@ private static Stream functionResourceTestCases() } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#functionResourceTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#functionResourceTestCases") public void testFunctionResourceAction( String actionName, TestHelpers.MethodWrapper method) { + InstrumentedHttpClient permissiveClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + InstrumentedHttpClient restrictiveClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, NO_ACCESS_RESPONSE)); + CatalogSchemaRoutineName routine = new CatalogSchemaRoutineName("my_catalog", "my_schema", "my_routine_name"); - assertTrue(method.isAccessAllowed(authorizer, requestingSecurityContext, routine)); + assertThat(method.isAccessAllowed(createOpaAuthorizer(OPA_SERVER_URI, permissiveClient), requestingSecurityContext, routine)).isTrue(); - mockClient.setHandler(request -> NO_ACCESS_RESPONSE); - assertFalse(method.isAccessAllowed(authorizer, requestingSecurityContext, routine)); + assertThat(method.isAccessAllowed(createOpaAuthorizer(OPA_SERVER_URI, restrictiveClient), requestingSecurityContext, routine)).isFalse(); String expectedRequest = """ { @@ -1360,8 +1019,8 @@ public void testFunctionResourceAction( } } }""".formatted(actionName); - assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); - assertEquals(mockClient.getRequests().size(), 2); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), permissiveClient.getRequests(), "/input/action"); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), restrictiveClient.getRequests(), "/input/action"); } private static Stream functionResourceIllegalResponseTestCases() @@ -1370,28 +1029,30 @@ private static Stream functionResourceIllegalResponseTestCases() } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlUnitTest#functionResourceIllegalResponseTestCases") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#functionResourceIllegalResponseTestCases") public void testFunctionResourceIllegalResponses( String actionName, TestHelpers.MethodWrapper method, - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); CatalogSchemaRoutineName routine = new CatalogSchemaRoutineName("my_catalog", "my_schema", "my_routine_name"); - Throwable actualError = assertThrows( - expectedException, - () -> method.isAccessAllowed(authorizer, requestingSecurityContext, routine)); - assertTrue( - actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + assertThatThrownBy( + () -> method.isAccessAllowed(authorizer, requestingSecurityContext, routine)) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } @Test public void testCanExecuteTableProcedure() { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); authorizer.checkCanExecuteTableProcedure(requestingSecurityContext, table, "my_procedure"); @@ -1415,40 +1076,33 @@ public void testCanExecuteTableProcedure() @ParameterizedTest(name = "{index}: {0}") @MethodSource("io.trino.plugin.opa.TestHelpers#allErrorCasesArgumentProvider") public void testCanExecuteTableProcedureFailure( - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); - Throwable actualError = assertThrows( - expectedException, + assertThatThrownBy( () -> authorizer.checkCanExecuteTableProcedure( requestingSecurityContext, table, - "my_procedure")); - assertTrue( - actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); + "my_procedure")) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } private static Stream noResourceActionTestCases() { - Stream> methods = Stream.of( - convertSystemSecurityContextToIdentityArgument(OpaAccessControl::checkCanExecuteQuery), - convertSystemSecurityContextToIdentityArgument(OpaAccessControl::checkCanReadSystemInformation), - convertSystemSecurityContextToIdentityArgument(OpaAccessControl::checkCanWriteSystemInformation), - OpaAccessControl::checkCanShowRoles, - OpaAccessControl::checkCanShowCurrentRoles, - OpaAccessControl::checkCanShowRoleGrants); + Stream> methods = Stream.of( + OpaAccessControl::checkCanExecuteQuery, + OpaAccessControl::checkCanReadSystemInformation, + OpaAccessControl::checkCanWriteSystemInformation); Stream expectedActions = Stream.of( "ExecuteQuery", "ReadSystemInformation", - "WriteSystemInformation", - "ShowRoles", - "ShowCurrentRoles", - "ShowRoleGrants"); + "WriteSystemInformation"); return Streams.zip(expectedActions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); } diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestFactory.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlFactory.java similarity index 71% rename from plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestFactory.java rename to plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlFactory.java index 50a81a7aaaeb7..e24fb0dfc5419 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestFactory.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlFactory.java @@ -18,11 +18,10 @@ import io.trino.spi.security.SystemAccessControl; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; -import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; -public class TestFactory +public class TestOpaAccessControlFactory { @Test public void testCreatesSimpleAuthorizerIfNoBatchUriProvided() @@ -30,8 +29,8 @@ public void testCreatesSimpleAuthorizerIfNoBatchUriProvided() OpaAccessControlFactory factory = new OpaAccessControlFactory(); SystemAccessControl opaAuthorizer = factory.create(ImmutableMap.of("opa.policy.uri", "foo")); - assertInstanceOf(OpaAccessControl.class, opaAuthorizer); - assertFalse(opaAuthorizer instanceof OpaBatchAccessControl); + assertThat(opaAuthorizer).isInstanceOf(OpaAccessControl.class); + assertThat(opaAuthorizer).isNotInstanceOf(OpaBatchAccessControl.class); } @Test @@ -44,8 +43,8 @@ public void testCreatesBatchAuthorizerIfBatchUriProvided() .put("opa.policy.batched-uri", "bar") .buildOrThrow()); - assertInstanceOf(OpaBatchAccessControl.class, opaAuthorizer); - assertInstanceOf(OpaAccessControl.class, opaAuthorizer); + assertThat(opaAuthorizer).isInstanceOf(OpaBatchAccessControl.class); + assertThat(opaAuthorizer).isInstanceOf(OpaAccessControl.class); } @Test @@ -53,10 +52,7 @@ public void testBasePolicyUriCannotBeUnset() { OpaAccessControlFactory factory = new OpaAccessControlFactory(); - assertThrows( - ApplicationConfigurationException.class, - () -> factory.create(ImmutableMap.of()), - "may not be null"); + assertThatThrownBy(() -> factory.create(ImmutableMap.of())).isInstanceOf(ApplicationConfigurationException.class); } @Test @@ -64,9 +60,7 @@ public void testConfigMayNotBeNull() { OpaAccessControlFactory factory = new OpaAccessControlFactory(); - assertThrows( - NullPointerException.class, - () -> factory.create(null)); + assertThatThrownBy(() -> factory.create(null)).isInstanceOf(NullPointerException.class); } @Test @@ -79,7 +73,7 @@ public void testSupportsAirliftHttpConfigs() .put("opa.http-client.log.enabled", "true") .buildOrThrow()); - assertInstanceOf(OpaAccessControl.class, opaAuthorizer); - assertFalse(opaAuthorizer instanceof OpaBatchAccessControl); + assertThat(opaAuthorizer).isInstanceOf(OpaAccessControl.class); + assertThat(opaAuthorizer).isNotInstanceOf(OpaBatchAccessControl.class); } } diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlFilteringUnitTest.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlFiltering.java similarity index 62% rename from plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlFilteringUnitTest.java rename to plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlFiltering.java index 85064653d6f2c..bd97a034e52b6 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlFilteringUnitTest.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlFiltering.java @@ -14,93 +14,67 @@ package io.trino.plugin.opa; import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.json.JsonMapper; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Maps; +import io.trino.plugin.opa.HttpClientUtils.InstrumentedHttpClient; +import io.trino.plugin.opa.HttpClientUtils.MockResponse; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.function.SchemaFunctionName; import io.trino.spi.security.Identity; import io.trino.spi.security.SystemSecurityContext; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; -import java.io.IOException; import java.net.URI; import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.function.BiFunction; import java.util.function.Function; -import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.net.MediaType.JSON_UTF_8; import static io.trino.plugin.opa.RequestTestUtilities.assertStringRequestsEqual; +import static io.trino.plugin.opa.RequestTestUtilities.buildValidatingRequestHandler; import static io.trino.plugin.opa.TestHelpers.NO_ACCESS_RESPONSE; import static io.trino.plugin.opa.TestHelpers.OK_RESPONSE; +import static io.trino.plugin.opa.TestHelpers.createMockHttpClient; +import static io.trino.plugin.opa.TestHelpers.createOpaAuthorizer; import static io.trino.plugin.opa.TestHelpers.systemSecurityContextFromIdentity; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; -public class OpaAccessControlFilteringUnitTest +public class TestOpaAccessControlFiltering { private static final URI OPA_SERVER_URI = URI.create("http://my-uri/"); - private HttpClientUtils.InstrumentedHttpClient mockClient; - private OpaAccessControl authorizer; - private final JsonMapper jsonMapper = new JsonMapper(); - private Identity requestingIdentity; - private SystemSecurityContext requestingSecurityContext; - - @BeforeEach - public void setupAuthorizer() - { - this.mockClient = new HttpClientUtils.InstrumentedHttpClient(OPA_SERVER_URI, "POST", JSON_UTF_8.toString(), request -> OK_RESPONSE); - this.authorizer = (OpaAccessControl) OpaAccessControlFactory.create(ImmutableMap.of("opa.policy.uri", OPA_SERVER_URI.toString()), Optional.of(mockClient)); - this.requestingIdentity = Identity.ofUser("source-user"); - this.requestingSecurityContext = systemSecurityContextFromIdentity(requestingIdentity); - } - - @AfterEach - public void ensureRequestContextCorrect() - throws IOException - { - for (String request : mockClient.getRequests()) { - JsonNode parsedRequest = jsonMapper.readTree(request); - assertEquals(parsedRequest.at("/input/context/identity/user").asText(), requestingIdentity.getUser()); - } - } + private final Identity requestingIdentity = Identity.ofUser("source-user"); + private final SystemSecurityContext requestingSecurityContext = systemSecurityContextFromIdentity(requestingIdentity); @Test public void testFilterViewQueryOwnedBy() { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildHandler("/input/action/resource/user/user", "user-one")); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + Identity userOne = Identity.ofUser("user-one"); Identity userTwo = Identity.ofUser("user-two"); List requestedIdentities = ImmutableList.of(userOne, userTwo); - this.mockClient.setHandler(buildHandler("/input/action/resource/user/user", "user-one")); Collection result = authorizer.filterViewQueryOwnedBy( requestingIdentity, requestedIdentities); - assertEquals(ImmutableSet.copyOf(result), ImmutableSet.of(userOne)); + assertThat(result).containsExactly(userOne); - List expectedRequests = ImmutableList.builder() + Set expectedRequests = ImmutableSet.builder() .add(""" { "operation": "FilterViewQueryOwnedBy", "resource": { "user": { "user": "user-one", - "groups": [], - "extraCredentials": {} + "groups": [] } } } @@ -111,28 +85,28 @@ public void testFilterViewQueryOwnedBy() "resource": { "user": { "user": "user-two", - "groups": [], - "extraCredentials": {} + "groups": [] } } } """) .build(); - assertStringRequestsEqual(expectedRequests, this.mockClient.getRequests(), "/input/action"); + assertStringRequestsEqual(expectedRequests, mockClient.getRequests(), "/input/action"); } @Test public void testFilterCatalogs() { - Set requestedCatalogs = ImmutableSet.of("catalog_one", "catalog_two"); - this.mockClient.setHandler(buildHandler("/input/action/resource/catalog/name", "catalog_two")); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildHandler("/input/action/resource/catalog/name", "catalog_two")); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + Set requestedCatalogs = ImmutableSet.of("catalog_one", "catalog_two"); Set result = authorizer.filterCatalogs( requestingSecurityContext, requestedCatalogs); - assertEquals(ImmutableSet.copyOf(result), ImmutableSet.of("catalog_two")); + assertThat(result).containsExactly("catalog_two"); - List expectedRequests = ImmutableList.builder() + Set expectedRequests = ImmutableSet.builder() .add(""" { "operation": "FilterCatalogs", @@ -154,22 +128,24 @@ public void testFilterCatalogs() } """) .build(); - assertStringRequestsEqual(expectedRequests, this.mockClient.getRequests(), "/input/action"); + assertStringRequestsEqual(expectedRequests, mockClient.getRequests(), "/input/action"); } @Test public void testFilterSchemas() { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildHandler("/input/action/resource/schema/schemaName", "schema_one")); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + Set requestedSchemas = ImmutableSet.of("schema_one", "schema_two"); - this.mockClient.setHandler(buildHandler("/input/action/resource/schema/schemaName", "schema_one")); Set result = authorizer.filterSchemas( requestingSecurityContext, "my_catalog", requestedSchemas); - assertEquals(ImmutableSet.copyOf(result), ImmutableSet.of("schema_one")); + assertThat(result).containsExactly("schema_one"); - List expectedRequests = requestedSchemas.stream() + Set expectedRequests = requestedSchemas.stream() .map(""" { "operation": "FilterSchemas", @@ -181,8 +157,8 @@ public void testFilterSchemas() } } """::formatted) - .collect(toImmutableList()); - assertStringRequestsEqual(expectedRequests, this.mockClient.getRequests(), "/input/action"); + .collect(toImmutableSet()); + assertStringRequestsEqual(expectedRequests, mockClient.getRequests(), "/input/action"); } @Test @@ -194,12 +170,13 @@ public void testFilterTables() .add(new SchemaTableName("schema_two", "table_one")) .add(new SchemaTableName("schema_two", "table_two")) .build(); - this.mockClient.setHandler(buildHandler("/input/action/resource/table/tableName", "table_one")); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildHandler("/input/action/resource/table/tableName", "table_one")); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); Set result = authorizer.filterTables(requestingSecurityContext, "my_catalog", tables); - assertEquals(ImmutableSet.copyOf(result), tables.stream().filter(table -> table.getTableName().equals("table_one")).collect(toImmutableSet())); + assertThat(result).containsExactlyInAnyOrderElementsOf(tables.stream().filter(table -> table.getTableName().equals("table_one")).collect(toImmutableSet())); - List expectedRequests = tables.stream() + Set expectedRequests = tables.stream() .map(table -> """ { "operation": "FilterTables", @@ -212,8 +189,8 @@ public void testFilterTables() } } """.formatted(table.getTableName(), table.getSchemaName())) - .collect(toImmutableList()); - assertStringRequestsEqual(expectedRequests, this.mockClient.getRequests(), "/input/action"); + .collect(toImmutableSet()); + assertStringRequestsEqual(expectedRequests, mockClient.getRequests(), "/input/action"); } @Test @@ -234,11 +211,12 @@ public void testFilterColumns() .add("table_two_column_two") .build(); - this.mockClient.setHandler(buildHandler("/input/action/resource/table/columns/0", columnsToAllow)); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildHandler("/input/action/resource/table/columns/0", columnsToAllow)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); Map> result = authorizer.filterColumns(requestingSecurityContext, "my_catalog", requestedColumns); - List expectedRequests = requestedColumns.entrySet().stream() + Set expectedRequests = requestedColumns.entrySet().stream() .mapMulti( (requestedColumnsForTable, accepter) -> requestedColumnsForTable.getValue().forEach( column -> accepter.accept(""" @@ -254,20 +232,21 @@ public void testFilterColumns() } } """.formatted(requestedColumnsForTable.getKey().getTableName(), column)))) - .collect(toImmutableList()); - assertStringRequestsEqual(expectedRequests, this.mockClient.getRequests(), "/input/action"); - assertTrue( - Maps.difference( - result, - ImmutableMap.builder() - .put(tableOne, ImmutableSet.of("table_one_column_one", "table_one_column_two")) - .put(tableTwo, ImmutableSet.of("table_two_column_two")) - .buildOrThrow()).areEqual()); + .collect(toImmutableSet()); + assertStringRequestsEqual(expectedRequests, mockClient.getRequests(), "/input/action"); + assertThat(result).containsExactlyInAnyOrderEntriesOf( + ImmutableMap.>builder() + .put(tableOne, ImmutableSet.of("table_one_column_one", "table_one_column_two")) + .put(tableTwo, ImmutableSet.of("table_two_column_two")) + .buildOrThrow()); } @Test public void testEmptyFilterColumns() { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, request -> OK_RESPONSE); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + SchemaTableName someTable = SchemaTableName.schemaTableName("my_schema", "my_table"); Map> requestedColumns = ImmutableMap.of(someTable, ImmutableSet.of()); @@ -276,8 +255,8 @@ public void testEmptyFilterColumns() "my_catalog", requestedColumns); - assertTrue(mockClient.getRequests().isEmpty()); - assertTrue(result.isEmpty()); + assertThat(mockClient.getRequests()).isEmpty(); + assertThat(result).isEmpty(); } @Test @@ -286,15 +265,17 @@ public void testFilterFunctions() SchemaFunctionName functionOne = new SchemaFunctionName("my_schema", "function_one"); SchemaFunctionName functionTwo = new SchemaFunctionName("my_schema", "function_two"); Set requestedFunctions = ImmutableSet.of(functionOne, functionTwo); - this.mockClient.setHandler(buildHandler("/input/action/resource/function/functionName", "function_two")); + + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildHandler("/input/action/resource/function/functionName", "function_two")); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); Set result = authorizer.filterFunctions( requestingSecurityContext, "my_catalog", requestedFunctions); - assertEquals(ImmutableSet.copyOf(result), ImmutableSet.of(functionTwo)); + assertThat(result).containsExactly(functionTwo); - List expectedRequests = requestedFunctions.stream() + Set expectedRequests = requestedFunctions.stream() .map(function -> """ { "operation": "FilterFunctions", @@ -306,8 +287,8 @@ public void testFilterFunctions() } } }""".formatted(function.getSchemaName(), function.getFunctionName())) - .collect(toImmutableList()); - assertStringRequestsEqual(expectedRequests, this.mockClient.getRequests(), "/input/action"); + .collect(toImmutableSet()); + assertStringRequestsEqual(expectedRequests, mockClient.getRequests(), "/input/action"); } @ParameterizedTest(name = "{index}: {0}") @@ -315,47 +296,43 @@ public void testFilterFunctions() public void testEmptyRequests( BiFunction callable) { - Collection result = callable.apply(authorizer, requestingSecurityContext); - assertTrue(result.isEmpty()); - assertTrue(mockClient.getRequests().isEmpty()); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, request -> OK_RESPONSE); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + + Collection result = callable.apply(authorizer, requestingSecurityContext); + assertThat(result).isEmpty(); + assertThat(mockClient.getRequests()).isEmpty(); } @ParameterizedTest(name = "{index}: {0} - {1}") @MethodSource("io.trino.plugin.opa.FilteringTestHelpers#prepopulatedErrorCases") public void testIllegalResponseThrows( BiFunction callable, - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); - - Throwable actualError = assertThrows( - expectedException, - () -> callable.apply(authorizer, requestingSecurityContext)); - assertTrue( - actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); - assertEquals(mockClient.getRequests().size(), 1); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + + assertThatThrownBy(() -> callable.apply(authorizer, requestingSecurityContext)) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); + assertThat(mockClient.getRequests()).hasSize(1); } - private Function buildHandler(String jsonPath, Set resourcesToAccept) + private Function buildHandler(String jsonPath, Set resourcesToAccept) { - return request -> { - try { - JsonNode parsedRequest = this.jsonMapper.readTree(request); - String requestedItem = parsedRequest.at(jsonPath).asText(); - if (resourcesToAccept.contains(requestedItem)) { - return OK_RESPONSE; - } - } - catch (IOException e) { - fail("Could not parse request"); + return buildValidatingRequestHandler(requestingIdentity, parsedRequest -> { + String requestedItem = parsedRequest.at(jsonPath).asText(); + if (resourcesToAccept.contains(requestedItem)) { + return OK_RESPONSE; } return NO_ACCESS_RESPONSE; - }; + }); } - private Function buildHandler(String jsonPath, String resourceToAccept) + + private Function buildHandler(String jsonPath, String resourceToAccept) { return buildHandler(jsonPath, ImmutableSet.of(resourceToAccept)); } diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlPermissioningOperations.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlPermissioningOperations.java new file mode 100644 index 0000000000000..a3678e9ea8728 --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlPermissioningOperations.java @@ -0,0 +1,147 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.plugin.opa.HttpClientUtils.InstrumentedHttpClient; +import io.trino.spi.connector.CatalogSchemaName; +import io.trino.spi.connector.CatalogSchemaTableName; +import io.trino.spi.security.AccessDeniedException; +import io.trino.spi.security.Identity; +import io.trino.spi.security.PrincipalType; +import io.trino.spi.security.Privilege; +import io.trino.spi.security.SystemSecurityContext; +import io.trino.spi.security.TrinoPrincipal; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.Optional; +import java.util.Set; +import java.util.function.Consumer; + +import static io.trino.plugin.opa.TestHelpers.createMockHttpClient; +import static io.trino.plugin.opa.TestHelpers.systemSecurityContextFromIdentity; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestOpaAccessControlPermissioningOperations +{ + private static final URI OPA_SERVER_URI = URI.create("http://my-uri/"); + private static final Identity REQUESTING_IDENTITY = Identity.ofUser("source-user"); + private static final SystemSecurityContext REQUESTING_SECURITY_CONTEXT = systemSecurityContextFromIdentity(REQUESTING_IDENTITY); + + @Test + public void testTablePrivilegeGrantingOperationsDeniedOrAllowedByConfig() + { + CatalogSchemaTableName sampleTableName = new CatalogSchemaTableName("some_catalog", "some_schema", "some_table"); + TrinoPrincipal samplePrincipal = new TrinoPrincipal(PrincipalType.USER, "some_user"); + + testOperationAllowedOrDeniedByConfig( + authorizer -> authorizer.checkCanGrantTablePrivilege(REQUESTING_SECURITY_CONTEXT, Privilege.CREATE, sampleTableName, samplePrincipal, false)); + testOperationAllowedOrDeniedByConfig( + authorizer -> authorizer.checkCanRevokeTablePrivilege(REQUESTING_SECURITY_CONTEXT, Privilege.CREATE, sampleTableName, samplePrincipal, false)); + testOperationAllowedOrDeniedByConfig( + authorizer -> authorizer.checkCanDenyTablePrivilege(REQUESTING_SECURITY_CONTEXT, Privilege.CREATE, sampleTableName, samplePrincipal)); + } + + @Test + public void testSchemaPrivilegeGrantingOperationsDeniedOrAllowedByConfig() + { + CatalogSchemaName sampleSchemaName = new CatalogSchemaName("some_catalog", "some_schema"); + TrinoPrincipal samplePrincipal = new TrinoPrincipal(PrincipalType.USER, "some_user"); + + testOperationAllowedOrDeniedByConfig( + authorizer -> authorizer.checkCanGrantSchemaPrivilege(REQUESTING_SECURITY_CONTEXT, Privilege.CREATE, sampleSchemaName, samplePrincipal, false)); + testOperationAllowedOrDeniedByConfig( + authorizer -> authorizer.checkCanRevokeSchemaPrivilege(REQUESTING_SECURITY_CONTEXT, Privilege.CREATE, sampleSchemaName, samplePrincipal, false)); + testOperationAllowedOrDeniedByConfig( + authorizer -> authorizer.checkCanDenySchemaPrivilege(REQUESTING_SECURITY_CONTEXT, Privilege.CREATE, sampleSchemaName, samplePrincipal)); + } + + @Test + public void testCanCreateRoleAllowedOrDeniedByConfig() + { + testOperationAllowedOrDeniedByConfig( + authorizer -> authorizer.checkCanCreateRole(REQUESTING_SECURITY_CONTEXT, "some_role", Optional.empty())); + } + + @Test + public void testCanDropRoleAllowedOrDeniedByConfig() + { + testOperationAllowedOrDeniedByConfig( + authorizer -> authorizer.checkCanDropRole(REQUESTING_SECURITY_CONTEXT, "some_role")); + } + + @Test + public void testCanGrantRolesAllowedOrDeniedByConfig() + { + Set roles = ImmutableSet.of("role_one", "role_two"); + Set grantees = ImmutableSet.of(new TrinoPrincipal(PrincipalType.USER, "some_principal")); + testOperationAllowedOrDeniedByConfig( + authorizer -> authorizer.checkCanGrantRoles(REQUESTING_SECURITY_CONTEXT, roles, grantees, true, Optional.empty())); + } + + private static void testOperationAllowedOrDeniedByConfig(Consumer methodToTest) + { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, request -> null); + OpaAccessControl permissiveAuthorizer = createAuthorizer(true, mockClient); + OpaAccessControl restrictiveAuthorizer = createAuthorizer(false, mockClient); + + methodToTest.accept(permissiveAuthorizer); + assertThatThrownBy(() -> methodToTest.accept(restrictiveAuthorizer)) + .isInstanceOf(AccessDeniedException.class) + .hasMessageContaining("Access Denied:"); + assertThat(mockClient.getRequests()).isEmpty(); + } + + @Test + public void testShowRolesAlwaysAllowedRegardlessOfConfig() + { + testOperationAlwaysAllowedRegardlessOfConfig(authorizer -> authorizer.checkCanShowRoles(REQUESTING_SECURITY_CONTEXT)); + } + + @Test + public void testShowCurrentRolesAlwaysAllowedRegardlessOfConfig() + { + testOperationAlwaysAllowedRegardlessOfConfig(authorizer -> authorizer.checkCanShowCurrentRoles(REQUESTING_SECURITY_CONTEXT)); + } + + @Test + public void testShowRoleGrantsAlwaysAllowedRegardlessOfConfig() + { + testOperationAlwaysAllowedRegardlessOfConfig(authorizer -> authorizer.checkCanShowRoleGrants(REQUESTING_SECURITY_CONTEXT)); + } + + private static void testOperationAlwaysAllowedRegardlessOfConfig(Consumer methodToTest) + { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, request -> null); + OpaAccessControl permissiveAuthorizer = createAuthorizer(true, mockClient); + OpaAccessControl restrictiveAuthorizer = createAuthorizer(false, mockClient); + methodToTest.accept(permissiveAuthorizer); + methodToTest.accept(restrictiveAuthorizer); + + assertThat(mockClient.getRequests()).isEmpty(); + } + + private static OpaAccessControl createAuthorizer(boolean allowPermissioningOperations, InstrumentedHttpClient mockClient) + { + return (OpaAccessControl) OpaAccessControlFactory.create( + ImmutableMap.builder() + .put("opa.policy.uri", OPA_SERVER_URI.toString()) + .put("opa.allow-permissioning-operations", String.valueOf(allowPermissioningOperations)) + .buildOrThrow(), + Optional.of(mockClient)); + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlPlugin.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlPlugin.java new file mode 100644 index 0000000000000..37ef92a11c980 --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlPlugin.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.common.collect.ImmutableMap; +import io.trino.spi.Plugin; +import io.trino.spi.security.SystemAccessControlFactory; +import org.junit.jupiter.api.Test; + +import static com.google.common.collect.Iterables.getOnlyElement; + +public class TestOpaAccessControlPlugin +{ + @Test + public void testCreatePlugin() + { + Plugin opaPlugin = new OpaAccessControlPlugin(); + SystemAccessControlFactory factory = getOnlyElement(opaPlugin.getSystemAccessControlFactories()); + factory.create(ImmutableMap.of("opa.policy.uri", "http://test/")); + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlSystemTest.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlSystem.java similarity index 86% rename from plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlSystemTest.java rename to plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlSystem.java index d7c3ff6954fe1..b174efd29e0a0 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaAccessControlSystemTest.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlSystem.java @@ -49,22 +49,20 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.plugin.opa.FunctionalHelpers.Pair; import static io.trino.testing.TestingSession.testSessionBuilder; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; @Testcontainers @TestInstance(PER_CLASS) -public class OpaAccessControlSystemTest +public class TestOpaAccessControlSystem { private URI opaServerUri; private DistributedQueryRunner runner; private static final int OPA_PORT = 8181; @Container - public static GenericContainer opaContainer = new GenericContainer<>(DockerImageName.parse("openpolicyagent/opa:latest-rootless")) + private static final GenericContainer OPA_CONTAINER = new GenericContainer<>(DockerImageName.parse("openpolicyagent/opa:latest-rootless")) .withCommand("run", "--server", "--addr", ":%d".formatted(OPA_PORT)) .withExposedPorts(OPA_PORT); @@ -89,7 +87,7 @@ public void teardown() } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlSystemTest#filterSchemaTests") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControlSystem#filterSchemaTests") public void testAllowsQueryAndFilters(String userName, Set expectedCatalogs) throws IOException, InterruptedException { @@ -120,7 +118,7 @@ public void testAllowsQueryAndFilters(String userName, Set expectedCatal } """); Set catalogs = querySetOfStrings(user(userName), "SHOW CATALOGS"); - assertEquals(expectedCatalogs, catalogs); + assertThat(catalogs).containsExactlyInAnyOrderElementsOf(expectedCatalogs); } @Test @@ -136,11 +134,9 @@ public void testShouldDenyQueryIfDirected() input.context.identity.user in ["someone", "admin"] } """); - RuntimeException error = assertThrows(RuntimeException.class, () -> { - runner.execute(user("bob"), "SHOW CATALOGS"); - }); - assertTrue(error.getMessage().contains("Access Denied"), - "Error must mention 'Access Denied': " + error.getMessage()); + assertThatThrownBy(() -> runner.execute(user("bob"), "SHOW CATALOGS")) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("Access Denied"); // smoke test: we can still query if we are the right user runner.execute(user("admin"), "SHOW CATALOGS"); } @@ -167,7 +163,7 @@ public void teardown() } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaAccessControlSystemTest#filterSchemaTests") + @MethodSource("io.trino.plugin.opa.TestOpaAccessControlSystem#filterSchemaTests") public void testFilterOutItemsBatch(String userName, Set expectedCatalogs) throws IOException, InterruptedException { @@ -206,7 +202,7 @@ public void testFilterOutItemsBatch(String userName, Set expectedCatalog } """); Set catalogs = querySetOfStrings(user(userName), "SHOW CATALOGS"); - assertEquals(expectedCatalogs, catalogs); + assertThat(catalogs).containsExactlyInAnyOrderElementsOf(expectedCatalogs); } @Test @@ -218,11 +214,9 @@ public void testDenyUnbatchedQuery() import future.keywords.in default allow = false """); - RuntimeException error = assertThrows(RuntimeException.class, () -> { - runner.execute(user("bob"), "SELECT version()"); - }); - assertTrue(error.getMessage().contains("Access Denied"), - "Error must mention 'Access Denied': " + error.getMessage()); + assertThatThrownBy(() -> runner.execute(user("bob"), "SELECT version()")) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("Access Denied"); } @Test @@ -239,15 +233,15 @@ public void testAllowUnbatchedQuery() } """); Set version = querySetOfStrings(user("bob"), "SELECT version()"); - assertFalse(version.isEmpty()); + assertThat(version).isNotEmpty(); } } private void ensureOpaUp() throws IOException, InterruptedException { - assertTrue(opaContainer.isRunning()); - InetSocketAddress opaSocket = new InetSocketAddress(opaContainer.getHost(), opaContainer.getMappedPort(OPA_PORT)); + assertThat(OPA_CONTAINER.isRunning()).isTrue(); + InetSocketAddress opaSocket = new InetSocketAddress(OPA_CONTAINER.getHost(), OPA_CONTAINER.getMappedPort(OPA_PORT)); String opaEndpoint = String.format("%s:%d", opaSocket.getHostString(), opaSocket.getPort()); awaitSocketOpen(opaSocket, 100, 200); this.opaServerUri = URI.create(String.format("http://%s/", opaEndpoint)); @@ -307,7 +301,7 @@ private void submitPolicy(String... policyLines) .PUT(HttpRequest.BodyPublishers.ofString(stringOfLines(policyLines))) .header("Content-Type", "text/plain").build(), HttpResponse.BodyHandlers.ofString()); - assertEquals(policyResponse.statusCode(), 200, "Failed to submit policy: " + policyResponse.body()); + assertThat(policyResponse.statusCode()).withFailMessage("Failed to submit policy: %s", policyResponse.body()).isEqualTo(200); } private Session user(String user) @@ -325,6 +319,6 @@ private static Stream filterSchemaTests() Stream>> userAndExpectedCatalogs = Stream.of( Pair.of("bob", ImmutableSet.of("catalog_one")), Pair.of("admin", ImmutableSet.of("catalog_one", "catalog_two", "system"))); - return userAndExpectedCatalogs.map(testCase -> Arguments.of(Named.of(testCase.getFirst(), testCase.getFirst()), testCase.getSecond())); + return userAndExpectedCatalogs.map(testCase -> Arguments.of(Named.of(testCase.first(), testCase.first()), testCase.second())); } } diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaBatchAccessControlFilteringUnitTest.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaBatchAccessControlFiltering.java similarity index 57% rename from plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaBatchAccessControlFilteringUnitTest.java rename to plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaBatchAccessControlFiltering.java index 4d004e427d159..6bfa78c190d55 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaBatchAccessControlFilteringUnitTest.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaBatchAccessControlFiltering.java @@ -13,95 +13,57 @@ */ package io.trino.plugin.opa; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.json.JsonMapper; -import com.fasterxml.jackson.databind.node.ArrayNode; -import com.fasterxml.jackson.databind.node.ObjectNode; -import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Maps; -import io.trino.plugin.opa.schema.TrinoUser; +import io.trino.plugin.opa.HttpClientUtils.InstrumentedHttpClient; +import io.trino.plugin.opa.HttpClientUtils.MockResponse; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.function.SchemaFunctionName; import io.trino.spi.security.Identity; import io.trino.spi.security.SystemSecurityContext; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Named; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import java.io.IOException; import java.net.URI; import java.util.ArrayList; import java.util.Collection; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Stream; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.net.MediaType.JSON_UTF_8; -import static io.trino.plugin.opa.RequestTestUtilities.assertJsonRequestsEqual; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.plugin.opa.RequestTestUtilities.assertStringRequestsEqual; +import static io.trino.plugin.opa.RequestTestUtilities.buildValidatingRequestHandler; +import static io.trino.plugin.opa.TestHelpers.OK_RESPONSE; +import static io.trino.plugin.opa.TestHelpers.createMockHttpClient; +import static io.trino.plugin.opa.TestHelpers.createOpaAuthorizer; import static io.trino.plugin.opa.TestHelpers.systemSecurityContextFromIdentity; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; -public class OpaBatchAccessControlFilteringUnitTest +public class TestOpaBatchAccessControlFiltering { private static final URI OPA_SERVER_URI = URI.create("http://my-uri/"); private static final URI OPA_BATCH_SERVER_URI = URI.create("http://my-uri/batchAllow"); - private HttpClientUtils.InstrumentedHttpClient mockClient; - private OpaAccessControl authorizer; - private final JsonMapper jsonMapper = new JsonMapper(); - private Identity requestingIdentity; - private SystemSecurityContext requestingSecurityContext; - - @BeforeEach - public void setupAuthorizer() - { - this.jsonMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); - this.jsonMapper.registerModule(new Jdk8Module()); - this.mockClient = new HttpClientUtils.InstrumentedHttpClient(OPA_BATCH_SERVER_URI, "POST", JSON_UTF_8.toString(), request -> null); - this.authorizer = (OpaAccessControl) new OpaAccessControlFactory().create( - ImmutableMap.builder() - .put("opa.policy.uri", OPA_SERVER_URI.toString()) - .put("opa.policy.batched-uri", OPA_BATCH_SERVER_URI.toString()) - .buildOrThrow(), - Optional.of(mockClient)); - this.requestingIdentity = Identity.ofUser("source-user"); - this.requestingSecurityContext = systemSecurityContextFromIdentity(requestingIdentity); - } - - @AfterEach - public void ensureRequestContextCorrect() - throws IOException - { - for (String request : mockClient.getRequests()) { - JsonNode parsedRequest = jsonMapper.readTree(request); - assertEquals(parsedRequest.at("/input/context/identity/user").asText(), requestingIdentity.getUser()); - } - } + private final Identity requestingIdentity = Identity.ofUser("source-user"); + private final SystemSecurityContext requestingSecurityContext = systemSecurityContextFromIdentity(requestingIdentity); @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaBatchAccessControlFilteringUnitTest#subsetProvider") + @MethodSource("io.trino.plugin.opa.TestOpaBatchAccessControlFiltering#subsetProvider") public void testFilterViewQueryOwnedBy( - HttpClientUtils.MockResponse response, + MockResponse response, List expectedItems) { - this.mockClient.setHandler(request -> response); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, response)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); Identity identityOne = Identity.ofUser("user-one"); Identity identityTwo = Identity.ofUser("user-two"); @@ -109,33 +71,50 @@ public void testFilterViewQueryOwnedBy( List requestedIdentities = ImmutableList.of(identityOne, identityTwo, identityThree); Collection result = authorizer.filterViewQueryOwnedBy(requestingIdentity, requestedIdentities); - assertEquals(ImmutableSet.copyOf(result), ImmutableSet.copyOf(getSubset(requestedIdentities, expectedItems))); - - ArrayNode allExpectedUsers = jsonMapper.createArrayNode().addAll( - requestedIdentities.stream() - .map(TrinoUser::new) - .map(user -> encodeObjectWithKey(user, "user")) - .collect(toImmutableList())); - ObjectNode expectedRequest = jsonMapper.createObjectNode() - .put("operation", "FilterViewQueryOwnedBy") - .set("filterResources", allExpectedUsers); - assertJsonRequestsEqual(ImmutableSet.of(expectedRequest), this.mockClient.getRequests(), "/input/action"); + assertThat(result).containsExactlyInAnyOrderElementsOf(getSubset(requestedIdentities, expectedItems)); + + String expectedRequest = """ + { + "operation": "FilterViewQueryOwnedBy", + "filterResources": [ + { + "user": { + "user": "user-one", + "groups": [] + } + }, + { + "user": { + "user": "user-two", + "groups": [] + } + }, + { + "user": { + "user": "user-three", + "groups": [] + } + } + ] + }"""; + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaBatchAccessControlFilteringUnitTest#subsetProvider") + @MethodSource("io.trino.plugin.opa.TestOpaBatchAccessControlFiltering#subsetProvider") public void testFilterCatalogs( - HttpClientUtils.MockResponse response, + MockResponse response, List expectedItems) { - this.mockClient.setHandler(request -> response); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, response)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); List requestedCatalogs = ImmutableList.of("catalog_one", "catalog_two", "catalog_three"); Set result = authorizer.filterCatalogs( requestingSecurityContext, new LinkedHashSet<>(requestedCatalogs)); - assertEquals(ImmutableSet.copyOf(result), ImmutableSet.copyOf(getSubset(requestedCatalogs, expectedItems))); + assertThat(result).containsExactlyInAnyOrderElementsOf(getSubset(requestedCatalogs, expectedItems)); String expectedRequest = """ { @@ -158,23 +137,24 @@ public void testFilterCatalogs( } ] }"""; - assertStringRequestsEqual(ImmutableList.of(expectedRequest), this.mockClient.getRequests(), "/input/action"); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaBatchAccessControlFilteringUnitTest#subsetProvider") + @MethodSource("io.trino.plugin.opa.TestOpaBatchAccessControlFiltering#subsetProvider") public void testFilterSchemas( - HttpClientUtils.MockResponse response, + MockResponse response, List expectedItems) { - this.mockClient.setHandler(request -> response); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, response)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); List requestedSchemas = ImmutableList.of("schema_one", "schema_two", "schema_three"); Set result = authorizer.filterSchemas( requestingSecurityContext, "my_catalog", new LinkedHashSet<>(requestedSchemas)); - assertEquals(ImmutableSet.copyOf(result), ImmutableSet.copyOf(getSubset(requestedSchemas, expectedItems))); + assertThat(result).containsExactlyInAnyOrderElementsOf(getSubset(requestedSchemas, expectedItems)); String expectedRequest = """ { @@ -200,16 +180,17 @@ public void testFilterSchemas( } ] }"""; - assertStringRequestsEqual(ImmutableList.of(expectedRequest), this.mockClient.getRequests(), "/input/action"); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaBatchAccessControlFilteringUnitTest#subsetProvider") + @MethodSource("io.trino.plugin.opa.TestOpaBatchAccessControlFiltering#subsetProvider") public void testFilterTables( - HttpClientUtils.MockResponse response, + MockResponse response, List expectedItems) { - this.mockClient.setHandler(request -> response); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, response)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); List tables = ImmutableList.builder() .add(new SchemaTableName("schema_one", "table_one")) .add(new SchemaTableName("schema_one", "table_two")) @@ -220,7 +201,7 @@ public void testFilterTables( requestingSecurityContext, "my_catalog", new LinkedHashSet<>(tables)); - assertEquals(ImmutableSet.copyOf(result), ImmutableSet.copyOf(getSubset(tables, expectedItems))); + assertThat(result).containsExactlyInAnyOrderElementsOf(getSubset(tables, expectedItems)); String expectedRequest = """ { @@ -249,12 +230,12 @@ public void testFilterTables( } ] }"""; - assertStringRequestsEqual(ImmutableList.of(expectedRequest), this.mockClient.getRequests(), "/input/action"); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); } - private static Function buildHandler(Function dataBuilder) + private static Function buildHandler(Function dataBuilder) { - return request -> new HttpClientUtils.MockResponse(dataBuilder.apply(request), 200); + return request -> new MockResponse(dataBuilder.apply(request), 200); } @Test @@ -270,23 +251,26 @@ public void testFilterColumns() .buildOrThrow(); // Allow both columns from one table, one column from another one and no columns from the last one - this.mockClient.setHandler( - buildHandler( - request -> { - if (request.contains("table_one")) { - return "{\"result\": [0, 1]}"; - } else if (request.contains("table_two")) { - return "{\"result\": [1]}"; - } - return "{\"result\": []}"; + InstrumentedHttpClient mockClient = createMockHttpClient( + OPA_BATCH_SERVER_URI, + buildValidatingRequestHandler( + requestingIdentity, + parsedRequest -> { + String tableName = parsedRequest.at("/input/action/filterResources/0/table/tableName").asText(); + String responseContents = switch(tableName) { + case "table_one" -> "{\"result\": [0, 1]}"; + case "table_two" -> "{\"result\": [1]}"; + default -> "{\"result\": []}"; + }; + return new MockResponse(responseContents, 200); })); - + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); Map> result = authorizer.filterColumns( requestingSecurityContext, "my_catalog", requestedColumns); - List expectedRequests = Stream.of("table_one", "table_two", "table_three") + Set expectedRequests = Stream.of("table_one", "table_two", "table_three") .map(tableName -> """ { "operation": "FilterColumns", @@ -302,23 +286,23 @@ public void testFilterColumns() ] } """.formatted(tableName, tableName, tableName)) - .collect(toImmutableList()); - assertStringRequestsEqual(expectedRequests, this.mockClient.getRequests(), "/input/action"); - assertTrue(Maps.difference( - result, - ImmutableMap.builder() + .collect(toImmutableSet()); + assertStringRequestsEqual(expectedRequests, mockClient.getRequests(), "/input/action"); + assertThat(result).containsExactlyInAnyOrderEntriesOf( + ImmutableMap.>builder() .put(tableOne, ImmutableSet.of("table_one_column_one", "table_one_column_two")) .put(tableTwo, ImmutableSet.of("table_two_column_two")) - .buildOrThrow()).areEqual()); + .buildOrThrow()); } @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.OpaBatchAccessControlFilteringUnitTest#subsetProvider") + @MethodSource("io.trino.plugin.opa.TestOpaBatchAccessControlFiltering#subsetProvider") public void testFilterFunctions( - HttpClientUtils.MockResponse response, + MockResponse response, List expectedItems) { - this.mockClient.setHandler(request -> response); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, response)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); List requestedFunctions = ImmutableList.builder() .add(new SchemaFunctionName("my_schema", "function_one")) .add(new SchemaFunctionName("my_schema", "function_two")) @@ -329,7 +313,7 @@ public void testFilterFunctions( requestingSecurityContext, "my_catalog", new LinkedHashSet<>(requestedFunctions)); - assertEquals(ImmutableSet.copyOf(result), ImmutableSet.copyOf(getSubset(requestedFunctions, expectedItems))); + assertThat(result).containsExactlyInAnyOrderElementsOf(getSubset(requestedFunctions, expectedItems)); String expectedRequest = """ { @@ -358,12 +342,15 @@ public void testFilterFunctions( } ] }"""; - assertStringRequestsEqual(ImmutableList.of(expectedRequest), this.mockClient.getRequests(), "/input/action"); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); } @Test public void testEmptyFilterColumns() { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, request -> OK_RESPONSE); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); + SchemaTableName tableOne = SchemaTableName.schemaTableName("my_schema", "table_one"); SchemaTableName tableTwo = SchemaTableName.schemaTableName("my_schema", "table_two"); Map> requestedColumns = ImmutableMap.>builder() @@ -375,8 +362,8 @@ public void testEmptyFilterColumns() requestingSecurityContext, "my_catalog", requestedColumns); - assertTrue(mockClient.getRequests().isEmpty()); - assertTrue(result.isEmpty()); + assertThat(mockClient.getRequests()).isEmpty(); + assertThat(result).isEmpty(); } @ParameterizedTest(name = "{index}: {0}") @@ -384,76 +371,68 @@ public void testEmptyFilterColumns() public void testEmptyRequests( BiFunction callable) { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, request -> OK_RESPONSE); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); + Collection result = callable.apply(authorizer, requestingSecurityContext); - assertTrue(result.isEmpty()); - assertTrue(mockClient.getRequests().isEmpty()); + assertThat(result).isEmpty(); + assertThat(mockClient.getRequests()).isEmpty(); } @ParameterizedTest(name = "{index}: {0} - {1}") @MethodSource("io.trino.plugin.opa.FilteringTestHelpers#prepopulatedErrorCases") public void testIllegalResponseThrows( BiFunction callable, - HttpClientUtils.MockResponse failureResponse, + MockResponse failureResponse, Class expectedException, String expectedErrorMessage) { - mockClient.setHandler(request -> failureResponse); - - Throwable actualError = assertThrows( - expectedException, - () -> callable.apply(authorizer, requestingSecurityContext)); - assertTrue( - actualError.getMessage().contains(expectedErrorMessage), - String.format("Error must contain '%s': %s", expectedErrorMessage, actualError.getMessage())); - assertFalse(mockClient.getRequests().isEmpty()); + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); + + assertThatThrownBy(() -> callable.apply(authorizer, requestingSecurityContext)) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); + assertThat(mockClient.getRequests()).isNotEmpty(); } @Test public void testResponseOutOfBoundsThrows() { - mockClient.setHandler(request -> new HttpClientUtils.MockResponse("{\"result\": [0, 1, 2]}", 200)); - - assertThrows( - OpaQueryException.QueryFailed.class, - () -> authorizer.filterCatalogs(requestingSecurityContext, ImmutableSet.of("catalog_one", "catalog_two"))); - assertThrows( - OpaQueryException.QueryFailed.class, - () -> authorizer.filterSchemas(requestingSecurityContext, "some_catalog", ImmutableSet.of("schema_one", "schema_two"))); - assertThrows( - OpaQueryException.QueryFailed.class, - () -> authorizer.filterTables( - requestingSecurityContext, - "some_catalog", + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, 200, "{\"result\": [0, 1, 2]}")); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); + + assertThatThrownBy(() -> authorizer.filterCatalogs(requestingSecurityContext, ImmutableSet.of("catalog_one", "catalog_two"))) + .isInstanceOf(OpaQueryException.QueryFailed.class); + assertThatThrownBy(() -> authorizer.filterSchemas(requestingSecurityContext, "some_catalog", ImmutableSet.of("schema_one", "schema_two"))) + .isInstanceOf(OpaQueryException.QueryFailed.class); + assertThatThrownBy(() -> authorizer.filterTables( + requestingSecurityContext, + "some_catalog", ImmutableSet.of( new SchemaTableName("some_schema", "table_one"), - new SchemaTableName("some_schema", "table_two")))); - assertThrows( - OpaQueryException.QueryFailed.class, - () -> authorizer.filterColumns( - requestingSecurityContext, - "some_catalog", - ImmutableMap.>builder() - .put(new SchemaTableName("some_schema", "some_table"), ImmutableSet.of("column_one", "column_two")) - .buildOrThrow())); - assertThrows( - OpaQueryException.QueryFailed.class, - () -> authorizer.filterViewQueryOwnedBy( - requestingIdentity, - ImmutableSet.of(Identity.ofUser("identity_one"), Identity.ofUser("identity_two")))); - } - - private ObjectNode encodeObjectWithKey(Object inp, String key) - { - return jsonMapper.createObjectNode().set(key, jsonMapper.valueToTree(inp)); + new SchemaTableName("some_schema", "table_two")))) + .isInstanceOf(OpaQueryException.QueryFailed.class); + assertThatThrownBy(() -> authorizer.filterColumns( + requestingSecurityContext, + "some_catalog", + ImmutableMap.>builder() + .put(new SchemaTableName("some_schema", "some_table"), ImmutableSet.of("column_one", "column_two")) + .buildOrThrow())) + .isInstanceOf(OpaQueryException.QueryFailed.class); + assertThatThrownBy(() -> authorizer.filterViewQueryOwnedBy( + requestingIdentity, + ImmutableSet.of(Identity.ofUser("identity_one"), Identity.ofUser("identity_two")))) + .isInstanceOf(OpaQueryException.QueryFailed.class); } private static Stream subsetProvider() { return Stream.of( - Arguments.of(Named.of("All-3-resources", new HttpClientUtils.MockResponse("{\"result\": [0, 1, 2]}", 200)), ImmutableList.of(0, 1, 2)), - Arguments.of(Named.of("First-and-last-resources", new HttpClientUtils.MockResponse("{\"result\": [0, 2]}", 200)), ImmutableList.of(0, 2)), - Arguments.of(Named.of("Only-one-resource", new HttpClientUtils.MockResponse("{\"result\": [2]}", 200)), ImmutableList.of(2)), - Arguments.of(Named.of("No-resources", new HttpClientUtils.MockResponse("{\"result\": []}", 200)), ImmutableList.of())); + Arguments.of(Named.of("All-3-resources", new MockResponse("{\"result\": [0, 1, 2]}", 200)), ImmutableList.of(0, 1, 2)), + Arguments.of(Named.of("First-and-last-resources", new MockResponse("{\"result\": [0, 2]}", 200)), ImmutableList.of(0, 2)), + Arguments.of(Named.of("Only-one-resource", new MockResponse("{\"result\": [2]}", 200)), ImmutableList.of(2)), + Arguments.of(Named.of("No-resources", new MockResponse("{\"result\": []}", 200)), ImmutableList.of())); } private List getSubset(List allItems, List subsetPositions) diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaConfig.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaConfig.java index 87d14a40cb9ee..37184bdefd3e8 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaConfig.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaConfig.java @@ -32,7 +32,8 @@ public void testDefaults() .setOpaUri(null) .setOpaBatchUri(null) .setLogRequests(false) - .setLogResponses(false)); + .setLogResponses(false) + .setAllowPermissioningOperations(false)); } @Test @@ -43,13 +44,15 @@ public void testExplicitPropertyMappings() .put("opa.policy.batched-uri", "https://opa-batch.example.com") .put("opa.log-requests", "true") .put("opa.log-responses", "true") + .put("opa.allow-permissioning-operations", "true") .buildOrThrow(); OpaConfig expected = new OpaConfig() .setOpaUri(URI.create("https://opa.example.com")) .setOpaBatchUri(URI.create("https://opa-batch.example.com")) .setLogRequests(true) - .setLogResponses(true); + .setLogResponses(true) + .setAllowPermissioningOperations(true); assertFullMapping(properties, expected); } diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/ResponseTest.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaResponseDecoding.java similarity index 57% rename from plugin/trino-opa/src/test/java/io/trino/plugin/opa/ResponseTest.java rename to plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaResponseDecoding.java index 66a2b21f5594b..63b7948f06cc8 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/ResponseTest.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaResponseDecoding.java @@ -13,56 +13,52 @@ */ package io.trino.plugin.opa; -import com.google.common.collect.ImmutableList; import io.airlift.json.JsonCodec; import io.airlift.json.JsonCodecFactory; import io.trino.plugin.opa.schema.OpaBatchQueryResult; import io.trino.plugin.opa.schema.OpaQueryResult; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.assertj.core.api.Assertions.assertThat; -public class ResponseTest +public class TestOpaResponseDecoding { - private JsonCodec responseCodec; - private JsonCodec batchResponseCodec; + private final JsonCodec responseCodec = new JsonCodecFactory().jsonCodec(OpaQueryResult.class); + private final JsonCodec batchResponseCodec = new JsonCodecFactory().jsonCodec(OpaBatchQueryResult.class); - @BeforeEach - public void setupParser() + @Test + public void testCanDeserializeOpaSingleResponse() { - this.responseCodec = new JsonCodecFactory().jsonCodec(OpaQueryResult.class); - this.batchResponseCodec = new JsonCodecFactory().jsonCodec(OpaBatchQueryResult.class); + testCanDeserializeOpaSingleResponse(true); + testCanDeserializeOpaSingleResponse(false); } - @ParameterizedTest - @ValueSource(booleans = {false, true}) - public void testCanDeserializeOpaSingleResponse(boolean response) + private void testCanDeserializeOpaSingleResponse(boolean response) { OpaQueryResult result = this.responseCodec.fromJson(""" { "decision_id": "foo", "result": %s }""".formatted(String.valueOf(response))); - assertEquals(response, result.result()); - assertEquals("foo", result.decisionId()); + assertThat(response).isEqualTo(result.result()); + assertThat(result.decisionId()).isEqualTo("foo"); + } + + @Test + public void testCanDeserializeOpaSingleResponseWithNoDecisionId() + { + testCanDeserializeOpaSingleResponseWithNoDecisionId(true); + testCanDeserializeOpaSingleResponseWithNoDecisionId(false); } - @ParameterizedTest - @ValueSource(booleans = {false, true}) - public void testCanDeserializeOpaSingleResponseWithNoDecisionId(boolean response) + private void testCanDeserializeOpaSingleResponseWithNoDecisionId(boolean response) { OpaQueryResult result = this.responseCodec.fromJson(""" { "result": %s }""".formatted(String.valueOf(response))); - assertEquals(response, result.result()); - assertNull(result.decisionId()); + assertThat(response).isEqualTo(result.result()); + assertThat(result.decisionId()).isNull(); } @Test @@ -73,25 +69,30 @@ public void testSingleResponseWithExtraFields() "result": true, "someExtraInfo": ["foo"] }"""); - assertTrue(result.result()); - assertNull(result.decisionId()); + assertThat(result.result()).isTrue(); + assertThat(result.decisionId()).isNull(); } @Test public void testUndefinedDecisionSingleResponseTreatedAsDeny() { OpaQueryResult result = this.responseCodec.fromJson("{}"); - assertFalse(result.result()); - assertNull(result.decisionId()); + assertThat(result.result()).isFalse(); + assertThat(result.decisionId()).isNull(); + } + + @Test + public void testEmptyOrUndefinedResponses() + { + testEmptyOrUndefinedResponses("{}"); + testEmptyOrUndefinedResponses("{\"result\": []}"); } - @ParameterizedTest - @ValueSource(strings = {"{}", "{\"result\": []}"}) - public void testEmptyOrUndefinedResponses(String response) + private void testEmptyOrUndefinedResponses(String response) { OpaBatchQueryResult result = this.batchResponseCodec.fromJson(response); - assertEquals(ImmutableList.of(), result.result()); - assertNull(result.decisionId()); + assertThat(result.result()).isEmpty(); + assertThat(result.decisionId()).isNull(); } @Test @@ -101,8 +102,8 @@ public void testBatchResponseWithItemsNoDecisionId() { "result": [1, 2, 3] }"""); - assertEquals(ImmutableList.of(1, 2, 3), result.result()); - assertNull(result.decisionId()); + assertThat(result.result()).containsExactly(1, 2, 3); + assertThat(result.decisionId()).isNull(); } @Test @@ -113,8 +114,8 @@ public void testBatchResponseWithItemsAndDecisionId() "result": [1, 2, 3], "decision_id": "foobar" }"""); - assertEquals(ImmutableList.of(1, 2, 3), result.result()); - assertEquals("foobar", result.decisionId()); + assertThat(result.result()).containsExactly(1, 2, 3); + assertThat(result.decisionId()).isEqualTo("foobar"); } @Test @@ -127,7 +128,7 @@ public void testBatchResponseWithExtraFields() "someInfo": "foo", "andAnObject": {} }"""); - assertEquals(ImmutableList.of(1, 2, 3), result.result()); - assertEquals("foobar", result.decisionId()); + assertThat(result.result()).containsExactly(1, 2, 3); + assertThat(result.decisionId()).isEqualTo("foobar"); } } From 0fa64eb60c6f6b5a7b1dc0bbe0d22b33313b4f2c Mon Sep 17 00:00:00 2001 From: Pablo Arteaga Date: Wed, 29 Nov 2023 17:36:05 +0000 Subject: [PATCH 03/11] Implement softwareStack feature --- .../io/trino/plugin/opa/OpaAccessControl.java | 97 +++++++++++-------- .../plugin/opa/OpaAccessControlFactory.java | 12 ++- .../plugin/opa/OpaBatchAccessControl.java | 23 +++-- .../plugin/opa/schema/OpaPluginContext.java | 24 +++++ .../plugin/opa/schema/OpaQueryContext.java | 16 +-- .../plugin/opa/RequestTestUtilities.java | 4 + .../java/io/trino/plugin/opa/TestHelpers.java | 38 +++++++- .../plugin/opa/TestOpaAccessControl.java | 45 +++++++++ ...aAccessControlPermissioningOperations.java | 4 +- 9 files changed, 192 insertions(+), 71 deletions(-) create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaPluginContext.java diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java index ca19d45981dc1..4da0f63d6c8b8 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableSetMultimap; import com.google.common.collect.Multimaps; import com.google.inject.Inject; +import io.trino.plugin.opa.schema.OpaPluginContext; import io.trino.plugin.opa.schema.OpaQueryContext; import io.trino.plugin.opa.schema.OpaQueryInput; import io.trino.plugin.opa.schema.OpaQueryInputAction; @@ -24,6 +25,7 @@ import io.trino.plugin.opa.schema.TrinoCatalogSessionProperty; import io.trino.plugin.opa.schema.TrinoFunction; import io.trino.plugin.opa.schema.TrinoGrantPrincipal; +import io.trino.plugin.opa.schema.TrinoIdentity; import io.trino.plugin.opa.schema.TrinoSchema; import io.trino.plugin.opa.schema.TrinoTable; import io.trino.plugin.opa.schema.TrinoUser; @@ -72,6 +74,7 @@ import static io.trino.spi.security.AccessDeniedException.denyShowCreateSchema; import static io.trino.spi.security.AccessDeniedException.denyShowFunctions; import static io.trino.spi.security.AccessDeniedException.denyShowTables; +import static java.util.Objects.requireNonNull; public sealed class OpaAccessControl implements SystemAccessControl @@ -79,19 +82,21 @@ public sealed class OpaAccessControl { private final OpaHighLevelClient opaHighLevelClient; private final boolean allowPermissioningOperations; + private final OpaPluginContext pluginContext; @Inject - public OpaAccessControl(OpaHighLevelClient opaHighLevelClient, OpaConfig config) + public OpaAccessControl(OpaHighLevelClient opaHighLevelClient, OpaConfig config, OpaPluginContext pluginContext) { - this.opaHighLevelClient = opaHighLevelClient; + this.opaHighLevelClient = requireNonNull(opaHighLevelClient, "opaHighLevelClient is null"); this.allowPermissioningOperations = config.getAllowPermissioningOperations(); + this.pluginContext = requireNonNull(pluginContext, "pluginContext is null"); } @Override public void checkCanImpersonateUser(Identity identity, String userName) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromIdentity(identity), + buildQueryContext(identity), "ImpersonateUser", () -> denyImpersonateUser(identity.getUser(), userName), OpaQueryInputResource.builder().user(new TrinoUser(userName)).build()); @@ -104,13 +109,13 @@ public void checkCanSetUser(Optional principal, String userName) @Override public void checkCanExecuteQuery(Identity identity) { - opaHighLevelClient.queryAndEnforce(OpaQueryContext.fromIdentity(identity), "ExecuteQuery", AccessDeniedException::denyExecuteQuery); + opaHighLevelClient.queryAndEnforce(buildQueryContext(identity), "ExecuteQuery", AccessDeniedException::denyExecuteQuery); } @Override public void checkCanViewQueryOwnedBy(Identity identity, Identity queryOwner) { - opaHighLevelClient.queryAndEnforce(OpaQueryContext.fromIdentity(identity), "ViewQueryOwnedBy", AccessDeniedException::denyViewQuery, OpaQueryInputResource.builder().user(new TrinoUser(queryOwner)).build()); + opaHighLevelClient.queryAndEnforce(buildQueryContext(identity), "ViewQueryOwnedBy", AccessDeniedException::denyViewQuery, OpaQueryInputResource.builder().user(new TrinoUser(queryOwner)).build()); } @Override @@ -119,7 +124,7 @@ public Collection filterViewQueryOwnedBy(Identity identity, Collection return opaHighLevelClient.parallelFilterFromOpa( queryOwners, queryOwner -> buildQueryInputForSimpleResource( - OpaQueryContext.fromIdentity(identity), + buildQueryContext(identity), "FilterViewQueryOwnedBy", OpaQueryInputResource.builder().user(new TrinoUser(queryOwner)).build())); } @@ -127,26 +132,26 @@ public Collection filterViewQueryOwnedBy(Identity identity, Collection @Override public void checkCanKillQueryOwnedBy(Identity identity, Identity queryOwner) { - opaHighLevelClient.queryAndEnforce(OpaQueryContext.fromIdentity(identity), "KillQueryOwnedBy", AccessDeniedException::denyKillQuery, OpaQueryInputResource.builder().user(new TrinoUser(queryOwner)).build()); + opaHighLevelClient.queryAndEnforce(buildQueryContext(identity), "KillQueryOwnedBy", AccessDeniedException::denyKillQuery, OpaQueryInputResource.builder().user(new TrinoUser(queryOwner)).build()); } @Override public void checkCanReadSystemInformation(Identity identity) { - opaHighLevelClient.queryAndEnforce(OpaQueryContext.fromIdentity(identity), "ReadSystemInformation", AccessDeniedException::denyReadSystemInformationAccess); + opaHighLevelClient.queryAndEnforce(buildQueryContext(identity), "ReadSystemInformation", AccessDeniedException::denyReadSystemInformationAccess); } @Override public void checkCanWriteSystemInformation(Identity identity) { - opaHighLevelClient.queryAndEnforce(OpaQueryContext.fromIdentity(identity), "WriteSystemInformation", AccessDeniedException::denyWriteSystemInformationAccess); + opaHighLevelClient.queryAndEnforce(buildQueryContext(identity), "WriteSystemInformation", AccessDeniedException::denyWriteSystemInformationAccess); } @Override public void checkCanSetSystemSessionProperty(Identity identity, String propertyName) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromIdentity(identity), + buildQueryContext(identity), "SetSystemSessionProperty", () -> denySetSystemSessionProperty(propertyName), OpaQueryInputResource.builder().systemSessionProperty(propertyName).build()); @@ -156,7 +161,7 @@ public void checkCanSetSystemSessionProperty(Identity identity, String propertyN public boolean canAccessCatalog(SystemSecurityContext context, String catalogName) { return opaHighLevelClient.queryOpaWithSimpleResource( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "AccessCatalog", OpaQueryInputResource.builder().catalog(catalogName).build()); } @@ -165,7 +170,7 @@ public boolean canAccessCatalog(SystemSecurityContext context, String catalogNam public void checkCanCreateCatalog(SystemSecurityContext context, String catalog) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "CreateCatalog", () -> denyCreateCatalog(catalog), OpaQueryInputResource.builder().catalog(catalog).build()); @@ -175,7 +180,7 @@ public void checkCanCreateCatalog(SystemSecurityContext context, String catalog) public void checkCanDropCatalog(SystemSecurityContext context, String catalog) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "DropCatalog", () -> denyDropCatalog(catalog), OpaQueryInputResource.builder().catalog(catalog).build()); @@ -187,7 +192,7 @@ public Set filterCatalogs(SystemSecurityContext context, Set cat return opaHighLevelClient.parallelFilterFromOpa( catalogs, catalog -> buildQueryInputForSimpleResource( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "FilterCatalogs", OpaQueryInputResource.builder().catalog(catalog).build())); } @@ -196,7 +201,7 @@ public Set filterCatalogs(SystemSecurityContext context, Set cat public void checkCanCreateSchema(SystemSecurityContext context, CatalogSchemaName schema, Map properties) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "CreateSchema", () -> denyCreateSchema(schema.toString()), OpaQueryInputResource.builder().schema(new TrinoSchema(schema).withProperties(convertProperties(properties))).build()); @@ -206,7 +211,7 @@ public void checkCanCreateSchema(SystemSecurityContext context, CatalogSchemaNam public void checkCanDropSchema(SystemSecurityContext context, CatalogSchemaName schema) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "DropSchema", () -> denyDropSchema(schema.toString()), OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build()); @@ -218,7 +223,7 @@ public void checkCanRenameSchema(SystemSecurityContext context, CatalogSchemaNam OpaQueryInputResource resource = OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build(); OpaQueryInputResource targetResource = OpaQueryInputResource.builder().schema(new TrinoSchema(schema.getCatalogName(), newSchemaName)).build(); - OpaQueryContext queryContext = OpaQueryContext.fromSystemSecurityContext(context); + OpaQueryContext queryContext = buildQueryContext(context); if (!opaHighLevelClient.queryOpaWithSourceAndTargetResource(queryContext, "RenameSchema", resource, targetResource)) { denyRenameSchema(schema.toString(), newSchemaName); @@ -234,7 +239,7 @@ public void checkCanSetSchemaAuthorization(SystemSecurityContext context, Catalo .resource(resource) .grantee(TrinoGrantPrincipal.fromTrinoPrincipal(principal)) .build(); - OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + OpaQueryInput input = new OpaQueryInput(buildQueryContext(context), action); if (!opaHighLevelClient.queryOpa(input)) { denySetSchemaAuthorization(schema.toString(), principal); @@ -245,7 +250,7 @@ public void checkCanSetSchemaAuthorization(SystemSecurityContext context, Catalo public void checkCanShowSchemas(SystemSecurityContext context, String catalogName) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "ShowSchemas", AccessDeniedException::denyShowSchemas, OpaQueryInputResource.builder().catalog(catalogName).build()); @@ -257,7 +262,7 @@ public Set filterSchemas(SystemSecurityContext context, String catalogNa return opaHighLevelClient.parallelFilterFromOpa( schemaNames, schema -> buildQueryInputForSimpleResource( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "FilterSchemas", OpaQueryInputResource.builder().schema(new TrinoSchema(catalogName, schema)).build())); } @@ -266,7 +271,7 @@ public Set filterSchemas(SystemSecurityContext context, String catalogNa public void checkCanShowCreateSchema(SystemSecurityContext context, CatalogSchemaName schemaName) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "ShowCreateSchema", () -> denyShowCreateSchema(schemaName.toString()), OpaQueryInputResource.builder().schema(new TrinoSchema(schemaName)).build()); @@ -295,7 +300,7 @@ public void checkCanRenameTable(SystemSecurityContext context, CatalogSchemaTabl { OpaQueryInputResource oldResource = OpaQueryInputResource.builder().table(new TrinoTable(table)).build(); OpaQueryInputResource newResource = OpaQueryInputResource.builder().table(new TrinoTable(newTable)).build(); - OpaQueryContext queryContext = OpaQueryContext.fromSystemSecurityContext(context); + OpaQueryContext queryContext = buildQueryContext(context); if (!opaHighLevelClient.queryOpaWithSourceAndTargetResource(queryContext, "RenameTable", oldResource, newResource)) { denyRenameTable(table.toString(), newTable.toString()); @@ -330,7 +335,7 @@ public void checkCanSetColumnComment(SystemSecurityContext context, CatalogSchem public void checkCanShowTables(SystemSecurityContext context, CatalogSchemaName schema) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "ShowTables", () -> denyShowTables(schema.toString()), OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build()); @@ -342,7 +347,7 @@ public Set filterTables(SystemSecurityContext context, String c return opaHighLevelClient.parallelFilterFromOpa( tableNames, table -> buildQueryInputForSimpleResource( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "FilterTables", OpaQueryInputResource.builder() .table(new TrinoTable(catalogName, table.getSchemaName(), table.getTableName())) @@ -369,7 +374,7 @@ public Map> filterColumns(SystemSecurityContext con Set filteredColumns = opaHighLevelClient.parallelFilterFromOpa( allColumnsBuilder.build(), tableColumn -> buildQueryInputForSimpleResource( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "FilterColumns", OpaQueryInputResource.builder().table(tableColumn).build())); @@ -407,7 +412,7 @@ public void checkCanSetTableAuthorization(SystemSecurityContext context, Catalog .resource(resource) .grantee(TrinoGrantPrincipal.fromTrinoPrincipal(principal)) .build(); - OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + OpaQueryInput input = new OpaQueryInput(buildQueryContext(context), action); if (!opaHighLevelClient.queryOpa(input)) { denySetTableAuthorization(table.toString(), principal); @@ -461,7 +466,7 @@ public void checkCanRenameView(SystemSecurityContext context, CatalogSchemaTable { OpaQueryInputResource oldResource = OpaQueryInputResource.builder().table(new TrinoTable(view)).build(); OpaQueryInputResource newResource = OpaQueryInputResource.builder().table(new TrinoTable(newView)).build(); - OpaQueryContext queryContext = OpaQueryContext.fromSystemSecurityContext(context); + OpaQueryContext queryContext = buildQueryContext(context); if (!opaHighLevelClient.queryOpaWithSourceAndTargetResource(queryContext, "RenameView", oldResource, newResource)) { denyRenameView(view.toString(), newView.toString()); @@ -477,7 +482,7 @@ public void checkCanSetViewAuthorization(SystemSecurityContext context, CatalogS .resource(resource) .grantee(TrinoGrantPrincipal.fromTrinoPrincipal(principal)) .build(); - OpaQueryInput input = new OpaQueryInput(OpaQueryContext.fromSystemSecurityContext(context), action); + OpaQueryInput input = new OpaQueryInput(buildQueryContext(context), action); if (!opaHighLevelClient.queryOpa(input)) { denySetViewAuthorization(view.toString(), principal); @@ -525,7 +530,7 @@ public void checkCanRenameMaterializedView(SystemSecurityContext context, Catalo { OpaQueryInputResource oldResource = OpaQueryInputResource.builder().table(new TrinoTable(view)).build(); OpaQueryInputResource newResource = OpaQueryInputResource.builder().table(new TrinoTable(newView)).build(); - OpaQueryContext queryContext = OpaQueryContext.fromSystemSecurityContext(context); + OpaQueryContext queryContext = buildQueryContext(context); if (!opaHighLevelClient.queryOpaWithSourceAndTargetResource(queryContext, "RenameMaterializedView", oldResource, newResource)) { denyRenameMaterializedView(view.toString(), newView.toString()); @@ -536,7 +541,7 @@ public void checkCanRenameMaterializedView(SystemSecurityContext context, Catalo public void checkCanSetCatalogSessionProperty(SystemSecurityContext context, String catalogName, String propertyName) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "SetCatalogSessionProperty", () -> denySetCatalogSessionProperty(propertyName), OpaQueryInputResource.builder().catalogSessionProperty(new TrinoCatalogSessionProperty(catalogName, propertyName)).build()); @@ -624,7 +629,7 @@ public void checkCanShowRoleGrants(SystemSecurityContext context) public void checkCanShowFunctions(SystemSecurityContext context, CatalogSchemaName schema) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "ShowFunctions", () -> denyShowFunctions(schema.toString()), OpaQueryInputResource.builder().schema(new TrinoSchema(schema)).build()); @@ -636,7 +641,7 @@ public Set filterFunctions(SystemSecurityContext context, St return opaHighLevelClient.parallelFilterFromOpa( functionNames, function -> buildQueryInputForSimpleResource( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "FilterFunctions", OpaQueryInputResource.builder() .function( @@ -650,7 +655,7 @@ public Set filterFunctions(SystemSecurityContext context, St public void checkCanExecuteProcedure(SystemSecurityContext systemSecurityContext, CatalogSchemaRoutineName procedure) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(systemSecurityContext), + buildQueryContext(systemSecurityContext), "ExecuteProcedure", () -> denyExecuteProcedure(procedure.toString()), OpaQueryInputResource.builder().function(TrinoFunction.fromTrinoFunction(procedure)).build()); @@ -660,7 +665,7 @@ public void checkCanExecuteProcedure(SystemSecurityContext systemSecurityContext public boolean canExecuteFunction(SystemSecurityContext systemSecurityContext, CatalogSchemaRoutineName functionName) { return opaHighLevelClient.queryOpaWithSimpleResource( - OpaQueryContext.fromSystemSecurityContext(systemSecurityContext), + buildQueryContext(systemSecurityContext), "ExecuteFunction", OpaQueryInputResource.builder().function(TrinoFunction.fromTrinoFunction(functionName)).build()); } @@ -669,7 +674,7 @@ public boolean canExecuteFunction(SystemSecurityContext systemSecurityContext, C public boolean canCreateViewWithExecuteFunction(SystemSecurityContext systemSecurityContext, CatalogSchemaRoutineName functionName) { return opaHighLevelClient.queryOpaWithSimpleResource( - OpaQueryContext.fromSystemSecurityContext(systemSecurityContext), + buildQueryContext(systemSecurityContext), "CreateViewWithExecuteFunction", OpaQueryInputResource.builder().function(TrinoFunction.fromTrinoFunction(functionName)).build()); } @@ -678,7 +683,7 @@ public boolean canCreateViewWithExecuteFunction(SystemSecurityContext systemSecu public void checkCanExecuteTableProcedure(SystemSecurityContext systemSecurityContext, CatalogSchemaTableName table, String procedure) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(systemSecurityContext), + buildQueryContext(systemSecurityContext), "ExecuteTableProcedure", () -> denyExecuteTableProcedure(table.toString(), procedure), OpaQueryInputResource.builder().table(new TrinoTable(table)).function(procedure).build()); @@ -688,7 +693,7 @@ public void checkCanExecuteTableProcedure(SystemSecurityContext systemSecurityCo public void checkCanCreateFunction(SystemSecurityContext systemSecurityContext, CatalogSchemaRoutineName functionName) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(systemSecurityContext), + buildQueryContext(systemSecurityContext), "CreateFunction", () -> denyCreateFunction(functionName.toString()), OpaQueryInputResource.builder().function(TrinoFunction.fromTrinoFunction(functionName)).build()); @@ -698,7 +703,7 @@ public void checkCanCreateFunction(SystemSecurityContext systemSecurityContext, public void checkCanDropFunction(SystemSecurityContext systemSecurityContext, CatalogSchemaRoutineName functionName) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(systemSecurityContext), + buildQueryContext(systemSecurityContext), "DropFunction", () -> denyDropFunction(functionName.toString()), OpaQueryInputResource.builder().function(TrinoFunction.fromTrinoFunction(functionName)).build()); @@ -707,7 +712,7 @@ public void checkCanDropFunction(SystemSecurityContext systemSecurityContext, Ca private void checkTableOperation(SystemSecurityContext context, String actionName, CatalogSchemaTableName table, Consumer deny) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), actionName, () -> deny.accept(table.toString()), OpaQueryInputResource.builder().table(new TrinoTable(table)).build()); @@ -716,7 +721,7 @@ private void checkTableOperation(SystemSecurityContext context, String actionNam private void checkTableAndPropertiesOperation(SystemSecurityContext context, String actionName, CatalogSchemaTableName table, Map> properties, Consumer deny) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), actionName, () -> deny.accept(table.toString()), OpaQueryInputResource.builder().table(new TrinoTable(table).withProperties(properties)).build()); @@ -725,7 +730,7 @@ private void checkTableAndPropertiesOperation(SystemSecurityContext context, Str private void checkTableAndColumnsOperation(SystemSecurityContext context, String actionName, CatalogSchemaTableName table, Set columns, BiConsumer> deny) { opaHighLevelClient.queryAndEnforce( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), actionName, () -> deny.accept(table.toString(), columns), OpaQueryInputResource.builder().table(new TrinoTable(table).withColumns(columns)).build()); @@ -751,4 +756,14 @@ private static Map> convertProperties(Map Map.entry(propertiesEntry.getKey(), Optional.ofNullable(propertiesEntry.getValue()))) .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); } + + OpaQueryContext buildQueryContext(Identity trinoIdentity) + { + return new OpaQueryContext(TrinoIdentity.fromTrinoIdentity(trinoIdentity), pluginContext); + } + + OpaQueryContext buildQueryContext(SystemSecurityContext securityContext) + { + return new OpaQueryContext(TrinoIdentity.fromTrinoIdentity(securityContext.getIdentity()), pluginContext); + } } diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlFactory.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlFactory.java index 24cc9c45c269f..e23d2838f9656 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlFactory.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlFactory.java @@ -22,6 +22,7 @@ import io.airlift.concurrent.BoundedExecutor; import io.airlift.http.client.HttpClient; import io.airlift.json.JsonModule; +import io.trino.plugin.opa.schema.OpaPluginContext; import io.trino.plugin.opa.schema.OpaQuery; import io.trino.plugin.opa.schema.OpaQueryResult; import io.trino.spi.security.SystemAccessControl; @@ -49,19 +50,21 @@ public String getName() @Override public SystemAccessControl create(Map config) { - return create(config, Optional.empty()); + return create(config, Optional.empty(), Optional.empty()); } @Override public SystemAccessControl create(Map config, SystemAccessControlContext context) { - return create(config); + return create(config, Optional.empty(), Optional.ofNullable(context)); } @VisibleForTesting - protected static SystemAccessControl create(Map config, Optional httpClient) + protected static SystemAccessControl create(Map config, Optional httpClient, Optional context) { requireNonNull(config, "config is null"); + requireNonNull(httpClient, "httpClient is null"); + requireNonNull(context, "context is null"); Bootstrap app = new Bootstrap( new JsonModule(), @@ -71,6 +74,9 @@ protected static SystemAccessControl create(Map config, Optional httpClient.ifPresentOrElse( client -> binder.bind(Key.get(HttpClient.class, ForOpa.class)).toInstance(client), () -> httpClientBinder(binder).bindHttpClient("opa", ForOpa.class)); + context.ifPresentOrElse( + actualContext -> binder.bind(OpaPluginContext.class).toInstance(new OpaPluginContext(actualContext.getVersion())), + () -> binder.bind(OpaPluginContext.class).toInstance(new OpaPluginContext("UNKNOWN"))); binder.bind(OpaHighLevelClient.class); binder.bind(Key.get(Executor.class, ForOpa.class)) .toProvider(ExecutorProvider.class) diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaBatchAccessControl.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaBatchAccessControl.java index 1ac684fd38d9e..26b78a0f7f8d0 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaBatchAccessControl.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaBatchAccessControl.java @@ -18,6 +18,7 @@ import com.google.inject.Inject; import io.airlift.json.JsonCodec; import io.trino.plugin.opa.schema.OpaBatchQueryResult; +import io.trino.plugin.opa.schema.OpaPluginContext; import io.trino.plugin.opa.schema.OpaQueryContext; import io.trino.plugin.opa.schema.OpaQueryInput; import io.trino.plugin.opa.schema.OpaQueryInputAction; @@ -40,6 +41,7 @@ import java.util.function.Function; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; public final class OpaBatchAccessControl extends OpaAccessControl @@ -53,19 +55,20 @@ public OpaBatchAccessControl( OpaHighLevelClient opaHighLevelClient, JsonCodec batchResultCodec, OpaHttpClient opaHttpClient, - OpaConfig config) + OpaConfig config, + OpaPluginContext pluginContext) { - super(opaHighLevelClient, config); + super(opaHighLevelClient, config, pluginContext); this.opaBatchedPolicyUri = config.getOpaBatchUri().orElseThrow(); - this.batchResultCodec = batchResultCodec; - this.opaHttpClient = opaHttpClient; + this.batchResultCodec = requireNonNull(batchResultCodec, "batchResultCodec is null"); + this.opaHttpClient = requireNonNull(opaHttpClient, "opaHttpClient is null"); } @Override public Collection filterViewQueryOwnedBy(Identity identity, Collection queryOwners) { return batchFilterFromOpa( - OpaQueryContext.fromIdentity(identity), + buildQueryContext(identity), "FilterViewQueryOwnedBy", queryOwners, queryOwner -> OpaQueryInputResource.builder() @@ -77,7 +80,7 @@ public Collection filterViewQueryOwnedBy(Identity identity, Collection public Set filterCatalogs(SystemSecurityContext context, Set catalogs) { return batchFilterFromOpa( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "FilterCatalogs", catalogs, catalog -> OpaQueryInputResource.builder() @@ -89,7 +92,7 @@ public Set filterCatalogs(SystemSecurityContext context, Set cat public Set filterSchemas(SystemSecurityContext context, String catalogName, Set schemaNames) { return batchFilterFromOpa( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "FilterSchemas", schemaNames, schema -> OpaQueryInputResource.builder().schema(new TrinoSchema(catalogName, schema)).build()); @@ -99,7 +102,7 @@ public Set filterSchemas(SystemSecurityContext context, String catalogNa public Set filterTables(SystemSecurityContext context, String catalogName, Set tableNames) { return batchFilterFromOpa( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "FilterTables", tableNames, table -> OpaQueryInputResource.builder().table(new TrinoTable(catalogName, table.getSchemaName(), table.getTableName())).build()); @@ -109,7 +112,7 @@ public Set filterTables(SystemSecurityContext context, String c public Map> filterColumns(SystemSecurityContext context, String catalogName, Map> tableColumns) { BiFunction, OpaQueryInput> requestBuilder = batchRequestBuilder( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "FilterColumns", (schemaTableName, columns) -> OpaQueryInputResource.builder() .table(new TrinoTable(catalogName, schemaTableName.getSchemaName(), schemaTableName.getTableName()).withColumns(ImmutableSet.copyOf(columns))) @@ -121,7 +124,7 @@ public Map> filterColumns(SystemSecurityContext con public Set filterFunctions(SystemSecurityContext context, String catalogName, Set functionNames) { return batchFilterFromOpa( - OpaQueryContext.fromSystemSecurityContext(context), + buildQueryContext(context), "FilterFunctions", functionNames, function -> OpaQueryInputResource.builder() diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaPluginContext.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaPluginContext.java new file mode 100644 index 0000000000000..341957e4899a4 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaPluginContext.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import static java.util.Objects.requireNonNull; + +public record OpaPluginContext(String trinoVersion) +{ + public OpaPluginContext + { + requireNonNull(trinoVersion, "trinoVersion is null"); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryContext.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryContext.java index 75af04b2071c1..22c9811351262 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryContext.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryContext.java @@ -13,25 +13,13 @@ */ package io.trino.plugin.opa.schema; -import io.trino.spi.security.Identity; -import io.trino.spi.security.SystemSecurityContext; - import static java.util.Objects.requireNonNull; -public record OpaQueryContext(TrinoIdentity identity) +public record OpaQueryContext(TrinoIdentity identity, OpaPluginContext softwareStack) { - public static OpaQueryContext fromSystemSecurityContext(SystemSecurityContext ctx) - { - return new OpaQueryContext(TrinoIdentity.fromTrinoIdentity(ctx.getIdentity())); - } - - public static OpaQueryContext fromIdentity(Identity identity) - { - return new OpaQueryContext(TrinoIdentity.fromTrinoIdentity(identity)); - } - public OpaQueryContext { requireNonNull(identity, "identity is null"); + requireNonNull(softwareStack, "softwareStack is null"); } } diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java index 73bd79f09a356..b15a585501a91 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java @@ -25,6 +25,7 @@ import java.util.function.Function; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.plugin.opa.TestHelpers.SYSTEM_ACCESS_CONTROL_CONTEXT; import static org.assertj.core.api.Assertions.assertThat; public class RequestTestUtilities @@ -70,6 +71,9 @@ public static Function buildValidatingRequestHandler(Ide if (!groupsInRequestBuilder.build().equals(expectedUser.getGroups())) { throw new AssertionError("Request had invalid set of groups in the identity block"); } + if (!parsedRequest.at("/input/context/softwareStack/trinoVersion").asText().equals(SYSTEM_ACCESS_CONTROL_CONTEXT.getVersion())) { + throw new AssertionError("Request had invalid trinoVersion"); + } return customHandler.apply(parsedRequest); }; } diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java index 7061498fe826c..8733f3142eb6f 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java @@ -16,11 +16,14 @@ import com.fasterxml.jackson.databind.JsonNode; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.trace.Tracer; import io.trino.execution.QueryIdGenerator; import io.trino.plugin.opa.HttpClientUtils.InstrumentedHttpClient; import io.trino.plugin.opa.HttpClientUtils.MockResponse; import io.trino.spi.security.AccessDeniedException; import io.trino.spi.security.Identity; +import io.trino.spi.security.SystemAccessControlFactory; import io.trino.spi.security.SystemSecurityContext; import org.junit.jupiter.api.Named; import org.junit.jupiter.params.provider.Arguments; @@ -60,6 +63,7 @@ private TestHelpers() {} public static final MockResponse UNDEFINED_RESPONSE = new MockResponse("{}", 404); public static final MockResponse BAD_REQUEST_RESPONSE = new MockResponse("{}", 400); public static final MockResponse SERVER_ERROR_RESPONSE = new MockResponse("", 500); + public static final SystemAccessControlFactory.SystemAccessControlContext SYSTEM_ACCESS_CONTROL_CONTEXT = new TestingSystemAccessControlContext("TEST_VERSION"); public static Stream createFailingTestCases(Stream baseTestCases) { @@ -143,7 +147,7 @@ public static InstrumentedHttpClient createMockHttpClient(URI expectedUri, Funct public static OpaAccessControl createOpaAuthorizer(URI opaUri, InstrumentedHttpClient mockHttpClient) { - return (OpaAccessControl) OpaAccessControlFactory.create(ImmutableMap.of("opa.policy.uri", opaUri.toString()), Optional.of(mockHttpClient)); + return (OpaAccessControl) OpaAccessControlFactory.create(ImmutableMap.of("opa.policy.uri", opaUri.toString()), Optional.of(mockHttpClient), Optional.of(SYSTEM_ACCESS_CONTROL_CONTEXT)); } public static OpaAccessControl createOpaAuthorizer(URI opaUri, URI opaBatchUri, InstrumentedHttpClient mockHttpClient) @@ -153,6 +157,36 @@ public static OpaAccessControl createOpaAuthorizer(URI opaUri, URI opaBatchUri, .put("opa.policy.uri", opaUri.toString()) .put("opa.policy.batched-uri", opaBatchUri.toString()) .buildOrThrow(), - Optional.of(mockHttpClient)); + Optional.of(mockHttpClient), + Optional.of(SYSTEM_ACCESS_CONTROL_CONTEXT)); + } + + static final class TestingSystemAccessControlContext + implements SystemAccessControlFactory.SystemAccessControlContext + { + private final String trinoVersion; + + public TestingSystemAccessControlContext(String version) + { + this.trinoVersion = version; + } + + @Override + public String getVersion() + { + return this.trinoVersion; + } + + @Override + public OpenTelemetry getOpenTelemetry() + { + return null; + } + + @Override + public Tracer getTracer() + { + return null; + } } } diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java index 027f851bc3bb9..fe151319fb125 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java @@ -18,11 +18,13 @@ import com.google.common.collect.Streams; import io.trino.plugin.opa.HttpClientUtils.InstrumentedHttpClient; import io.trino.plugin.opa.HttpClientUtils.MockResponse; +import io.trino.plugin.opa.TestHelpers.TestingSystemAccessControlContext; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.CatalogSchemaRoutineName; import io.trino.spi.connector.CatalogSchemaTableName; import io.trino.spi.security.Identity; import io.trino.spi.security.PrincipalType; +import io.trino.spi.security.SystemAccessControlFactory; import io.trino.spi.security.SystemSecurityContext; import io.trino.spi.security.TrinoPrincipal; import org.junit.jupiter.api.Named; @@ -1093,6 +1095,49 @@ public void testCanExecuteTableProcedureFailure( .hasMessageContaining(expectedErrorMessage); } + @Test + public void testRequestContextContentsWithKnownTrinoVersion() + { + testRequestContextContentsForGivenTrinoVersion( + Optional.of(new TestingSystemAccessControlContext("12345.67890")), + "12345.67890"); + } + + @Test + public void testRequestContextContentsWithUnknownTrinoVersion() + { + testRequestContextContentsForGivenTrinoVersion(Optional.empty(), "UNKNOWN"); + } + + private void testRequestContextContentsForGivenTrinoVersion(Optional accessControlContext, String expectedTrinoVersion) + { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, request -> OK_RESPONSE); + OpaAccessControl authorizer = (OpaAccessControl) OpaAccessControlFactory.create( + ImmutableMap.of("opa.policy.uri", OPA_SERVER_URI.toString()), + Optional.of(mockClient), + accessControlContext); + Identity sampleIdentityWithGroups = Identity.forUser("test_user").withGroups(ImmutableSet.of("some_group")).build(); + + authorizer.checkCanExecuteQuery(sampleIdentityWithGroups); + + String expectedRequest = """ + { + "action": { + "operation": "ExecuteQuery" + }, + "context": { + "identity": { + "user": "test_user", + "groups": ["some_group"] + }, + "softwareStack": { + "trinoVersion": "%s" + } + } + }""".formatted(expectedTrinoVersion); + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input"); + } + private static Stream noResourceActionTestCases() { Stream> methods = Stream.of( diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlPermissioningOperations.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlPermissioningOperations.java index a3678e9ea8728..ef3d2c243dac4 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlPermissioningOperations.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlPermissioningOperations.java @@ -31,6 +31,7 @@ import java.util.Set; import java.util.function.Consumer; +import static io.trino.plugin.opa.TestHelpers.SYSTEM_ACCESS_CONTROL_CONTEXT; import static io.trino.plugin.opa.TestHelpers.createMockHttpClient; import static io.trino.plugin.opa.TestHelpers.systemSecurityContextFromIdentity; import static org.assertj.core.api.Assertions.assertThat; @@ -142,6 +143,7 @@ private static OpaAccessControl createAuthorizer(boolean allowPermissioningOpera .put("opa.policy.uri", OPA_SERVER_URI.toString()) .put("opa.allow-permissioning-operations", String.valueOf(allowPermissioningOperations)) .buildOrThrow(), - Optional.of(mockClient)); + Optional.of(mockClient), + Optional.of(SYSTEM_ACCESS_CONTROL_CONTEXT)); } } From 10b78fcbc7e1dc2e5b0face2904e47e14e85f6d7 Mon Sep 17 00:00:00 2001 From: Pablo Arteaga Date: Wed, 29 Nov 2023 19:44:36 +0000 Subject: [PATCH 04/11] Update README --- plugin/trino-opa/README.md | 58 +++++++++++++++++++++++++++++++------- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/plugin/trino-opa/README.md b/plugin/trino-opa/README.md index f4724973c7f88..93b668a1b80bf 100644 --- a/plugin/trino-opa/README.md +++ b/plugin/trino-opa/README.md @@ -47,19 +47,43 @@ opa.policy.batched-uri=https://your-opa-endpoint/v1/data/batch ### All configuration entries -| Configuration name | Required | Default | Description | -|--------------------------|:--------:|:-------:|------------------------------------------------------------------------------------------------------------------------------| -| `opa.policy.uri` | Yes | N/A | Endpoint to query OPA | -| `opa.policy.batched-uri` | No | Unset | Endpoint for batch OPA requests | -| `opa.log-requests` | No | `false` | Determines whether requests (URI, headers and entire body) are logged prior to sending them to OPA | -| `opa.log-responses` | No | `false` | Determines whether OPA responses (URI, status code, headers and entire body) are logged | -| `opa.http-client.*` | No | Unset | Additional HTTP client configurations that get passed down. E.g. `opa.http-client.http-proxy` for configuring the HTTP proxy | +| Configuration name | Required | Default | Description | +|--------------------------------------|:--------:|:-------:|----------------------------------------------------------------------------------------------------------------------------------------------------------| +| `opa.policy.uri` | Yes | N/A | Endpoint to query OPA | +| `opa.policy.batched-uri` | No | Unset | Endpoint for batch OPA requests | +| `opa.log-requests` | No | `false` | Determines whether requests (URI, headers and entire body) are logged prior to sending them to OPA | +| `opa.log-responses` | No | `false` | Determines whether OPA responses (URI, status code, headers and entire body) are logged | +| `opa.allow-permissioning-operations` | No | `false` | Determines whether permissioning operations will be allowed. These operations will be allowed or denied based on this setting, no request is sent to OPA | +| `opa.http-client.*` | No | Unset | Additional HTTP client configurations that get passed down. E.g. `opa.http-client.http-proxy` for configuring the HTTP proxy | > When request / response logging is enabled, they will be logged at DEBUG level under the `io.trino.plugin.opa.OpaHttpClient` logger, you will need to update > your log configuration accordingly. > > Be aware that enabling these options will produce very large amounts of logs +##### About permissioning operations + +The following operations are controlled by the `opa.allow-permissioning-operations` setting. If this setting is `true`, these +operations will be allowed; they will otherwise be denied. No request is sent to OPA either way: + +- `GrantSchemaPrivilege` +- `DenySchemaPrivilege` +- `RevokeSchemaPrivilege` +- `GrantTablePrivilege` +- `DenyTablePrivilege` +- `RevokeTablePrivilege` +- `CreateRole` +- `DropRole` +- `GrantRoles` +- `RevokeRoles` + +This is due to the complexity and potential unexpected consequences of having SQL-style grants / roles together with OPA, as per [discussion](https://github.com/trinodb/trino/pull/19532#discussion_r1380776593) +on the initial PR. + +Additionally, users are always allowed to show information about roles (`SHOW ROLES`), regardless of this setting. The following operations are _always_ allowed: +- `ShowRoles` +- `ShowCurrentRoles` +- `ShowRoleGrants` ## OPA queries @@ -77,7 +101,13 @@ A query will contain a `context` and an `action` as its top level fields. #### Query context: -This determines _who_ is performing the operations, and reflects the `SystemSecurityContext` class in Trino. +While the `action` object contains information about _what_ action is being performed, the `context` object +contains all other contextual information about it. The `context` object contains the following fields: +- `identity`: The identity of the user performing the operation, containing the following 2 fields: + - `user` (string): username + - `groups` (array of strings): list of groups this user belongs to +- `softwareStack`: Information about the software stack running in the Trino server, more fields may be added later, currently: + - `trinoVersion` (string): Trino version #### Query action: @@ -86,8 +116,10 @@ This determines _what_ action is being performed and upon what resources, the to - `operation` (string): operation being performed - `resource` (object, nullable): information about the object being operated upon - `targetResource` (object, nullable): information about the _new object_ being created, if applicable -- `grantee` (object, nullable): grantee of a grant operation -- `grantor` (object, nullable): grantor in a grant operation +- `grantee` (object, nullable): grantee of a grant operation. + +Fields that are not applicable for a specific operation (e.g. `targetResource` if not modifying a table/schema/catalog, or `grantee` if not granting +permissions) will be set to null. Any null field will be omitted altogether from the `action` object. #### Examples @@ -99,6 +131,9 @@ Accessing a table will result in a query like the one below: "identity": { "user": "foo", "groups": ["some-group"] + }, + "softwareStack": { + "trinoVersion": "434" } }, "action": { @@ -128,6 +163,9 @@ when renaming a table. "identity": { "user": "foo", "groups": ["some-group"] + }, + "softwareStack": { + "trinoVersion": "434" } }, "action": { From 247d74672f9fa4e4e1a646ca050a0892ab05a5bc Mon Sep 17 00:00:00 2001 From: Pablo Arteaga Date: Thu, 30 Nov 2023 17:41:58 +0000 Subject: [PATCH 05/11] Suggestion to remove parameterized tests --- .../java/io/trino/plugin/opa/TestHelpers.java | 28 ++- .../plugin/opa/TestOpaAccessControl.java | 206 +++++++----------- 2 files changed, 91 insertions(+), 143 deletions(-) diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java index 8733f3142eb6f..eb3c6b5e3b107 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java @@ -32,6 +32,7 @@ import java.time.Instant; import java.util.Arrays; import java.util.Optional; +import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Stream; @@ -105,38 +106,41 @@ public static SystemSecurityContext systemSecurityContextFromIdentity(Identity i return new SystemSecurityContext(identity, new QueryIdGenerator().createNextQueryId(), Instant.now()); } - public abstract static class MethodWrapper { - public abstract boolean isAccessAllowed(OpaAccessControl opaAccessControl, SystemSecurityContext systemSecurityContext, T argument); + public abstract static class MethodWrapper { + public abstract boolean isAccessAllowed(OpaAccessControl opaAccessControl); } - public static class ThrowingMethodWrapper extends MethodWrapper { - private final FunctionalHelpers.Consumer3 callable; + public static class ThrowingMethodWrapper extends MethodWrapper { + private final Consumer callable; - public ThrowingMethodWrapper(FunctionalHelpers.Consumer3 callable) { + public ThrowingMethodWrapper(Consumer callable) { this.callable = callable; } @Override - public boolean isAccessAllowed(OpaAccessControl opaAccessControl, SystemSecurityContext systemSecurityContext, T argument) { + public boolean isAccessAllowed(OpaAccessControl opaAccessControl) { try { - this.callable.accept(opaAccessControl, systemSecurityContext, argument); + this.callable.accept(opaAccessControl); return true; } catch (AccessDeniedException e) { + if (!e.getMessage().contains("Access Denied")) { + throw new AssertionError("Expected AccessDenied exception to contain 'Access Denied' in the message"); + } return false; } } } - public static class ReturningMethodWrapper extends MethodWrapper { - private final FunctionalHelpers.Function3 callable; + public static class ReturningMethodWrapper extends MethodWrapper { + private final Function callable; - public ReturningMethodWrapper(FunctionalHelpers.Function3 callable) { + public ReturningMethodWrapper(Function callable) { this.callable = callable; } @Override - public boolean isAccessAllowed(OpaAccessControl opaAccessControl, SystemSecurityContext systemSecurityContext, T argument) { - return this.callable.apply(opaAccessControl, systemSecurityContext, argument); + public boolean isAccessAllowed(OpaAccessControl opaAccessControl) { + return this.callable.apply(opaAccessControl); } } diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java index fe151319fb125..8c5e90e1f0c55 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java @@ -16,8 +16,10 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; +import io.trino.plugin.opa.FunctionalHelpers.Pair; import io.trino.plugin.opa.HttpClientUtils.InstrumentedHttpClient; import io.trino.plugin.opa.HttpClientUtils.MockResponse; +import io.trino.plugin.opa.TestHelpers.MethodWrapper; import io.trino.plugin.opa.TestHelpers.TestingSystemAccessControlContext; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.CatalogSchemaRoutineName; @@ -42,10 +44,13 @@ import static io.trino.plugin.opa.RequestTestUtilities.assertStringRequestsEqual; import static io.trino.plugin.opa.RequestTestUtilities.buildValidatingRequestHandler; +import static io.trino.plugin.opa.TestHelpers.BAD_REQUEST_RESPONSE; +import static io.trino.plugin.opa.TestHelpers.MALFORMED_RESPONSE; import static io.trino.plugin.opa.TestHelpers.NO_ACCESS_RESPONSE; import static io.trino.plugin.opa.TestHelpers.OK_RESPONSE; +import static io.trino.plugin.opa.TestHelpers.SERVER_ERROR_RESPONSE; +import static io.trino.plugin.opa.TestHelpers.UNDEFINED_RESPONSE; import static io.trino.plugin.opa.TestHelpers.createFailingTestCases; -import static io.trino.plugin.opa.TestHelpers.createIllegalResponseTestCases; import static io.trino.plugin.opa.TestHelpers.createMockHttpClient; import static io.trino.plugin.opa.TestHelpers.createOpaAuthorizer; import static io.trino.plugin.opa.TestHelpers.systemSecurityContextFromIdentity; @@ -55,6 +60,9 @@ public class TestOpaAccessControl { private static final URI OPA_SERVER_URI = URI.create("http://my-uri/"); + private static final Identity TEST_IDENTITY = Identity.forUser("source-user").withGroups(ImmutableSet.of("some-group")).build(); + private static final SystemSecurityContext TEST_SECURITY_CONTEXT = systemSecurityContextFromIdentity(TEST_IDENTITY); + // The below identity and security ctx would go away if we move all the tests to use their static constant counterparts above private final Identity requestingIdentity = Identity.ofUser("source-user"); private final SystemSecurityContext requestingSecurityContext = systemSecurityContextFromIdentity(requestingIdentity); @@ -71,41 +79,22 @@ public void testResponseHasExtraFields() authorizer.checkCanExecuteQuery(requestingIdentity); } - @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#noResourceActionTestCases") - public void testNoResourceAction(String actionName, BiConsumer method) + @Test + public void testNoResourceAction() { - InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - - method.accept(authorizer, requestingIdentity); - String expectedRequest = """ - { - "operation": "%s" - }""".formatted(actionName); - assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + testNoResourceAction("ExecuteQuery", OpaAccessControl::checkCanExecuteQuery); + testNoResourceAction("ReadSystemInformation", OpaAccessControl::checkCanReadSystemInformation); + testNoResourceAction("WriteSystemInformation", OpaAccessControl::checkCanWriteSystemInformation); } - @ParameterizedTest(name = "{index}: {0} - {2}") - @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#noResourceActionFailureTestCases") - public void testNoResourceActionFailure( - String actionName, - BiConsumer method, - MockResponse failureResponse, - Class expectedException, - String expectedErrorMessage) + private void testNoResourceAction(String actionName, BiConsumer method) { - InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - - assertThatThrownBy(() -> method.accept(authorizer, requestingIdentity)) - .isInstanceOf(expectedException) - .hasMessageContaining(expectedErrorMessage); - String expectedRequest = """ + Set expectedRequests = ImmutableSet.of(""" { "operation": "%s" - }""".formatted(actionName); - assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); + }""".formatted(actionName)); + TestHelpers.ThrowingMethodWrapper wrappedMethod = new TestHelpers.ThrowingMethodWrapper((accessControl) -> method.accept(accessControl, TEST_IDENTITY)); + assertAccessControlMethodBehaviour(wrappedMethod, expectedRequests); } private static Stream tableResourceTestCases() @@ -356,11 +345,11 @@ private static Stream stringResourceTestCases() OpaAccessControl::checkCanCreateCatalog, OpaAccessControl::checkCanDropCatalog, OpaAccessControl::checkCanShowSchemas); - Stream> actionAndResource = Stream.of( - FunctionalHelpers.Pair.of("SetSystemSessionProperty", "systemSessionProperty"), - FunctionalHelpers.Pair.of("CreateCatalog", "catalog"), - FunctionalHelpers.Pair.of("DropCatalog", "catalog"), - FunctionalHelpers.Pair.of("ShowSchemas", "catalog")); + Stream> actionAndResource = Stream.of( + Pair.of("SetSystemSessionProperty", "systemSessionProperty"), + Pair.of("CreateCatalog", "catalog"), + Pair.of("DropCatalog", "catalog"), + Pair.of("ShowSchemas", "catalog")); return Streams.zip( actionAndResource, methods, @@ -979,38 +968,11 @@ public void testCanSetCatalogSessionPropertyFailure( .hasMessageContaining(expectedErrorMessage); } - private static Stream functionResourceTestCases() - { - Stream> methods = Stream.of( - new TestHelpers.ThrowingMethodWrapper<>(OpaAccessControl::checkCanExecuteProcedure), - new TestHelpers.ThrowingMethodWrapper<>(OpaAccessControl::checkCanCreateFunction), - new TestHelpers.ThrowingMethodWrapper<>(OpaAccessControl::checkCanDropFunction), - new TestHelpers.ReturningMethodWrapper<>(OpaAccessControl::canExecuteFunction), - new TestHelpers.ReturningMethodWrapper<>(OpaAccessControl::canCreateViewWithExecuteFunction)); - Stream actions = Stream.of( - "ExecuteProcedure", - "CreateFunction", - "DropFunction", - "ExecuteFunction", - "CreateViewWithExecuteFunction"); - return Streams.zip(actions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); - } - - @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#functionResourceTestCases") - public void testFunctionResourceAction( - String actionName, - TestHelpers.MethodWrapper method) + @Test + public void testFunctionResourceActions() { - InstrumentedHttpClient permissiveClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - InstrumentedHttpClient restrictiveClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, NO_ACCESS_RESPONSE)); - CatalogSchemaRoutineName routine = new CatalogSchemaRoutineName("my_catalog", "my_schema", "my_routine_name"); - assertThat(method.isAccessAllowed(createOpaAuthorizer(OPA_SERVER_URI, permissiveClient), requestingSecurityContext, routine)).isTrue(); - - assertThat(method.isAccessAllowed(createOpaAuthorizer(OPA_SERVER_URI, restrictiveClient), requestingSecurityContext, routine)).isFalse(); - - String expectedRequest = """ + String baseRequest = """ { "operation": "%s", "resource": { @@ -1020,51 +982,35 @@ public void testFunctionResourceAction( "functionName": "my_routine_name" } } - }""".formatted(actionName); - assertStringRequestsEqual(ImmutableSet.of(expectedRequest), permissiveClient.getRequests(), "/input/action"); - assertStringRequestsEqual(ImmutableSet.of(expectedRequest), restrictiveClient.getRequests(), "/input/action"); - } - - private static Stream functionResourceIllegalResponseTestCases() - { - return createIllegalResponseTestCases(functionResourceTestCases()); - } - - @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.TestOpaAccessControl#functionResourceIllegalResponseTestCases") - public void testFunctionResourceIllegalResponses( - String actionName, - TestHelpers.MethodWrapper method, - MockResponse failureResponse, - Class expectedException, - String expectedErrorMessage) - { - InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - - CatalogSchemaRoutineName routine = new CatalogSchemaRoutineName("my_catalog", "my_schema", "my_routine_name"); - assertThatThrownBy( - () -> method.isAccessAllowed(authorizer, requestingSecurityContext, routine)) - .isInstanceOf(expectedException) - .hasMessageContaining(expectedErrorMessage); + }"""; + assertAccessControlMethodBehaviour( + new TestHelpers.ThrowingMethodWrapper(authorizer -> authorizer.checkCanExecuteProcedure(TEST_SECURITY_CONTEXT, routine)), + ImmutableSet.of(baseRequest.formatted("ExecuteProcedure"))); + assertAccessControlMethodBehaviour( + new TestHelpers.ThrowingMethodWrapper(authorizer -> authorizer.checkCanCreateFunction(TEST_SECURITY_CONTEXT, routine)), + ImmutableSet.of(baseRequest.formatted("CreateFunction"))); + assertAccessControlMethodBehaviour( + new TestHelpers.ThrowingMethodWrapper(authorizer -> authorizer.checkCanDropFunction(TEST_SECURITY_CONTEXT, routine)), + ImmutableSet.of(baseRequest.formatted("DropFunction"))); + assertAccessControlMethodBehaviour( + new TestHelpers.ReturningMethodWrapper(authorizer -> authorizer.canExecuteFunction(TEST_SECURITY_CONTEXT, routine)), + ImmutableSet.of(baseRequest.formatted("ExecuteFunction"))); + assertAccessControlMethodBehaviour( + new TestHelpers.ReturningMethodWrapper(authorizer -> authorizer.canCreateViewWithExecuteFunction(TEST_SECURITY_CONTEXT, routine)), + ImmutableSet.of(baseRequest.formatted("CreateViewWithExecuteFunction"))); } @Test public void testCanExecuteTableProcedure() { - InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); - authorizer.checkCanExecuteTableProcedure(requestingSecurityContext, table, "my_procedure"); - String expectedRequest = """ { "operation": "ExecuteTableProcedure", "resource": { "table": { - "schemaName": "my_schema", "catalogName": "my_catalog", + "schemaName": "my_schema", "tableName": "my_table" }, "function": { @@ -1072,27 +1018,9 @@ public void testCanExecuteTableProcedure() } } }"""; - assertStringRequestsEqual(ImmutableSet.of(expectedRequest), mockClient.getRequests(), "/input/action"); - } - - @ParameterizedTest(name = "{index}: {0}") - @MethodSource("io.trino.plugin.opa.TestHelpers#allErrorCasesArgumentProvider") - public void testCanExecuteTableProcedureFailure( - MockResponse failureResponse, - Class expectedException, - String expectedErrorMessage) - { - InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - - CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); - assertThatThrownBy( - () -> authorizer.checkCanExecuteTableProcedure( - requestingSecurityContext, - table, - "my_procedure")) - .isInstanceOf(expectedException) - .hasMessageContaining(expectedErrorMessage); + assertAccessControlMethodBehaviour( + new TestHelpers.ThrowingMethodWrapper(authorizer -> authorizer.checkCanExecuteTableProcedure(TEST_SECURITY_CONTEXT, table, "my_procedure")), + ImmutableSet.of(expectedRequest)); } @Test @@ -1138,21 +1066,37 @@ private void testRequestContextContentsForGivenTrinoVersion(Optional noResourceActionTestCases() + private static void assertAccessControlMethodBehaviour(MethodWrapper method, Set expectedRequests) { - Stream> methods = Stream.of( - OpaAccessControl::checkCanExecuteQuery, - OpaAccessControl::checkCanReadSystemInformation, - OpaAccessControl::checkCanWriteSystemInformation); - Stream expectedActions = Stream.of( - "ExecuteQuery", - "ReadSystemInformation", - "WriteSystemInformation"); - return Streams.zip(expectedActions, methods, (action, method) -> Arguments.of(Named.of(action, action), method)); + InstrumentedHttpClient permissiveMockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(TEST_IDENTITY, OK_RESPONSE)); + InstrumentedHttpClient restrictiveMockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(TEST_IDENTITY, NO_ACCESS_RESPONSE)); + + assertThat(method.isAccessAllowed(createOpaAuthorizer(OPA_SERVER_URI, permissiveMockClient))).isTrue(); + assertThat(method.isAccessAllowed(createOpaAuthorizer(OPA_SERVER_URI, restrictiveMockClient))).isFalse(); + assertThat(permissiveMockClient.getRequests()).containsExactlyInAnyOrderElementsOf(restrictiveMockClient.getRequests()); + assertStringRequestsEqual(expectedRequests, permissiveMockClient.getRequests(), "/input/action"); + assertAccessControlMethodThrowsForIllegalResponses(method); } - private static Stream noResourceActionFailureTestCases() + private static void assertAccessControlMethodThrowsForIllegalResponses(MethodWrapper methodToTest) { - return createFailingTestCases(noResourceActionTestCases()); + assertAccessControlMethodThrowsForResponse(methodToTest, UNDEFINED_RESPONSE, OpaQueryException.OpaServerError.PolicyNotFound.class, "did not return a value"); + assertAccessControlMethodThrowsForResponse(methodToTest, BAD_REQUEST_RESPONSE, OpaQueryException.OpaServerError.class, "returned status 400"); + assertAccessControlMethodThrowsForResponse(methodToTest, SERVER_ERROR_RESPONSE, OpaQueryException.OpaServerError.class, "returned status 500"); + assertAccessControlMethodThrowsForResponse(methodToTest, MALFORMED_RESPONSE, OpaQueryException.class, "Failed to deserialize"); + } + + private static void assertAccessControlMethodThrowsForResponse( + MethodWrapper methodToTest, + MockResponse response, + Class expectedException, + String expectedErrorMessage) + { + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(TEST_IDENTITY, response)); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + + assertThatThrownBy(() -> methodToTest.isAccessAllowed(authorizer)) + .isInstanceOf(expectedException) + .hasMessageContaining(expectedErrorMessage); } } From 70a3a1bcc568a0a2e52016a3a5f4e31d558824b8 Mon Sep 17 00:00:00 2001 From: Pablo Arteaga Date: Thu, 30 Nov 2023 17:52:26 +0000 Subject: [PATCH 06/11] Code review: use synchronizedList --- .../test/java/io/trino/plugin/opa/HttpClientUtils.java | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java index 980ce862c735e..1d21b70e7adf6 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java @@ -27,7 +27,8 @@ import java.io.IOException; import java.net.URI; import java.nio.charset.StandardCharsets; -import java.util.LinkedList; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.function.Function; @@ -43,7 +44,7 @@ public static class RecordingHttpProcessor implements TestingHttpClient.Processor { private static final JsonMapper jsonMapper = new JsonMapper(); - private final List requests = new LinkedList<>(); + private final List requests = Collections.synchronizedList(new ArrayList<>()); private final Function handler; private final URI expectedURI; private final String expectedMethod; @@ -58,7 +59,7 @@ public RecordingHttpProcessor(URI expectedURI, String expectedMethod, String exp } @Override - public synchronized Response handle(Request request) + public Response handle(Request request) { if (!requireNonNull(request.getMethod()).equalsIgnoreCase(expectedMethod)) { throw new IllegalArgumentException("Unexpected method: %s".formatted(request.getMethod())); @@ -86,7 +87,7 @@ public synchronized Response handle(Request request) } } - public synchronized List getRequests() + public List getRequests() { return ImmutableList.copyOf(requests); } From 75e7f1e0f02b3c50e017c6029a823839cfffe5bc Mon Sep 17 00:00:00 2001 From: Pablo Arteaga Date: Thu, 30 Nov 2023 17:54:18 +0000 Subject: [PATCH 07/11] Make utility classes final --- .../trino/plugin/opa/FilteringTestHelpers.java | 2 +- .../io/trino/plugin/opa/FunctionalHelpers.java | 17 +---------------- .../io/trino/plugin/opa/HttpClientUtils.java | 4 ++-- .../trino/plugin/opa/RequestTestUtilities.java | 2 +- .../java/io/trino/plugin/opa/TestHelpers.java | 2 +- 5 files changed, 6 insertions(+), 21 deletions(-) diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FilteringTestHelpers.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FilteringTestHelpers.java index 75c9b628aa5e7..a3f24e1554411 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FilteringTestHelpers.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FilteringTestHelpers.java @@ -29,7 +29,7 @@ import static io.trino.plugin.opa.TestHelpers.createIllegalResponseTestCases; -public class FilteringTestHelpers +public final class FilteringTestHelpers { private FilteringTestHelpers() {} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FunctionalHelpers.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FunctionalHelpers.java index e3fbdc2c78a30..25354b7176858 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FunctionalHelpers.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/FunctionalHelpers.java @@ -13,33 +13,18 @@ */ package io.trino.plugin.opa; -public class FunctionalHelpers +public final class FunctionalHelpers { public interface Consumer3 { void accept(T1 t1, T2 t2, T3 t3); } - public interface Function3 - { - R apply(T1 t1, T2 t2, T3 t3); - } - public interface Consumer4 { void accept(T1 t1, T2 t2, T3 t3, T4 t4); } - public interface Consumer5 - { - void accept(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5); - } - - public interface Consumer6 - { - void accept(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6); - } - public record Pair(T first, U second) { public static Pair of(T first, U second) diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java index 1d21b70e7adf6..83ae8c57331f4 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/HttpClientUtils.java @@ -36,7 +36,7 @@ import static com.google.common.net.MediaType.JSON_UTF_8; import static java.util.Objects.requireNonNull; -public class HttpClientUtils +public final class HttpClientUtils { private HttpClientUtils() {} @@ -93,7 +93,7 @@ public List getRequests() } } - public static class InstrumentedHttpClient + public static final class InstrumentedHttpClient extends TestingHttpClient { private final RecordingHttpProcessor httpProcessor; diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java index b15a585501a91..7bf683b5d79cc 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java @@ -28,7 +28,7 @@ import static io.trino.plugin.opa.TestHelpers.SYSTEM_ACCESS_CONTROL_CONTEXT; import static org.assertj.core.api.Assertions.assertThat; -public class RequestTestUtilities +public final class RequestTestUtilities { private RequestTestUtilities() {} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java index eb3c6b5e3b107..d47e2a4280d86 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java @@ -39,7 +39,7 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.net.MediaType.JSON_UTF_8; -public class TestHelpers +public final class TestHelpers { private TestHelpers() {} From fd6d65ce00890a39738407bee549edfd9b213c1c Mon Sep 17 00:00:00 2001 From: Pablo Arteaga Date: Thu, 21 Dec 2023 16:46:12 +0100 Subject: [PATCH 08/11] Bump version --- plugin/trino-opa/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/trino-opa/pom.xml b/plugin/trino-opa/pom.xml index 5b9f3fa36a180..739d4dfe45b48 100644 --- a/plugin/trino-opa/pom.xml +++ b/plugin/trino-opa/pom.xml @@ -4,7 +4,7 @@ io.trino trino-root - 434-SNAPSHOT + 436-SNAPSHOT ../../pom.xml From 701c0e0675a9642a8eea5afc13d35aafac1935e8 Mon Sep 17 00:00:00 2001 From: Pablo Arteaga Date: Mon, 18 Dec 2023 10:56:31 +0000 Subject: [PATCH 09/11] Implement row level filtering and masking and add system tests --- plugin/trino-opa/pom.xml | 8 + .../io/trino/plugin/opa/OpaAccessControl.java | 22 ++ .../plugin/opa/OpaAccessControlFactory.java | 4 + .../java/io/trino/plugin/opa/OpaConfig.java | 31 ++ .../trino/plugin/opa/OpaHighLevelClient.java | 47 ++- .../opa/schema/OpaColumnMaskQueryResult.java | 29 ++ .../opa/schema/OpaQueryInputResource.java | 13 +- .../opa/schema/OpaRowFiltersQueryResult.java | 30 ++ .../plugin/opa/schema/OpaViewExpression.java | 40 +++ .../trino/plugin/opa/schema/TrinoColumn.java | 46 +++ .../opa/DistributedQueryRunnerHelper.java | 71 +++++ .../io/trino/plugin/opa/OpaContainer.java | 90 ++++++ .../java/io/trino/plugin/opa/TestHelpers.java | 63 +++- .../plugin/opa/TestOpaAccessControl.java | 296 +++++++++++++++--- ...stOpaAccessControlDataFilteringSystem.java | 286 +++++++++++++++++ .../opa/TestOpaAccessControlFiltering.java | 19 +- .../opa/TestOpaAccessControlSystem.java | 144 ++------- .../TestOpaBatchAccessControlFiltering.java | 26 +- .../io/trino/plugin/opa/TestOpaConfig.java | 6 + .../plugin/opa/TestOpaResponseDecoding.java | 168 +++++++++- 20 files changed, 1250 insertions(+), 189 deletions(-) create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaColumnMaskQueryResult.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaRowFiltersQueryResult.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaViewExpression.java create mode 100644 plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoColumn.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/DistributedQueryRunnerHelper.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaContainer.java create mode 100644 plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlDataFilteringSystem.java diff --git a/plugin/trino-opa/pom.xml b/plugin/trino-opa/pom.xml index 739d4dfe45b48..27b1d5526fafb 100644 --- a/plugin/trino-opa/pom.xml +++ b/plugin/trino-opa/pom.xml @@ -121,11 +121,19 @@ test + + io.trino + trino-main + test-jar + test + + io.trino trino-testing test + org.assertj assertj-core diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java index 4da0f63d6c8b8..fafd8a9d8cd16 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControl.java @@ -22,6 +22,7 @@ import io.trino.plugin.opa.schema.OpaQueryInput; import io.trino.plugin.opa.schema.OpaQueryInputAction; import io.trino.plugin.opa.schema.OpaQueryInputResource; +import io.trino.plugin.opa.schema.OpaViewExpression; import io.trino.plugin.opa.schema.TrinoCatalogSessionProperty; import io.trino.plugin.opa.schema.TrinoFunction; import io.trino.plugin.opa.schema.TrinoGrantPrincipal; @@ -40,15 +41,19 @@ import io.trino.spi.security.SystemAccessControl; import io.trino.spi.security.SystemSecurityContext; import io.trino.spi.security.TrinoPrincipal; +import io.trino.spi.security.ViewExpression; +import io.trino.spi.type.Type; import java.security.Principal; import java.util.Collection; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.function.BiConsumer; import java.util.function.Consumer; +import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.plugin.opa.OpaHighLevelClient.buildQueryInputForSimpleResource; @@ -709,6 +714,23 @@ public void checkCanDropFunction(SystemSecurityContext systemSecurityContext, Ca OpaQueryInputResource.builder().function(TrinoFunction.fromTrinoFunction(functionName)).build()); } + @Override + public List getRowFilters(SystemSecurityContext context, CatalogSchemaTableName tableName) + { + List rowFilterExpressions = opaHighLevelClient.getRowFilterExpressionsFromOpa(buildQueryContext(context), tableName); + return rowFilterExpressions.stream() + .map(expression -> expression.toTrinoViewExpression(tableName.getCatalogName(), tableName.getSchemaTableName().getSchemaName())) + .collect(toImmutableList()); + } + + @Override + public Optional getColumnMask(SystemSecurityContext context, CatalogSchemaTableName tableName, String columnName, Type type) + { + return opaHighLevelClient + .getColumnMaskFromOpa(buildQueryContext(context), tableName, columnName, type) + .map(expression -> expression.toTrinoViewExpression(tableName.getCatalogName(), tableName.getSchemaTableName().getSchemaName())); + } + private void checkTableOperation(SystemSecurityContext context, String actionName, CatalogSchemaTableName table, Consumer deny) { opaHighLevelClient.queryAndEnforce( diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlFactory.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlFactory.java index e23d2838f9656..6aa6f45dc7739 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlFactory.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaAccessControlFactory.java @@ -22,9 +22,11 @@ import io.airlift.concurrent.BoundedExecutor; import io.airlift.http.client.HttpClient; import io.airlift.json.JsonModule; +import io.trino.plugin.opa.schema.OpaColumnMaskQueryResult; import io.trino.plugin.opa.schema.OpaPluginContext; import io.trino.plugin.opa.schema.OpaQuery; import io.trino.plugin.opa.schema.OpaQueryResult; +import io.trino.plugin.opa.schema.OpaRowFiltersQueryResult; import io.trino.spi.security.SystemAccessControl; import io.trino.spi.security.SystemAccessControlFactory; @@ -71,6 +73,8 @@ protected static SystemAccessControl create(Map config, Optional binder -> { jsonCodecBinder(binder).bindJsonCodec(OpaQuery.class); jsonCodecBinder(binder).bindJsonCodec(OpaQueryResult.class); + jsonCodecBinder(binder).bindJsonCodec(OpaRowFiltersQueryResult.class); + jsonCodecBinder(binder).bindJsonCodec(OpaColumnMaskQueryResult.class); httpClient.ifPresentOrElse( client -> binder.bind(Key.get(HttpClient.class, ForOpa.class)).toInstance(client), () -> httpClientBinder(binder).bindHttpClient("opa", ForOpa.class)); diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaConfig.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaConfig.java index 675e02b8c7dcc..d215c8c0b046e 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaConfig.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaConfig.java @@ -28,6 +28,8 @@ public class OpaConfig private boolean logRequests; private boolean logResponses; private boolean allowPermissioningOperations; + private Optional opaRowFiltersUri = Optional.empty(); + private Optional opaColumnMaskingUri = Optional.empty(); @NotNull public URI getOpaUri() @@ -43,6 +45,7 @@ public OpaConfig setOpaUri(@NotNull URI opaUri) return this; } + @NotNull public Optional getOpaBatchUri() { return opaBatchUri; @@ -94,4 +97,32 @@ public OpaConfig setAllowPermissioningOperations(boolean allowPermissioningOpera this.allowPermissioningOperations = allowPermissioningOperations; return this; } + + @NotNull + public Optional getOpaRowFiltersUri() + { + return opaRowFiltersUri; + } + + @Config("opa.policy.row-filters-uri") + @ConfigDescription("URI for fetching row filters - if not set no row filtering will be applied") + public OpaConfig setOpaRowFiltersUri(@NotNull URI opaRowFiltersUri) + { + this.opaRowFiltersUri = Optional.ofNullable(opaRowFiltersUri); + return this; + } + + @NotNull + public Optional getOpaColumnMaskingUri() + { + return opaColumnMaskingUri; + } + + @Config("opa.policy.column-masking-uri") + @ConfigDescription("URI for fetching column masks - if not set no masking will be applied") + public OpaConfig setOpaColumnMaskingUri(URI opaColumnMaskingUri) + { + this.opaColumnMaskingUri = Optional.ofNullable(opaColumnMaskingUri); + return this; + } } diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHighLevelClient.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHighLevelClient.java index 81a99ddde5dc4..054e187f97a39 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHighLevelClient.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHighLevelClient.java @@ -13,17 +13,27 @@ */ package io.trino.plugin.opa; +import com.google.common.collect.ImmutableList; import com.google.inject.Inject; import io.airlift.json.JsonCodec; +import io.trino.plugin.opa.schema.OpaColumnMaskQueryResult; import io.trino.plugin.opa.schema.OpaQueryContext; import io.trino.plugin.opa.schema.OpaQueryInput; import io.trino.plugin.opa.schema.OpaQueryInputAction; import io.trino.plugin.opa.schema.OpaQueryInputResource; import io.trino.plugin.opa.schema.OpaQueryResult; +import io.trino.plugin.opa.schema.OpaRowFiltersQueryResult; +import io.trino.plugin.opa.schema.OpaViewExpression; +import io.trino.plugin.opa.schema.TrinoColumn; +import io.trino.plugin.opa.schema.TrinoTable; +import io.trino.spi.connector.CatalogSchemaTableName; import io.trino.spi.security.AccessDeniedException; +import io.trino.spi.type.Type; import java.net.URI; import java.util.Collection; +import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -32,18 +42,28 @@ public class OpaHighLevelClient { private final JsonCodec queryResultCodec; - private final URI opaPolicyUri; + private final JsonCodec rowFiltersQueryResultCodec; + private final JsonCodec columnMaskQueryResultCodec; private final OpaHttpClient opaHttpClient; + private final URI opaPolicyUri; + private final Optional opaRowFiltersUri; + private final Optional opaColumnMaskingUri; @Inject public OpaHighLevelClient( JsonCodec queryResultCodec, + JsonCodec rowFiltersQueryResultCodec, + JsonCodec columnMaskQueryResultCodec, OpaHttpClient opaHttpClient, OpaConfig config) { this.queryResultCodec = requireNonNull(queryResultCodec, "queryResultCodec is null"); + this.rowFiltersQueryResultCodec = requireNonNull(rowFiltersQueryResultCodec, "rowFiltersQueryResultCodec is null"); + this.columnMaskQueryResultCodec = requireNonNull(columnMaskQueryResultCodec, "columnMaskQueryResultCodec is null"); this.opaHttpClient = requireNonNull(opaHttpClient, "opaHttpClient is null"); this.opaPolicyUri = config.getOpaUri(); + this.opaRowFiltersUri = config.getOpaRowFiltersUri(); + this.opaColumnMaskingUri = config.getOpaColumnMaskingUri(); } public boolean queryOpa(OpaQueryInput input) @@ -105,6 +125,31 @@ public Set parallelFilterFromOpa( return opaHttpClient.parallelFilterFromOpa(items, requestBuilder, opaPolicyUri, queryResultCodec); } + public List getRowFilterExpressionsFromOpa(OpaQueryContext context, CatalogSchemaTableName table) + { + OpaQueryInput queryInput = new OpaQueryInput( + context, + OpaQueryInputAction.builder() + .operation("GetRowFilters") + .resource(OpaQueryInputResource.builder().table(new TrinoTable(table)).build()) + .build()); + return opaRowFiltersUri + .map(uri -> opaHttpClient.consumeOpaResponse(opaHttpClient.submitOpaRequest(queryInput, uri, rowFiltersQueryResultCodec)).result()) + .orElse(ImmutableList.of()); + } + + public Optional getColumnMaskFromOpa(OpaQueryContext context, CatalogSchemaTableName table, String columnName, Type type) + { + OpaQueryInput queryInput = new OpaQueryInput( + context, + OpaQueryInputAction.builder() + .operation("GetColumnMask") + .resource(OpaQueryInputResource.builder().column(new TrinoColumn(table, columnName, type)).build()) + .build()); + return opaColumnMaskingUri + .flatMap(uri -> opaHttpClient.consumeOpaResponse(opaHttpClient.submitOpaRequest(queryInput, uri, columnMaskQueryResultCodec)).result()); + } + public static OpaQueryInput buildQueryInputForSimpleResource(OpaQueryContext context, String operation, OpaQueryInputResource resource) { return new OpaQueryInput(context, OpaQueryInputAction.builder().operation(operation).resource(resource).build()); diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaColumnMaskQueryResult.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaColumnMaskQueryResult.java new file mode 100644 index 0000000000000..d19ac46c1c262 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaColumnMaskQueryResult.java @@ -0,0 +1,29 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import com.fasterxml.jackson.annotation.JsonProperty; +import jakarta.validation.constraints.NotNull; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public record OpaColumnMaskQueryResult(@JsonProperty("decision_id") String decisionId, @NotNull Optional result) +{ + public OpaColumnMaskQueryResult + { + requireNonNull(result, "result is null"); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputResource.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputResource.java index 61820030882ef..a6468fa305a85 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputResource.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaQueryInputResource.java @@ -27,7 +27,8 @@ public record OpaQueryInputResource( TrinoFunction function, NamedEntity catalog, TrinoSchema schema, - TrinoTable table) + TrinoTable table, + TrinoColumn column) { public record NamedEntity(@NotNull String name) { @@ -51,6 +52,7 @@ public static class Builder private TrinoSchema schema; private TrinoTable table; private TrinoFunction function; + private TrinoColumn column; private Builder() {} @@ -102,6 +104,12 @@ public Builder function(String functionName) return this; } + public Builder column(TrinoColumn column) + { + this.column = column; + return this; + } + public OpaQueryInputResource build() { return new OpaQueryInputResource( @@ -111,7 +119,8 @@ public OpaQueryInputResource build() this.function, this.catalog, this.schema, - this.table); + this.table, + this.column); } } } diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaRowFiltersQueryResult.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaRowFiltersQueryResult.java new file mode 100644 index 0000000000000..98e6f5b5facdd --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaRowFiltersQueryResult.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import jakarta.validation.constraints.NotNull; + +import java.util.List; + +import static java.util.Objects.requireNonNullElse; + +public record OpaRowFiltersQueryResult(@JsonProperty("decision_id") String decisionId, @NotNull List result) +{ + public OpaRowFiltersQueryResult + { + result = ImmutableList.copyOf(requireNonNullElse(result, ImmutableList.of())); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaViewExpression.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaViewExpression.java new file mode 100644 index 0000000000000..5d5bd228aa3f8 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/OpaViewExpression.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import io.trino.spi.security.ViewExpression; +import jakarta.validation.constraints.NotNull; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public record OpaViewExpression(@NotNull String expression, @NotNull Optional identity) +{ + public OpaViewExpression + { + requireNonNull(expression, "expression is null"); + requireNonNull(identity, "identity is null"); + } + + public ViewExpression toTrinoViewExpression(String catalogName, String schemaName) + { + ViewExpression.Builder builder = ViewExpression.builder() + .catalog(catalogName) + .schema(schemaName) + .expression(expression); + identity.ifPresent(builder::identity); + return builder.build(); + } +} diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoColumn.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoColumn.java new file mode 100644 index 0000000000000..4f21cb04a20b6 --- /dev/null +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoColumn.java @@ -0,0 +1,46 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa.schema; + +import io.trino.spi.connector.CatalogSchemaTableName; +import io.trino.spi.type.Type; +import jakarta.validation.constraints.NotNull; + +import static java.util.Objects.requireNonNull; + +public record TrinoColumn( + @NotNull String catalogName, + @NotNull String schemaName, + @NotNull String tableName, + @NotNull String columnName, + @NotNull String columnType) +{ + public TrinoColumn + { + requireNonNull(catalogName, "catalogName is null"); + requireNonNull(schemaName, "schemaName is null"); + requireNonNull(tableName, "tableName is null"); + requireNonNull(columnName, "columnName is null"); + requireNonNull(columnType, "columnType is null"); + } + + public TrinoColumn(CatalogSchemaTableName tableName, String columnName, Type type) + { + this(tableName.getCatalogName(), + tableName.getSchemaTableName().getSchemaName(), + tableName.getSchemaTableName().getTableName(), + columnName, + type.getDisplayName()); + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/DistributedQueryRunnerHelper.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/DistributedQueryRunnerHelper.java new file mode 100644 index 0000000000000..36a8fa737151e --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/DistributedQueryRunnerHelper.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import io.trino.Session; +import io.trino.spi.security.Identity; +import io.trino.testing.DistributedQueryRunner; + +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.testing.TestingSession.testSessionBuilder; + +public final class DistributedQueryRunnerHelper +{ + private final DistributedQueryRunner runner; + + private DistributedQueryRunnerHelper(DistributedQueryRunner runner) + { + this.runner = runner; + } + + public static DistributedQueryRunnerHelper withOpaConfig(Map opaConfig) + throws Exception + { + return new DistributedQueryRunnerHelper( + DistributedQueryRunner.builder(testSessionBuilder().build()) + .setSystemAccessControl(new OpaAccessControlFactory().create(opaConfig)) + .setNodeCount(1) + .build()); + } + + public Set querySetOfStrings(String user, String query) + { + return querySetOfStrings(userSession(user), query); + } + + public Set querySetOfStrings(Session session, String query) + { + return runner.execute(session, query).getMaterializedRows().stream().map(row -> row.getField(0) == null ? "" : row.getField(0).toString()).collect(toImmutableSet()); + } + + public DistributedQueryRunner getBaseQueryRunner() + { + return this.runner; + } + + public void teardown() + { + if (this.runner != null) { + this.runner.close(); + } + } + + private static Session userSession(String user) + { + return testSessionBuilder().setIdentity(Identity.ofUser(user)).build(); + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaContainer.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaContainer.java new file mode 100644 index 0000000000000..ddb8c4aaf6fce --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaContainer.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; +import org.testcontainers.lifecycle.Startable; +import org.testcontainers.utility.DockerImageName; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; + +public class OpaContainer + implements Startable +{ + private static final int OPA_PORT = 8181; + private static final String OPA_BASE_PATH = "v1/data/trino/"; + private static final String OPA_POLICY_PUSH_BASE_PATH = "v1/policies/trino"; + + private final GenericContainer container; + private URI resolvedUri; + + public OpaContainer() + { + this.container = new GenericContainer<>(DockerImageName.parse("openpolicyagent/opa:latest-rootless")) + .withCommand("run", "--server", "--addr", ":%d".formatted(OPA_PORT)) + .withExposedPorts(OPA_PORT) + .waitingFor(Wait.forListeningPort()); + } + + @Override + public synchronized void start() + { + this.container.start(); + this.resolvedUri = null; + } + + @Override + public synchronized void stop() + { + this.container.stop(); + this.resolvedUri = null; + } + + public synchronized URI getOpaServerUri() + { + if (!container.isRunning()) { + this.resolvedUri = null; + throw new IllegalStateException("Container is not running"); + } + if (this.resolvedUri == null) { + this.resolvedUri = URI.create(String.format("http://%s:%d/", container.getHost(), container.getMappedPort(OPA_PORT))); + } + return this.resolvedUri; + } + + public URI getOpaUriForPolicyPath(String relativePath) + { + return getOpaServerUri().resolve(OPA_BASE_PATH + relativePath); + } + + public void submitPolicy(String... policyLines) + throws IOException, InterruptedException + { + HttpClient httpClient = HttpClient.newHttpClient(); + HttpResponse policyResponse = + httpClient.send( + HttpRequest.newBuilder(getOpaServerUri().resolve(OPA_POLICY_PUSH_BASE_PATH)) + .PUT(HttpRequest.BodyPublishers.ofString(String.join("\n", policyLines))) + .header("Content-Type", "text/plain").build(), + HttpResponse.BodyHandlers.ofString()); + if (policyResponse.statusCode() != 200) { + throw new RuntimeException("Failed to submit policy: %s".formatted(policyResponse.body())); + } + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java index d47e2a4280d86..ea3b36a5710cb 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java @@ -16,6 +16,7 @@ import com.fasterxml.jackson.databind.JsonNode; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; +import io.airlift.configuration.ConfigurationMetadata; import io.opentelemetry.api.OpenTelemetry; import io.opentelemetry.api.trace.Tracer; import io.trino.execution.QueryIdGenerator; @@ -28,9 +29,11 @@ import org.junit.jupiter.api.Named; import org.junit.jupiter.params.provider.Arguments; +import java.lang.reflect.InvocationTargetException; import java.net.URI; import java.time.Instant; import java.util.Arrays; +import java.util.Map; import java.util.Optional; import java.util.function.Consumer; import java.util.function.Function; @@ -149,20 +152,60 @@ public static InstrumentedHttpClient createMockHttpClient(URI expectedUri, Funct return new InstrumentedHttpClient(expectedUri, "POST", JSON_UTF_8.toString(), handler); } - public static OpaAccessControl createOpaAuthorizer(URI opaUri, InstrumentedHttpClient mockHttpClient) + public static OpaAccessControl createOpaAuthorizer(Map config, InstrumentedHttpClient mockHttpClient) { - return (OpaAccessControl) OpaAccessControlFactory.create(ImmutableMap.of("opa.policy.uri", opaUri.toString()), Optional.of(mockHttpClient), Optional.of(SYSTEM_ACCESS_CONTROL_CONTEXT)); + return (OpaAccessControl) OpaAccessControlFactory.create(config, Optional.of(mockHttpClient), Optional.of(SYSTEM_ACCESS_CONTROL_CONTEXT)); } - public static OpaAccessControl createOpaAuthorizer(URI opaUri, URI opaBatchUri, InstrumentedHttpClient mockHttpClient) + public static final class OpaConfigBuilder { - return (OpaAccessControl) OpaAccessControlFactory.create( - ImmutableMap.builder() - .put("opa.policy.uri", opaUri.toString()) - .put("opa.policy.batched-uri", opaBatchUri.toString()) - .buildOrThrow(), - Optional.of(mockHttpClient), - Optional.of(SYSTEM_ACCESS_CONTROL_CONTEXT)); + private final OpaConfig config = new OpaConfig(); + + public OpaConfigBuilder withBasePolicy(URI basePolicy) + { + config.setOpaUri(basePolicy); + return this; + } + + public OpaConfigBuilder withBatchPolicy(URI batchPolicy) + { + config.setOpaBatchUri(batchPolicy); + return this; + } + + public OpaConfigBuilder withRowFiltersPolicy(URI rowFiltersPolicy) + { + config.setOpaRowFiltersUri(rowFiltersPolicy); + return this; + } + + public OpaConfigBuilder withColumnMaskingPolicy(URI columnMaskingPolicy) + { + config.setOpaColumnMaskingUri(columnMaskingPolicy); + return this; + } + + public Map buildConfig() + { + ConfigurationMetadata metadata = ConfigurationMetadata.getValidConfigurationMetadata(OpaConfig.class); + ImmutableMap.Builder opaConfigBuilder = ImmutableMap.builder(); + try { + for (ConfigurationMetadata.AttributeMetadata attribute : metadata.getAttributes().values()) { + convertPropertyToString(attribute.getGetter().invoke(config)).ifPresent( + propertyValue -> opaConfigBuilder.put(attribute.getInjectionPoint().getProperty(), propertyValue)); + } + } catch (InvocationTargetException|IllegalAccessException e) { + throw new AssertionError("Failed to build config map", e); + } + return opaConfigBuilder.buildOrThrow(); + } + + private static Optional convertPropertyToString(Object value) { + if (value instanceof Optional optionalValue) { + return optionalValue.map(Object::toString); + } + return Optional.ofNullable(value).map(Object::toString); + } } static final class TestingSystemAccessControlContext diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java index 8c5e90e1f0c55..45bf09b8e41fe 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java @@ -13,6 +13,7 @@ */ package io.trino.plugin.opa; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Streams; @@ -21,6 +22,7 @@ import io.trino.plugin.opa.HttpClientUtils.MockResponse; import io.trino.plugin.opa.TestHelpers.MethodWrapper; import io.trino.plugin.opa.TestHelpers.TestingSystemAccessControlContext; +import io.trino.plugin.opa.schema.OpaViewExpression; import io.trino.spi.connector.CatalogSchemaName; import io.trino.spi.connector.CatalogSchemaRoutineName; import io.trino.spi.connector.CatalogSchemaTableName; @@ -29,6 +31,8 @@ import io.trino.spi.security.SystemAccessControlFactory; import io.trino.spi.security.SystemSecurityContext; import io.trino.spi.security.TrinoPrincipal; +import io.trino.spi.security.ViewExpression; +import io.trino.spi.type.VarcharType; import org.junit.jupiter.api.Named; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -36,10 +40,12 @@ import org.junit.jupiter.params.provider.MethodSource; import java.net.URI; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.function.BiConsumer; +import java.util.function.Consumer; import java.util.stream.Stream; import static io.trino.plugin.opa.RequestTestUtilities.assertStringRequestsEqual; @@ -60,8 +66,11 @@ public class TestOpaAccessControl { private static final URI OPA_SERVER_URI = URI.create("http://my-uri/"); + private static final URI OPA_SERVER_ROW_FILTERING_URI = URI.create("http://my-row-filtering-uri"); + private static final URI OPA_SERVER_COLUMN_MASK_URI = URI.create("http://my-column-masking-uri"); private static final Identity TEST_IDENTITY = Identity.forUser("source-user").withGroups(ImmutableSet.of("some-group")).build(); private static final SystemSecurityContext TEST_SECURITY_CONTEXT = systemSecurityContextFromIdentity(TEST_IDENTITY); + private static final Map OPA_CONFIG_WITH_ONLY_ALLOW = new TestHelpers.OpaConfigBuilder().withBasePolicy(OPA_SERVER_URI).buildConfig(); // The below identity and security ctx would go away if we move all the tests to use their static constant counterparts above private final Identity requestingIdentity = Identity.ofUser("source-user"); private final SystemSecurityContext requestingSecurityContext = systemSecurityContextFromIdentity(requestingIdentity); @@ -75,7 +84,7 @@ public void testResponseHasExtraFields() "decision_id": "foo", "some_debug_info": {"test": ""} }""")); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); authorizer.checkCanExecuteQuery(requestingIdentity); } @@ -145,7 +154,7 @@ public void testTableResourceActions( FunctionalHelpers.Consumer3 callable) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); callable.accept( authorizer, @@ -182,7 +191,7 @@ public void testTableResourceFailure( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); assertThatThrownBy( () -> method.accept( @@ -215,7 +224,7 @@ public void testTableWithPropertiesActions( FunctionalHelpers.Consumer4 callable) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); Map> properties = ImmutableMap.>builder() @@ -261,7 +270,7 @@ public void testTableWithPropertiesActionFailure( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); assertThatThrownBy( () -> method.accept( @@ -291,7 +300,7 @@ public void testIdentityResourceActions( FunctionalHelpers.Consumer3 callable) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); Identity dummyIdentity = Identity.forUser("dummy-user") .withGroups(ImmutableSet.of("some-group")) @@ -327,7 +336,7 @@ public void testIdentityResourceActionsFailure( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); assertThatThrownBy( () -> method.accept( @@ -364,7 +373,7 @@ public void testStringResourceAction( FunctionalHelpers.Consumer3 callable) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); callable.accept(authorizer, requestingSecurityContext, "resource_name"); @@ -397,7 +406,7 @@ public void testStringResourceActionsFailure( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); assertThatThrownBy( () -> method.accept( @@ -412,7 +421,7 @@ public void testStringResourceActionsFailure( public void testCanImpersonateUser() { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); authorizer.checkCanImpersonateUser(requestingIdentity, "some_other_user"); @@ -437,7 +446,7 @@ public void testCanImpersonateUserFailure( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); assertThatThrownBy( () -> authorizer.checkCanImpersonateUser(requestingIdentity, "some_other_user")) @@ -449,11 +458,11 @@ public void testCanImpersonateUserFailure( public void testCanAccessCatalog() { InstrumentedHttpClient permissiveClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl permissiveAuthorizer = createOpaAuthorizer(OPA_SERVER_URI, permissiveClient); + OpaAccessControl permissiveAuthorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, permissiveClient); assertThat(permissiveAuthorizer.canAccessCatalog(requestingSecurityContext, "test_catalog")).isTrue(); InstrumentedHttpClient restrictiveClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, NO_ACCESS_RESPONSE)); - OpaAccessControl restrictiveAuthorizer = createOpaAuthorizer(OPA_SERVER_URI, restrictiveClient); + OpaAccessControl restrictiveAuthorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, restrictiveClient); assertThat(restrictiveAuthorizer.canAccessCatalog(requestingSecurityContext, "test_catalog")).isFalse(); String expectedRequest = """ @@ -477,7 +486,7 @@ public void testCanAccessCatalogIllegalResponses( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); assertThatThrownBy( () -> authorizer.canAccessCatalog(requestingSecurityContext, "my_catalog")) @@ -507,7 +516,7 @@ public void testSchemaResourceActions( FunctionalHelpers.Consumer3 callable) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); callable.accept(authorizer, requestingSecurityContext, new CatalogSchemaName("my_catalog", "my_schema")); @@ -540,7 +549,7 @@ public void testSchemaResourceActionsFailure( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); assertThatThrownBy( () -> method.accept( @@ -555,7 +564,7 @@ public void testSchemaResourceActionsFailure( public void testCreateSchema() { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); CatalogSchemaName schema = new CatalogSchemaName("my_catalog", "my_schema"); authorizer.checkCanCreateSchema(requestingSecurityContext, schema, ImmutableMap.of("some_key", "some_value")); @@ -600,7 +609,7 @@ public void testCreateSchemaFailure( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); assertThatThrownBy( () -> authorizer.checkCanCreateSchema( @@ -615,7 +624,7 @@ public void testCreateSchemaFailure( public void testCanRenameSchema() { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); CatalogSchemaName sourceSchema = new CatalogSchemaName("my_catalog", "my_schema"); authorizer.checkCanRenameSchema(requestingSecurityContext, sourceSchema, "new_schema_name"); @@ -648,7 +657,7 @@ public void testCanRenameSchemaFailure( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); assertThatThrownBy( () -> authorizer.checkCanRenameSchema( @@ -679,7 +688,7 @@ public void testRenameTableActions( FunctionalHelpers.Consumer4 method) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); CatalogSchemaTableName sourceTable = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); CatalogSchemaTableName targetTable = new CatalogSchemaTableName("my_catalog", "new_schema_name", "new_table_name"); @@ -723,7 +732,7 @@ public void testRenameTableFailure( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); CatalogSchemaTableName sourceTable = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); CatalogSchemaTableName targetTable = new CatalogSchemaTableName("my_catalog", "new_schema_name", "new_table_name"); @@ -741,7 +750,7 @@ public void testRenameTableFailure( public void testCanSetSchemaAuthorization() { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); CatalogSchemaName schema = new CatalogSchemaName("my_catalog", "my_schema"); @@ -773,7 +782,7 @@ public void testCanSetSchemaAuthorizationFailure( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); CatalogSchemaName schema = new CatalogSchemaName("my_catalog", "my_schema"); assertThatThrownBy( @@ -803,7 +812,7 @@ public void testCanSetTableAuthorization( FunctionalHelpers.Consumer4 method) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); @@ -843,7 +852,7 @@ public void testCanSetTableAuthorizationFailure( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); @@ -877,7 +886,7 @@ public void testTableColumnOperations( FunctionalHelpers.Consumer4> method) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); Set columns = ImmutableSet.of("my_column"); @@ -915,7 +924,7 @@ public void testTableColumnOperationsFailure( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); CatalogSchemaTableName table = new CatalogSchemaTableName("my_catalog", "my_schema", "my_table"); Set columns = ImmutableSet.of("my_column"); @@ -930,7 +939,7 @@ public void testTableColumnOperationsFailure( public void testCanSetCatalogSessionProperty() { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, OK_RESPONSE)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); authorizer.checkCanSetCatalogSessionProperty( requestingSecurityContext, "my_catalog", "my_property"); @@ -957,7 +966,7 @@ public void testCanSetCatalogSessionPropertyFailure( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, mockClient); assertThatThrownBy( () -> authorizer.checkCanSetCatalogSessionProperty( @@ -1066,19 +1075,224 @@ private void testRequestContextContentsForGivenTrinoVersion(Optional authorizer.getRowFilters(TEST_SECURITY_CONTEXT, tableName)); + + // Also test a valid JSON response, but containing invalid fields for a row filters request + String validJsonButIllegalSchemaResponseContents = """ + { + "result": ["some-expr"] + }"""; + assertAccessControlMethodThrowsForResponse( + authorizer -> authorizer.getRowFilters(TEST_SECURITY_CONTEXT, tableName), + new MockResponse(validJsonButIllegalSchemaResponseContents, 200), + OpaQueryException.class, + "Failed to deserialize"); + } + + @Test + public void testGetRowFilters() + { + // This example is a bit strange - an undefined policy would in most cases + // result in an access denied situation. However, since this is row-level-filtering + // we will accept this as meaning there are no known filters to be applied. + testGetRowFilters("{}", ImmutableList.of()); + + String noExpressionsResponse = """ + { + "result": [] + }"""; + testGetRowFilters(noExpressionsResponse, ImmutableList.of()); + + String singleExpressionResponse = """ + { + "result": [ + {"expression": "expr1"} + ] + }"""; + testGetRowFilters( + singleExpressionResponse, + ImmutableList.of(new OpaViewExpression("expr1", Optional.empty()))); + + String multipleExpressionsAndIdentitiesResponse = """ + { + "result": [ + {"expression": "expr1"}, + {"expression": "expr2", "identity": "expr2_identity"}, + {"expression": "expr3", "identity": "expr3_identity"} + ] + }"""; + testGetRowFilters( + multipleExpressionsAndIdentitiesResponse, + ImmutableList.builder() + .add(new OpaViewExpression("expr1", Optional.empty())) + .add(new OpaViewExpression("expr2", Optional.of("expr2_identity"))) + .add(new OpaViewExpression("expr3", Optional.of("expr3_identity"))) + .build()); + } + + private void testGetRowFilters(String responseContent, List expectedExpressions) + { + InstrumentedHttpClient httpClient = createMockHttpClient(OPA_SERVER_ROW_FILTERING_URI, buildValidatingRequestHandler(TEST_IDENTITY, new MockResponse(responseContent, 200))); + OpaAccessControl authorizer = createOpaAuthorizer( + new TestHelpers.OpaConfigBuilder() + .withBasePolicy(OPA_SERVER_URI) + .withRowFiltersPolicy(OPA_SERVER_ROW_FILTERING_URI) + .buildConfig(), + httpClient); + CatalogSchemaTableName tableName = new CatalogSchemaTableName("some_catalog", "some_schema", "some_table"); + + List result = authorizer.getRowFilters(TEST_SECURITY_CONTEXT, tableName); + assertThat(result).allSatisfy(expression -> { + assertThat(expression.getCatalog()).contains("some_catalog"); + assertThat(expression.getSchema()).contains("some_schema"); + }); + assertThat(result).map( + viewExpression -> new OpaViewExpression( + viewExpression.getExpression(), + viewExpression.getSecurityIdentity())) + .containsExactlyInAnyOrderElementsOf(expectedExpressions); + + String expectedRequest = """ + { + "operation": "GetRowFilters", + "resource": { + "table": { + "catalogName": "some_catalog", + "schemaName": "some_schema", + "tableName": "some_table" + } + } + }"""; + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), httpClient.getRequests(), "/input/action"); + } + + @Test + public void testGetRowFiltersDoesNothingIfNotConfigured() + { + InstrumentedHttpClient httpClient = createMockHttpClient(OPA_SERVER_ROW_FILTERING_URI, request -> {throw new AssertionError("Should not have been called");}); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, httpClient); + CatalogSchemaTableName tableName = new CatalogSchemaTableName("some_catalog", "some_schema", "some_table"); + + List result = authorizer.getRowFilters(TEST_SECURITY_CONTEXT, tableName); + assertThat(result).isEmpty(); + assertThat(httpClient.getRequests()).isEmpty(); + } + + @Test + public void testGetColumnMaskThrowsForIllegalResponse() + { + CatalogSchemaTableName tableName = new CatalogSchemaTableName("some_catalog", "some_schema", "some_table"); + assertAccessControlMethodThrowsForIllegalResponses(authorizer -> authorizer.getColumnMask(TEST_SECURITY_CONTEXT, tableName, "some_column", VarcharType.VARCHAR)); + + // Also test a valid JSON response, but containing invalid fields for a row filters request + String validJsonButIllegalSchemaResponseContents = """ + { + "result": {"expression": {"foo": "bar"}} + }"""; + assertAccessControlMethodThrowsForResponse( + authorizer -> authorizer.getColumnMask(TEST_SECURITY_CONTEXT, tableName, "some_column", VarcharType.VARCHAR), + new MockResponse(validJsonButIllegalSchemaResponseContents, 200), + OpaQueryException.class, + "Failed to deserialize"); + } + + @Test + public void testGetColumnMask() + { + // Similar note to the test for row level filtering: + // This example is a bit strange - an undefined policy would in most cases + // result in an access denied situation. However, since this is column masking, + // we will accept this as meaning there are no masks to be applied. + testGetColumnMask("{}", Optional.empty()); + + String nullResponse = """ + { + "result": null + }"""; + testGetColumnMask(nullResponse, Optional.empty()); + + String expressionWithoutIdentityResponse = """ + { + "result": {"expression": "expr1"} + }"""; + testGetColumnMask( + expressionWithoutIdentityResponse, + Optional.of(new OpaViewExpression("expr1", Optional.empty()))); + + String expressionWithIdentityResponse = """ + { + "result": {"expression": "expr1", "identity": "some_identity"} + }"""; + testGetColumnMask( + expressionWithIdentityResponse, + Optional.of(new OpaViewExpression("expr1", Optional.of("some_identity")))); + } + + private void testGetColumnMask(String responseContent, Optional expectedExpression) + { + InstrumentedHttpClient httpClient = createMockHttpClient(OPA_SERVER_COLUMN_MASK_URI, buildValidatingRequestHandler(TEST_IDENTITY, new MockResponse(responseContent, 200))); + OpaAccessControl authorizer = createOpaAuthorizer( + new TestHelpers.OpaConfigBuilder() + .withBasePolicy(OPA_SERVER_URI) + .withColumnMaskingPolicy(OPA_SERVER_COLUMN_MASK_URI) + .buildConfig(), + httpClient); + CatalogSchemaTableName tableName = new CatalogSchemaTableName("some_catalog", "some_schema", "some_table"); + + Optional result = authorizer.getColumnMask(TEST_SECURITY_CONTEXT, tableName, "some_column", VarcharType.VARCHAR); + + assertThat(result.isEmpty()).isEqualTo(expectedExpression.isEmpty()); + assertThat(result.map(viewExpression -> { + assertThat(viewExpression.getCatalog()).contains("some_catalog"); + assertThat(viewExpression.getSchema()).contains("some_schema"); + return new OpaViewExpression(viewExpression.getExpression(), viewExpression.getSecurityIdentity()); + })).isEqualTo(expectedExpression); + + String expectedRequest = """ + { + "operation": "GetColumnMask", + "resource": { + "column": { + "catalogName": "some_catalog", + "schemaName": "some_schema", + "tableName": "some_table", + "columnName": "some_column", + "columnType": "varchar" + } + } + }"""; + assertStringRequestsEqual(ImmutableSet.of(expectedRequest), httpClient.getRequests(), "/input/action"); + } + + @Test + public void testGetColumnMaskDoesNothingIfNotConfigured() + { + InstrumentedHttpClient httpClient = createMockHttpClient(OPA_SERVER_COLUMN_MASK_URI, request -> {throw new AssertionError("Should not have been called");}); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, httpClient); + CatalogSchemaTableName tableName = new CatalogSchemaTableName("some_catalog", "some_schema", "some_table"); + + Optional result = authorizer.getColumnMask(TEST_SECURITY_CONTEXT, tableName, "some_column", VarcharType.VARCHAR); + assertThat(result).isEmpty(); + assertThat(httpClient.getRequests()).isEmpty(); + } + private static void assertAccessControlMethodBehaviour(MethodWrapper method, Set expectedRequests) { InstrumentedHttpClient permissiveMockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(TEST_IDENTITY, OK_RESPONSE)); InstrumentedHttpClient restrictiveMockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(TEST_IDENTITY, NO_ACCESS_RESPONSE)); - assertThat(method.isAccessAllowed(createOpaAuthorizer(OPA_SERVER_URI, permissiveMockClient))).isTrue(); - assertThat(method.isAccessAllowed(createOpaAuthorizer(OPA_SERVER_URI, restrictiveMockClient))).isFalse(); + assertThat(method.isAccessAllowed(createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, permissiveMockClient))).isTrue(); + assertThat(method.isAccessAllowed(createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, restrictiveMockClient))).isFalse(); assertThat(permissiveMockClient.getRequests()).containsExactlyInAnyOrderElementsOf(restrictiveMockClient.getRequests()); assertStringRequestsEqual(expectedRequests, permissiveMockClient.getRequests(), "/input/action"); - assertAccessControlMethodThrowsForIllegalResponses(method); + assertAccessControlMethodThrowsForIllegalResponses(method::isAccessAllowed); } - private static void assertAccessControlMethodThrowsForIllegalResponses(MethodWrapper methodToTest) + private static void assertAccessControlMethodThrowsForIllegalResponses(Consumer methodToTest) { assertAccessControlMethodThrowsForResponse(methodToTest, UNDEFINED_RESPONSE, OpaQueryException.OpaServerError.PolicyNotFound.class, "did not return a value"); assertAccessControlMethodThrowsForResponse(methodToTest, BAD_REQUEST_RESPONSE, OpaQueryException.OpaServerError.class, "returned status 400"); @@ -1087,15 +1301,21 @@ private static void assertAccessControlMethodThrowsForIllegalResponses(MethodWra } private static void assertAccessControlMethodThrowsForResponse( - MethodWrapper methodToTest, + Consumer methodToTest, MockResponse response, Class expectedException, String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(TEST_IDENTITY, response)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); - - assertThatThrownBy(() -> methodToTest.isAccessAllowed(authorizer)) + OpaAccessControl authorizer = createOpaAuthorizer( + new TestHelpers.OpaConfigBuilder() + .withBasePolicy(OPA_SERVER_URI) + .withRowFiltersPolicy(OPA_SERVER_URI) + .withColumnMaskingPolicy(OPA_SERVER_URI) + .buildConfig(), + mockClient); + + assertThatThrownBy(() -> methodToTest.accept(authorizer)) .isInstanceOf(expectedException) .hasMessageContaining(expectedErrorMessage); } diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlDataFilteringSystem.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlDataFilteringSystem.java new file mode 100644 index 0000000000000..1cced9f230a8c --- /dev/null +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlDataFilteringSystem.java @@ -0,0 +1,286 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.opa; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import io.trino.connector.MockConnectorFactory; +import io.trino.connector.MockConnectorPlugin; +import io.trino.spi.connector.ColumnMetadata; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.VarcharType; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; + +@Testcontainers +@TestInstance(PER_CLASS) +public class TestOpaAccessControlDataFilteringSystem +{ + @Container + private static final OpaContainer OPA_CONTAINER = new OpaContainer(); + private static final String OPA_ALLOW_POLICY_NAME = "allow"; + private static final String OPA_ROW_LEVEL_FILTERING_POLICY_NAME = "rowFilters"; + private static final String OPA_COLUMN_MASKING_POLICY_NAME = "columnMask"; + private static final String SAMPLE_ROW_LEVEL_FILTERING_POLICY = """ + package trino + import future.keywords.in + import future.keywords.if + import future.keywords.contains + + default allow := true + + table_resource := input.action.resource.table + is_admin { + input.context.identity.user == "admin" + } + + rowFilters contains {"expression": "user_type <> 'customer'"} if { + not is_admin + table_resource.catalogName == "sample_catalog" + table_resource.schemaName == "sample_schema" + table_resource.tableName == "restricted_table" + }"""; + private static final String SAMPLE_COLUMN_MASKING_POLICY = """ + package trino + import future.keywords.in + import future.keywords.if + import future.keywords.contains + + default allow := true + + column_resource := input.action.resource.column + is_admin { + input.context.identity.user == "admin" + } + + columnMask := {"expression": "NULL"} if { + not is_admin + column_resource.catalogName == "sample_catalog" + column_resource.schemaName == "sample_schema" + column_resource.tableName == "restricted_table" + column_resource.columnName == "user_phone" + } + + columnMask := {"expression": "'****' || substring(user_name, -3)"} if { + not is_admin + column_resource.catalogName == "sample_catalog" + column_resource.schemaName == "sample_schema" + column_resource.tableName == "restricted_table" + column_resource.columnName == "user_name" + } + """; + + private static final Set DUMMY_CUSTOMERS_IN_TABLE = ImmutableSet.of("customer_one", "customer_two"); + private static final Set DUMMY_INTERNAL_USERS_IN_TABLE = ImmutableSet.of("some_internal_user"); + private static final Set ALL_DUMMY_USERS_IN_TABLE = ImmutableSet.builder() + .addAll(DUMMY_INTERNAL_USERS_IN_TABLE) + .addAll(DUMMY_CUSTOMERS_IN_TABLE) + .build(); + + private DistributedQueryRunnerHelper runner; + + + @AfterEach + public void teardown() + { + if (runner != null) { + runner.teardown(); + } + } + + @Test + public void testRowFilteringEnabled() + throws Exception + { + setupTrinoWithOpa( + new TestHelpers.OpaConfigBuilder() + .withBasePolicy(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ALLOW_POLICY_NAME)) + .withRowFiltersPolicy(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ROW_LEVEL_FILTERING_POLICY_NAME)) + .buildConfig()); + OPA_CONTAINER.submitPolicy(SAMPLE_ROW_LEVEL_FILTERING_POLICY); + String restrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.restricted_table"; + String unrestrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table"; + assertResultsForUser("admin", restrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE); + assertResultsForUser("admin", unrestrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE); + + assertResultsForUser("bob", unrestrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE); + assertResultsForUser("bob", restrictedTableQuery, DUMMY_INTERNAL_USERS_IN_TABLE); + } + + @Test + public void testRowFilteringDisabledDoesNothing() + throws Exception + { + setupTrinoWithOpa( + new TestHelpers.OpaConfigBuilder() + .withBasePolicy(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ALLOW_POLICY_NAME)) + .buildConfig()); + OPA_CONTAINER.submitPolicy(SAMPLE_ROW_LEVEL_FILTERING_POLICY); + String restrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.restricted_table"; + String unrestrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table"; + assertResultsForUser("admin", restrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE); + assertResultsForUser("admin", unrestrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE); + + assertResultsForUser("bob", unrestrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE); + assertResultsForUser("bob", restrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE); + } + + @Test + public void testColumnMasking() + throws Exception + { + setupTrinoWithOpa( + new TestHelpers.OpaConfigBuilder() + .withBasePolicy(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ALLOW_POLICY_NAME)) + .withColumnMaskingPolicy(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_COLUMN_MASKING_POLICY_NAME)) + .buildConfig()); + OPA_CONTAINER.submitPolicy(SAMPLE_COLUMN_MASKING_POLICY); + + String userNamesInUnrestrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table"; + String userNamesInRestrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.restricted_table"; + // No masking is applied to the unrestricted table + assertResultsForUser("admin", userNamesInUnrestrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE); + assertResultsForUser("bob", userNamesInUnrestrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE); + + // No masking is applied for "admin" even in the restricted table + assertResultsForUser("admin", userNamesInRestrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE); + + // "bob" can only see the last 3 characters of user names for the restricted table + Set expectedMaskedUserNames = ALL_DUMMY_USERS_IN_TABLE.stream().map(userName -> "****" + userName.substring(userName.length() - 3)).collect(toImmutableSet()); + assertResultsForUser("bob", userNamesInRestrictedTableQuery, expectedMaskedUserNames); + + String phoneNumbersInUnrestrictedTableQuery = "SELECT user_phone FROM sample_catalog.sample_schema.unrestricted_table"; + String phoneNumbersInRestrictedTableQuery = "SELECT user_phone FROM sample_catalog.sample_schema.restricted_table"; + + // Phone numbers are derived by hashing the name of the user + Set allExpectedPhoneNumbers = ALL_DUMMY_USERS_IN_TABLE.stream().map(userName -> String.valueOf(userName.hashCode())).collect(toImmutableSet()); + + // No masking is applied to the unrestricted table + assertResultsForUser("admin", phoneNumbersInUnrestrictedTableQuery, allExpectedPhoneNumbers); + assertResultsForUser("bob", phoneNumbersInUnrestrictedTableQuery, allExpectedPhoneNumbers); + + // No masking is applied for "admin" even in the restricted table + assertResultsForUser("admin", phoneNumbersInRestrictedTableQuery, allExpectedPhoneNumbers); + // "bob" cannot see any phone numbers in the restricted table + assertResultsForUser("bob", phoneNumbersInRestrictedTableQuery, ImmutableSet.of("")); + } + + @Test + public void testColumnMaskingDisabledDoesNothing() + throws Exception + { + setupTrinoWithOpa( + new TestHelpers.OpaConfigBuilder() + .withBasePolicy(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ALLOW_POLICY_NAME)) + .buildConfig()); + OPA_CONTAINER.submitPolicy(SAMPLE_COLUMN_MASKING_POLICY); + String restrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.restricted_table"; + String unrestrictedTableQuery = "SELECT user_name FROM sample_catalog.sample_schema.unrestricted_table"; + assertResultsForUser("admin", restrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE); + assertResultsForUser("admin", unrestrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE); + + assertResultsForUser("bob", unrestrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE); + assertResultsForUser("bob", restrictedTableQuery, ALL_DUMMY_USERS_IN_TABLE); + } + + @Test + public void testColumnMaskingAndRowFiltering() + throws Exception + { + setupTrinoWithOpa( + new TestHelpers.OpaConfigBuilder() + .withBasePolicy(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ALLOW_POLICY_NAME)) + .withColumnMaskingPolicy(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_COLUMN_MASKING_POLICY_NAME)) + .withRowFiltersPolicy(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ROW_LEVEL_FILTERING_POLICY_NAME)) + .buildConfig()); + // Simpler policy than the previous tests: + // Admin has no restrictions + // Any other user can only see rows where "user_type" is not "customer" + // And cannot see any data for field "user_name" + String policy = """ + package trino + import future.keywords.in + import future.keywords.if + import future.keywords.contains + + default allow := true + + is_admin { + input.context.identity.user == "admin" + } + + table_resource := input.action.resource.table + column_resource := input.action.resource.column + + rowFilters contains {"expression": "user_type <> 'customer'"} if { + not is_admin + } + columnMask := {"expression": "NULL"} if { + not is_admin + column_resource.columnName == "user_name" + }"""; + OPA_CONTAINER.submitPolicy(policy); + + String selectUserNameData = "SELECT user_name FROM sample_catalog.sample_schema.restricted_table"; + String selectUserTypeData = "SELECT user_type FROM sample_catalog.sample_schema.restricted_table"; + Set expectedUserTypes = ImmutableSet.of("internal_user", "customer"); + + assertResultsForUser("admin", selectUserNameData, ALL_DUMMY_USERS_IN_TABLE); + assertResultsForUser("admin", selectUserTypeData, expectedUserTypes); + + assertResultsForUser("bob", selectUserNameData, ImmutableSet.of("")); + assertResultsForUser("bob", selectUserTypeData, ImmutableSet.of("internal_user")); + } + + private void assertResultsForUser(String asUser, String query, Set expectedResults) + { + assertThat(runner.querySetOfStrings(asUser, query)).containsExactlyInAnyOrderElementsOf(expectedResults); + } + + private void setupTrinoWithOpa(Map opaConfig) + throws Exception + { + this.runner = DistributedQueryRunnerHelper.withOpaConfig(opaConfig); + MockConnectorFactory connectorFactory = MockConnectorFactory.builder() + .withListSchemaNames(session -> ImmutableList.of("sample_schema")) + .withListTables((session, schema) -> ImmutableList.builder() + .add("restricted_table") + .add("unrestricted_table") + .build()) + .withGetColumns(schemaTableName -> ImmutableList.builder() + .add(ColumnMetadata.builder().setName("user_type").setType(VarcharType.VARCHAR).build()) + .add(ColumnMetadata.builder().setName("user_name").setType(VarcharType.VARCHAR).build()) + .add(ColumnMetadata.builder().setName("user_phone").setType(IntegerType.INTEGER).build()) + .build()) + .withData(schemaTableName -> ImmutableList.>builder() + .addAll(DUMMY_CUSTOMERS_IN_TABLE.stream().map(customer -> ImmutableList.of("customer", customer, customer.hashCode())).collect(toImmutableSet())) + .addAll(DUMMY_INTERNAL_USERS_IN_TABLE.stream().map(internalUser -> ImmutableList.of("internal_user", internalUser, internalUser.hashCode())).collect(toImmutableSet())) + .build()) + .build(); + + runner.getBaseQueryRunner().installPlugin(new MockConnectorPlugin(connectorFactory)); + runner.getBaseQueryRunner().createCatalog("sample_catalog", "mock"); + } +} diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlFiltering.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlFiltering.java index bd97a034e52b6..3ad662efceca5 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlFiltering.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlFiltering.java @@ -49,6 +49,7 @@ public class TestOpaAccessControlFiltering { private static final URI OPA_SERVER_URI = URI.create("http://my-uri/"); + private static final Map OPA_CONFIG = new TestHelpers.OpaConfigBuilder().withBasePolicy(OPA_SERVER_URI).buildConfig(); private final Identity requestingIdentity = Identity.ofUser("source-user"); private final SystemSecurityContext requestingSecurityContext = systemSecurityContextFromIdentity(requestingIdentity); @@ -56,7 +57,7 @@ public class TestOpaAccessControlFiltering public void testFilterViewQueryOwnedBy() { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildHandler("/input/action/resource/user/user", "user-one")); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); Identity userOne = Identity.ofUser("user-one"); Identity userTwo = Identity.ofUser("user-two"); @@ -98,7 +99,7 @@ public void testFilterViewQueryOwnedBy() public void testFilterCatalogs() { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildHandler("/input/action/resource/catalog/name", "catalog_two")); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); Set requestedCatalogs = ImmutableSet.of("catalog_one", "catalog_two"); Set result = authorizer.filterCatalogs( @@ -135,7 +136,7 @@ public void testFilterCatalogs() public void testFilterSchemas() { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildHandler("/input/action/resource/schema/schemaName", "schema_one")); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); Set requestedSchemas = ImmutableSet.of("schema_one", "schema_two"); @@ -171,7 +172,7 @@ public void testFilterTables() .add(new SchemaTableName("schema_two", "table_two")) .build(); InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildHandler("/input/action/resource/table/tableName", "table_one")); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); Set result = authorizer.filterTables(requestingSecurityContext, "my_catalog", tables); assertThat(result).containsExactlyInAnyOrderElementsOf(tables.stream().filter(table -> table.getTableName().equals("table_one")).collect(toImmutableSet())); @@ -212,7 +213,7 @@ public void testFilterColumns() .build(); InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildHandler("/input/action/resource/table/columns/0", columnsToAllow)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); Map> result = authorizer.filterColumns(requestingSecurityContext, "my_catalog", requestedColumns); @@ -245,7 +246,7 @@ public void testFilterColumns() public void testEmptyFilterColumns() { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, request -> OK_RESPONSE); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); SchemaTableName someTable = SchemaTableName.schemaTableName("my_schema", "my_table"); Map> requestedColumns = ImmutableMap.of(someTable, ImmutableSet.of()); @@ -267,7 +268,7 @@ public void testFilterFunctions() Set requestedFunctions = ImmutableSet.of(functionOne, functionTwo); InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildHandler("/input/action/resource/function/functionName", "function_two")); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); Set result = authorizer.filterFunctions( requestingSecurityContext, @@ -297,7 +298,7 @@ public void testEmptyRequests( BiFunction callable) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, request -> OK_RESPONSE); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); Collection result = callable.apply(authorizer, requestingSecurityContext); assertThat(result).isEmpty(); @@ -313,7 +314,7 @@ public void testIllegalResponseThrows( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); assertThatThrownBy(() -> callable.apply(authorizer, requestingSecurityContext)) .isInstanceOf(expectedException) diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlSystem.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlSystem.java index b174efd29e0a0..2e38f0a33a3ae 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlSystem.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlSystem.java @@ -13,12 +13,8 @@ */ package io.trino.plugin.opa; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import io.trino.Session; import io.trino.plugin.blackhole.BlackHolePlugin; -import io.trino.spi.security.Identity; -import io.trino.testing.DistributedQueryRunner; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.DisplayName; @@ -29,26 +25,15 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.testcontainers.containers.GenericContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; -import org.testcontainers.utility.DockerImageName; import java.io.IOException; -import java.net.InetSocketAddress; -import java.net.Socket; -import java.net.SocketTimeoutException; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.util.Optional; +import java.util.Map; import java.util.Set; import java.util.stream.Stream; -import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.plugin.opa.FunctionalHelpers.Pair; -import static io.trino.testing.TestingSession.testSessionBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; @@ -57,14 +42,12 @@ @TestInstance(PER_CLASS) public class TestOpaAccessControlSystem { - private URI opaServerUri; - private DistributedQueryRunner runner; + private DistributedQueryRunnerHelper runner; - private static final int OPA_PORT = 8181; + private static final String OPA_ALLOW_POLICY_NAME = "allow"; + private static final String OPA_BATCH_ALLOW_POLICY_NAME = "batchAllow"; @Container - private static final GenericContainer OPA_CONTAINER = new GenericContainer<>(DockerImageName.parse("openpolicyagent/opa:latest-rootless")) - .withCommand("run", "--server", "--addr", ":%d".formatted(OPA_PORT)) - .withExposedPorts(OPA_PORT); + private static final OpaContainer OPA_CONTAINER = new OpaContainer(); @Nested @TestInstance(PER_CLASS) @@ -75,15 +58,15 @@ class UnbatchedAuthorizerTests public void setupTrino() throws Exception { - setupTrinoWithOpa("v1/data/trino/allow", Optional.empty()); + setupTrinoWithOpa(new TestHelpers.OpaConfigBuilder() + .withBasePolicy(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ALLOW_POLICY_NAME)) + .buildConfig()); } @AfterAll public void teardown() { - if (runner != null) { - runner.close(); - } + runner.teardown(); } @ParameterizedTest(name = "{index}: {0}") @@ -91,7 +74,7 @@ public void teardown() public void testAllowsQueryAndFilters(String userName, Set expectedCatalogs) throws IOException, InterruptedException { - submitPolicy(""" + OPA_CONTAINER.submitPolicy(""" package trino import future.keywords.in import future.keywords.if @@ -117,7 +100,7 @@ public void testAllowsQueryAndFilters(String userName, Set expectedCatal input.action.resource.catalog.name == "catalog_one" } """); - Set catalogs = querySetOfStrings(user(userName), "SHOW CATALOGS"); + Set catalogs = runner.querySetOfStrings(userName, "SHOW CATALOGS"); assertThat(catalogs).containsExactlyInAnyOrderElementsOf(expectedCatalogs); } @@ -125,7 +108,7 @@ public void testAllowsQueryAndFilters(String userName, Set expectedCatal public void testShouldDenyQueryIfDirected() throws IOException, InterruptedException { - submitPolicy(""" + OPA_CONTAINER.submitPolicy(""" package trino import future.keywords.in default allow = false @@ -134,11 +117,11 @@ public void testShouldDenyQueryIfDirected() input.context.identity.user in ["someone", "admin"] } """); - assertThatThrownBy(() -> runner.execute(user("bob"), "SHOW CATALOGS")) + assertThatThrownBy(() -> runner.querySetOfStrings("bob", "SHOW CATALOGS")) .isInstanceOf(RuntimeException.class) .hasMessageContaining("Access Denied"); // smoke test: we can still query if we are the right user - runner.execute(user("admin"), "SHOW CATALOGS"); + runner.querySetOfStrings("admin", "SHOW CATALOGS"); } } @@ -151,15 +134,16 @@ class BatchedAuthorizerTests public void setupTrino() throws Exception { - setupTrinoWithOpa("v1/data/trino/allow", Optional.of("v1/data/trino/batchAllow")); + setupTrinoWithOpa(new TestHelpers.OpaConfigBuilder() + .withBasePolicy(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_ALLOW_POLICY_NAME)) + .withBatchPolicy(OPA_CONTAINER.getOpaUriForPolicyPath(OPA_BATCH_ALLOW_POLICY_NAME)) + .buildConfig()); } @AfterAll public void teardown() { - if (runner != null) { - runner.close(); - } + runner.teardown(); } @ParameterizedTest(name = "{index}: {0}") @@ -167,7 +151,7 @@ public void teardown() public void testFilterOutItemsBatch(String userName, Set expectedCatalogs) throws IOException, InterruptedException { - submitPolicy(""" + OPA_CONTAINER.submitPolicy(""" package trino import future.keywords.in import future.keywords.if @@ -201,7 +185,7 @@ public void testFilterOutItemsBatch(String userName, Set expectedCatalog is_admin } """); - Set catalogs = querySetOfStrings(user(userName), "SHOW CATALOGS"); + Set catalogs = runner.querySetOfStrings(userName, "SHOW CATALOGS"); assertThat(catalogs).containsExactlyInAnyOrderElementsOf(expectedCatalogs); } @@ -209,12 +193,12 @@ public void testFilterOutItemsBatch(String userName, Set expectedCatalog public void testDenyUnbatchedQuery() throws IOException, InterruptedException { - submitPolicy(""" + OPA_CONTAINER.submitPolicy(""" package trino import future.keywords.in default allow = false """); - assertThatThrownBy(() -> runner.execute(user("bob"), "SELECT version()")) + assertThatThrownBy(() -> runner.querySetOfStrings("bob", "SELECT version()")) .isInstanceOf(RuntimeException.class) .hasMessageContaining("Access Denied"); } @@ -223,7 +207,7 @@ public void testDenyUnbatchedQuery() public void testAllowUnbatchedQuery() throws IOException, InterruptedException { - submitPolicy(""" + OPA_CONTAINER.submitPolicy(""" package trino import future.keywords.in default allow = false @@ -232,86 +216,18 @@ public void testAllowUnbatchedQuery() input.action.operation in ["ImpersonateUser", "ExecuteFunction", "AccessCatalog", "ExecuteQuery"] } """); - Set version = querySetOfStrings(user("bob"), "SELECT version()"); + Set version = runner.querySetOfStrings("bob", "SELECT version()"); assertThat(version).isNotEmpty(); } } - private void ensureOpaUp() - throws IOException, InterruptedException - { - assertThat(OPA_CONTAINER.isRunning()).isTrue(); - InetSocketAddress opaSocket = new InetSocketAddress(OPA_CONTAINER.getHost(), OPA_CONTAINER.getMappedPort(OPA_PORT)); - String opaEndpoint = String.format("%s:%d", opaSocket.getHostString(), opaSocket.getPort()); - awaitSocketOpen(opaSocket, 100, 200); - this.opaServerUri = URI.create(String.format("http://%s/", opaEndpoint)); - } - - private void setupTrinoWithOpa(String basePolicyRelativeUri, Optional batchPolicyRelativeUri) + private void setupTrinoWithOpa(Map opaConfig) throws Exception { - ensureOpaUp(); - ImmutableMap.Builder opaConfigBuilder = ImmutableMap.builder(); - opaConfigBuilder.put("opa.policy.uri", opaServerUri.resolve(basePolicyRelativeUri).toString()); - batchPolicyRelativeUri.ifPresent(relativeUri -> opaConfigBuilder.put("opa.policy.batched-uri", opaServerUri.resolve(relativeUri).toString())); - this.runner = DistributedQueryRunner.builder(testSessionBuilder().build()) - .setSystemAccessControl(new OpaAccessControlFactory().create(opaConfigBuilder.buildOrThrow())) - .setNodeCount(1) - .build(); - runner.installPlugin(new BlackHolePlugin()); - runner.createCatalog("catalog_one", "blackhole"); - runner.createCatalog("catalog_two", "blackhole"); - } - - private static void awaitSocketOpen(InetSocketAddress addr, int attempts, int timeoutMs) - throws IOException, InterruptedException - { - for (int i = 0; i < attempts; ++i) { - try (Socket socket = new Socket()) { - socket.connect(addr, timeoutMs); - return; - } - catch (SocketTimeoutException e) { - // ignored - } - catch (IOException e) { - Thread.sleep(timeoutMs); - } - } - throw new SocketTimeoutException("Timed out waiting for addr %s to be available (%d attempts made with a %d ms wait)".formatted(addr, attempts, timeoutMs)); - } - - private static String stringOfLines(String... lines) - { - StringBuilder out = new StringBuilder(); - for (String line : lines) { - out.append(line); - out.append("\r\n"); - } - return out.toString(); - } - - private void submitPolicy(String... policyLines) - throws IOException, InterruptedException - { - HttpClient httpClient = HttpClient.newHttpClient(); - HttpResponse policyResponse = - httpClient.send( - HttpRequest.newBuilder(opaServerUri.resolve("v1/policies/trino")) - .PUT(HttpRequest.BodyPublishers.ofString(stringOfLines(policyLines))) - .header("Content-Type", "text/plain").build(), - HttpResponse.BodyHandlers.ofString()); - assertThat(policyResponse.statusCode()).withFailMessage("Failed to submit policy: %s", policyResponse.body()).isEqualTo(200); - } - - private Session user(String user) - { - return testSessionBuilder().setIdentity(Identity.ofUser(user)).build(); - } - - private Set querySetOfStrings(Session session, String query) - { - return runner.execute(session, query).getMaterializedRows().stream().map(row -> row.getField(0).toString()).collect(toImmutableSet()); + this.runner = DistributedQueryRunnerHelper.withOpaConfig(opaConfig); + runner.getBaseQueryRunner().installPlugin(new BlackHolePlugin()); + runner.getBaseQueryRunner().createCatalog("catalog_one", "blackhole"); + runner.getBaseQueryRunner().createCatalog("catalog_two", "blackhole"); } private static Stream filterSchemaTests() diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaBatchAccessControlFiltering.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaBatchAccessControlFiltering.java index 6bfa78c190d55..3f63e94cc29cf 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaBatchAccessControlFiltering.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaBatchAccessControlFiltering.java @@ -52,7 +52,11 @@ public class TestOpaBatchAccessControlFiltering { private static final URI OPA_SERVER_URI = URI.create("http://my-uri/"); - private static final URI OPA_BATCH_SERVER_URI = URI.create("http://my-uri/batchAllow"); + private static final URI OPA_BATCH_SERVER_URI = URI.create("http://my-batch-uri/"); + private static final Map OPA_CONFIG = new TestHelpers.OpaConfigBuilder() + .withBasePolicy(OPA_SERVER_URI) + .withBatchPolicy(OPA_BATCH_SERVER_URI) + .buildConfig(); private final Identity requestingIdentity = Identity.ofUser("source-user"); private final SystemSecurityContext requestingSecurityContext = systemSecurityContextFromIdentity(requestingIdentity); @@ -63,7 +67,7 @@ public void testFilterViewQueryOwnedBy( List expectedItems) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, response)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); Identity identityOne = Identity.ofUser("user-one"); Identity identityTwo = Identity.ofUser("user-two"); @@ -107,7 +111,7 @@ public void testFilterCatalogs( List expectedItems) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, response)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); List requestedCatalogs = ImmutableList.of("catalog_one", "catalog_two", "catalog_three"); @@ -147,7 +151,7 @@ public void testFilterSchemas( List expectedItems) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, response)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); List requestedSchemas = ImmutableList.of("schema_one", "schema_two", "schema_three"); Set result = authorizer.filterSchemas( @@ -190,7 +194,7 @@ public void testFilterTables( List expectedItems) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, response)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); List tables = ImmutableList.builder() .add(new SchemaTableName("schema_one", "table_one")) .add(new SchemaTableName("schema_one", "table_two")) @@ -264,7 +268,7 @@ public void testFilterColumns() }; return new MockResponse(responseContents, 200); })); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); Map> result = authorizer.filterColumns( requestingSecurityContext, "my_catalog", @@ -302,7 +306,7 @@ public void testFilterFunctions( List expectedItems) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, response)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); List requestedFunctions = ImmutableList.builder() .add(new SchemaFunctionName("my_schema", "function_one")) .add(new SchemaFunctionName("my_schema", "function_two")) @@ -349,7 +353,7 @@ public void testFilterFunctions( public void testEmptyFilterColumns() { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, request -> OK_RESPONSE); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); SchemaTableName tableOne = SchemaTableName.schemaTableName("my_schema", "table_one"); SchemaTableName tableTwo = SchemaTableName.schemaTableName("my_schema", "table_two"); @@ -372,7 +376,7 @@ public void testEmptyRequests( BiFunction callable) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, request -> OK_RESPONSE); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); Collection result = callable.apply(authorizer, requestingSecurityContext); assertThat(result).isEmpty(); @@ -388,7 +392,7 @@ public void testIllegalResponseThrows( String expectedErrorMessage) { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, failureResponse)); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); assertThatThrownBy(() -> callable.apply(authorizer, requestingSecurityContext)) .isInstanceOf(expectedException) @@ -400,7 +404,7 @@ public void testIllegalResponseThrows( public void testResponseOutOfBoundsThrows() { InstrumentedHttpClient mockClient = createMockHttpClient(OPA_BATCH_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, 200, "{\"result\": [0, 1, 2]}")); - OpaAccessControl authorizer = createOpaAuthorizer(OPA_SERVER_URI, OPA_BATCH_SERVER_URI, mockClient); + OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG, mockClient); assertThatThrownBy(() -> authorizer.filterCatalogs(requestingSecurityContext, ImmutableSet.of("catalog_one", "catalog_two"))) .isInstanceOf(OpaQueryException.QueryFailed.class); diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaConfig.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaConfig.java index 37184bdefd3e8..766b043234125 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaConfig.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaConfig.java @@ -31,6 +31,8 @@ public void testDefaults() assertRecordedDefaults(recordDefaults(OpaConfig.class) .setOpaUri(null) .setOpaBatchUri(null) + .setOpaRowFiltersUri(null) + .setOpaColumnMaskingUri(null) .setLogRequests(false) .setLogResponses(false) .setAllowPermissioningOperations(false)); @@ -42,6 +44,8 @@ public void testExplicitPropertyMappings() Map properties = ImmutableMap.builder() .put("opa.policy.uri", "https://opa.example.com") .put("opa.policy.batched-uri", "https://opa-batch.example.com") + .put("opa.policy.row-filters-uri", "https://opa-row-filtering.example.com") + .put("opa.policy.column-masking-uri", "https://opa-column-masking.example.com") .put("opa.log-requests", "true") .put("opa.log-responses", "true") .put("opa.allow-permissioning-operations", "true") @@ -50,6 +54,8 @@ public void testExplicitPropertyMappings() OpaConfig expected = new OpaConfig() .setOpaUri(URI.create("https://opa.example.com")) .setOpaBatchUri(URI.create("https://opa-batch.example.com")) + .setOpaRowFiltersUri(URI.create("https://opa-row-filtering.example.com")) + .setOpaColumnMaskingUri(URI.create("https://opa-column-masking.example.com")) .setLogRequests(true) .setLogResponses(true) .setAllowPermissioningOperations(true); diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaResponseDecoding.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaResponseDecoding.java index 63b7948f06cc8..b40bd9213125c 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaResponseDecoding.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaResponseDecoding.java @@ -16,15 +16,23 @@ import io.airlift.json.JsonCodec; import io.airlift.json.JsonCodecFactory; import io.trino.plugin.opa.schema.OpaBatchQueryResult; +import io.trino.plugin.opa.schema.OpaColumnMaskQueryResult; import io.trino.plugin.opa.schema.OpaQueryResult; +import io.trino.plugin.opa.schema.OpaRowFiltersQueryResult; +import io.trino.plugin.opa.schema.OpaViewExpression; import org.junit.jupiter.api.Test; +import java.util.Optional; + import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestOpaResponseDecoding { private final JsonCodec responseCodec = new JsonCodecFactory().jsonCodec(OpaQueryResult.class); private final JsonCodec batchResponseCodec = new JsonCodecFactory().jsonCodec(OpaBatchQueryResult.class); + private final JsonCodec rowFilteringResponseCodec = new JsonCodecFactory().jsonCodec(OpaRowFiltersQueryResult.class); + private final JsonCodec columnMaskingResponseCodec = new JsonCodecFactory().jsonCodec(OpaColumnMaskQueryResult.class); @Test public void testCanDeserializeOpaSingleResponse() @@ -82,13 +90,19 @@ public void testUndefinedDecisionSingleResponseTreatedAsDeny() } @Test - public void testEmptyOrUndefinedResponses() + public void testIllegalResponseThrows() + { + testIllegalResponseDecodingThrows("{\"result\": \"foo\"}", responseCodec); + } + + @Test + public void testBatchEmptyOrUndefinedResponses() { - testEmptyOrUndefinedResponses("{}"); - testEmptyOrUndefinedResponses("{\"result\": []}"); + testBatchEmptyOrUndefinedResponses("{}"); + testBatchEmptyOrUndefinedResponses("{\"result\": []}"); } - private void testEmptyOrUndefinedResponses(String response) + private void testBatchEmptyOrUndefinedResponses(String response) { OpaBatchQueryResult result = this.batchResponseCodec.fromJson(response); assertThat(result.result()).isEmpty(); @@ -118,6 +132,16 @@ public void testBatchResponseWithItemsAndDecisionId() assertThat(result.decisionId()).isEqualTo("foobar"); } + @Test + public void testBatchResponseIllegalResponseThrows() + { + testIllegalResponseDecodingThrows(""" + { + "result": ["foo"], + "decision_id": "foobar" + }""", batchResponseCodec); + } + @Test public void testBatchResponseWithExtraFields() { @@ -131,4 +155,140 @@ public void testBatchResponseWithExtraFields() assertThat(result.result()).containsExactly(1, 2, 3); assertThat(result.decisionId()).isEqualTo("foobar"); } + + @Test + public void testRowFilteringEmptyOrUndefinedResponses() + { + testRowFilteringEmptyOrUndefinedResponses("{}"); + testRowFilteringEmptyOrUndefinedResponses("{\"result\": []}"); + } + + private void testRowFilteringEmptyOrUndefinedResponses(String response) + { + OpaRowFiltersQueryResult result = this.rowFilteringResponseCodec.fromJson(response); + assertThat(result.result()).isEmpty(); + assertThat(result.decisionId()).isNull(); + } + + @Test + public void testRowFilteringResponseWithItemsNoDecisionId() + { + OpaRowFiltersQueryResult result = this.rowFilteringResponseCodec.fromJson(""" + { + "result": [ + {"expression": "foo"}, + {"expression": "bar", "identity": "some_identity"} + ] + }"""); + assertThat(result.result()).containsExactlyInAnyOrder( + new OpaViewExpression("foo", Optional.empty()), + new OpaViewExpression("bar", Optional.of("some_identity"))); + assertThat(result.decisionId()).isNull(); + } + + @Test + public void testRowFilteringResponseWithItemsAndDecisionId() + { + OpaRowFiltersQueryResult result = this.rowFilteringResponseCodec.fromJson(""" + { + "result": [{"expression": "test_expression"}], + "decision_id": "some_id" + }"""); + assertThat(result.result()).containsExactly(new OpaViewExpression("test_expression", Optional.empty())); + assertThat(result.decisionId()).isEqualTo("some_id"); + } + + @Test + public void testRowFilteringResponseWithExtraFields() + { + OpaRowFiltersQueryResult result = this.rowFilteringResponseCodec.fromJson(""" + { + "result": [{"expression": "test_expression"}], + "decision_id": "foobar", + "someInfo": "foo", + "andAnObject": {} + }"""); + assertThat(result.result()).containsExactly(new OpaViewExpression("test_expression", Optional.empty())); + assertThat(result.decisionId()).isEqualTo("foobar"); + } + + @Test + public void testRowFilteringResponseIllegalResponseThrows() + { + testIllegalResponseDecodingThrows(""" + { + "result": ["foo"] + }""", rowFilteringResponseCodec); + } + + @Test + public void testColumnMaskingEmptyOrUndefinedResponse() + { + OpaColumnMaskQueryResult emptyResult = columnMaskingResponseCodec.fromJson("{}"); + assertThat(emptyResult.result()).isEmpty(); + assertThat(emptyResult.decisionId()).isNull(); + OpaColumnMaskQueryResult undefinedResult = columnMaskingResponseCodec.fromJson("{\"result\": null}"); + assertThat(undefinedResult.result()).isEmpty(); + assertThat(undefinedResult.decisionId()).isNull(); + } + + @Test + public void testColumnMaskingResponsesWithNoDecisionId() + { + OpaColumnMaskQueryResult result = this.columnMaskingResponseCodec.fromJson(""" + { + "result": {"expression": "test_expression"} + }"""); + assertThat(result.result()).contains(new OpaViewExpression("test_expression", Optional.empty())); + assertThat(result.decisionId()).isNull(); + } + + @Test + public void testColumnMaskingResponsesWithDecisionId() + { + OpaColumnMaskQueryResult resultWithExpression = this.columnMaskingResponseCodec.fromJson(""" + { + "result": {"expression": "test_expression"}, + "decision_id": "foobar" + }"""); + OpaColumnMaskQueryResult resultWithExpressionAndIdentity = this.columnMaskingResponseCodec.fromJson(""" + { + "result": {"expression": "test_expression", "identity": "some_identity"}, + "decision_id": "foobar" + }"""); + assertThat(resultWithExpression.result()).contains(new OpaViewExpression("test_expression", Optional.empty())); + assertThat(resultWithExpressionAndIdentity.result()).contains(new OpaViewExpression("test_expression", Optional.of("some_identity"))); + assertThat(resultWithExpression.decisionId()).isEqualTo("foobar"); + assertThat(resultWithExpressionAndIdentity.decisionId()).isEqualTo("foobar"); + } + + @Test + public void testColumnMaskingResponseWithExtraFields() + { + OpaColumnMaskQueryResult result = this.columnMaskingResponseCodec.fromJson(""" + { + "result": {"expression": "test_expression"}, + "decision_id": "foobar", + "someInfo": "foo", + "andAnObject": {} + }"""); + assertThat(result.result()).contains(new OpaViewExpression("test_expression", Optional.empty())); + assertThat(result.decisionId()).isEqualTo("foobar"); + } + + @Test + public void testColumnMaskingResponseIllegalResponseThrows() + { + testIllegalResponseDecodingThrows(""" + { + "result": {"foo": "bar"} + }""", columnMaskingResponseCodec); + } + + private void testIllegalResponseDecodingThrows(String rawResponse, JsonCodec codec) + { + assertThatThrownBy(() -> codec.fromJson(rawResponse)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Invalid JSON"); + } } From 9dd9265a127ec046ce0e136503ebf941972cd308 Mon Sep 17 00:00:00 2001 From: Pablo Arteaga Date: Thu, 21 Dec 2023 14:42:24 +0100 Subject: [PATCH 10/11] Fix checkstyle --- .../java/io/trino/plugin/opa/TestHelpers.java | 37 +++++++++++++------ .../plugin/opa/TestOpaAccessControl.java | 14 +++++-- ...stOpaAccessControlDataFilteringSystem.java | 1 - .../TestOpaBatchAccessControlFiltering.java | 2 +- .../plugin/opa/TestOpaResponseDecoding.java | 12 +++--- 5 files changed, 43 insertions(+), 23 deletions(-) diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java index ea3b36a5710cb..e82149c99a77d 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java @@ -105,27 +105,34 @@ public static Stream allErrorCasesArgumentProvider() Stream.of(Arguments.of(Named.of("No access response", NO_ACCESS_RESPONSE), AccessDeniedException.class, "Access Denied"))); } - public static SystemSecurityContext systemSecurityContextFromIdentity(Identity identity) { + public static SystemSecurityContext systemSecurityContextFromIdentity(Identity identity) + { return new SystemSecurityContext(identity, new QueryIdGenerator().createNextQueryId(), Instant.now()); } - public abstract static class MethodWrapper { + public abstract static class MethodWrapper + { public abstract boolean isAccessAllowed(OpaAccessControl opaAccessControl); } - public static class ThrowingMethodWrapper extends MethodWrapper { + public static class ThrowingMethodWrapper + extends MethodWrapper + { private final Consumer callable; - public ThrowingMethodWrapper(Consumer callable) { + public ThrowingMethodWrapper(Consumer callable) + { this.callable = callable; } @Override - public boolean isAccessAllowed(OpaAccessControl opaAccessControl) { + public boolean isAccessAllowed(OpaAccessControl opaAccessControl) + { try { this.callable.accept(opaAccessControl); return true; - } catch (AccessDeniedException e) { + } + catch (AccessDeniedException e) { if (!e.getMessage().contains("Access Denied")) { throw new AssertionError("Expected AccessDenied exception to contain 'Access Denied' in the message"); } @@ -134,15 +141,19 @@ public boolean isAccessAllowed(OpaAccessControl opaAccessControl) { } } - public static class ReturningMethodWrapper extends MethodWrapper { + public static class ReturningMethodWrapper + extends MethodWrapper + { private final Function callable; - public ReturningMethodWrapper(Function callable) { + public ReturningMethodWrapper(Function callable) + { this.callable = callable; } @Override - public boolean isAccessAllowed(OpaAccessControl opaAccessControl) { + public boolean isAccessAllowed(OpaAccessControl opaAccessControl) + { return this.callable.apply(opaAccessControl); } } @@ -194,13 +205,15 @@ public Map buildConfig() convertPropertyToString(attribute.getGetter().invoke(config)).ifPresent( propertyValue -> opaConfigBuilder.put(attribute.getInjectionPoint().getProperty(), propertyValue)); } - } catch (InvocationTargetException|IllegalAccessException e) { + } + catch (InvocationTargetException | IllegalAccessException e) { throw new AssertionError("Failed to build config map", e); } return opaConfigBuilder.buildOrThrow(); } - private static Optional convertPropertyToString(Object value) { + private static Optional convertPropertyToString(Object value) + { if (value instanceof Optional optionalValue) { return optionalValue.map(Object::toString); } @@ -209,7 +222,7 @@ private static Optional convertPropertyToString(Object value) { } static final class TestingSystemAccessControlContext - implements SystemAccessControlFactory.SystemAccessControlContext + implements SystemAccessControlFactory.SystemAccessControlContext { private final String trinoVersion; diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java index 45bf09b8e41fe..2d0cbcbe85810 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControl.java @@ -78,7 +78,7 @@ public class TestOpaAccessControl @Test public void testResponseHasExtraFields() { - InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, 200,""" + InstrumentedHttpClient mockClient = createMockHttpClient(OPA_SERVER_URI, buildValidatingRequestHandler(requestingIdentity, 200, """ { "result": true, "decision_id": "foo", @@ -1173,7 +1173,11 @@ private void testGetRowFilters(String responseContent, List e @Test public void testGetRowFiltersDoesNothingIfNotConfigured() { - InstrumentedHttpClient httpClient = createMockHttpClient(OPA_SERVER_ROW_FILTERING_URI, request -> {throw new AssertionError("Should not have been called");}); + InstrumentedHttpClient httpClient = createMockHttpClient( + OPA_SERVER_ROW_FILTERING_URI, + request -> { + throw new AssertionError("Should not have been called"); + }); OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, httpClient); CatalogSchemaTableName tableName = new CatalogSchemaTableName("some_catalog", "some_schema", "some_table"); @@ -1271,7 +1275,11 @@ private void testGetColumnMask(String responseContent, Optional {throw new AssertionError("Should not have been called");}); + InstrumentedHttpClient httpClient = createMockHttpClient( + OPA_SERVER_COLUMN_MASK_URI, + request -> { + throw new AssertionError("Should not have been called"); + }); OpaAccessControl authorizer = createOpaAuthorizer(OPA_CONFIG_WITH_ONLY_ALLOW, httpClient); CatalogSchemaTableName tableName = new CatalogSchemaTableName("some_catalog", "some_schema", "some_table"); diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlDataFilteringSystem.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlDataFilteringSystem.java index 1cced9f230a8c..d825233b6f164 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlDataFilteringSystem.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaAccessControlDataFilteringSystem.java @@ -101,7 +101,6 @@ public class TestOpaAccessControlDataFilteringSystem private DistributedQueryRunnerHelper runner; - @AfterEach public void teardown() { diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaBatchAccessControlFiltering.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaBatchAccessControlFiltering.java index 3f63e94cc29cf..b2514593261fc 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaBatchAccessControlFiltering.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaBatchAccessControlFiltering.java @@ -261,7 +261,7 @@ public void testFilterColumns() requestingIdentity, parsedRequest -> { String tableName = parsedRequest.at("/input/action/filterResources/0/table/tableName").asText(); - String responseContents = switch(tableName) { + String responseContents = switch (tableName) { case "table_one" -> "{\"result\": [0, 1]}"; case "table_two" -> "{\"result\": [1]}"; default -> "{\"result\": []}"; diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaResponseDecoding.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaResponseDecoding.java index b40bd9213125c..ad08d98878131 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaResponseDecoding.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestOpaResponseDecoding.java @@ -216,9 +216,9 @@ public void testRowFilteringResponseWithExtraFields() public void testRowFilteringResponseIllegalResponseThrows() { testIllegalResponseDecodingThrows(""" - { - "result": ["foo"] - }""", rowFilteringResponseCodec); + { + "result": ["foo"] + }""", rowFilteringResponseCodec); } @Test @@ -280,9 +280,9 @@ public void testColumnMaskingResponseWithExtraFields() public void testColumnMaskingResponseIllegalResponseThrows() { testIllegalResponseDecodingThrows(""" - { - "result": {"foo": "bar"} - }""", columnMaskingResponseCodec); + { + "result": {"foo": "bar"} + }""", columnMaskingResponseCodec); } private void testIllegalResponseDecodingThrows(String rawResponse, JsonCodec codec) From 0db1eb7e33b9e0d48059bc9cb9f730eab527ffc3 Mon Sep 17 00:00:00 2001 From: Pablo Arteaga Date: Fri, 22 Dec 2023 11:49:37 +0100 Subject: [PATCH 11/11] Code review: add JavaDoc to TrinoColumn and enable decision logging for OPA --- .../io/trino/plugin/opa/schema/TrinoColumn.java | 17 +++++++++++++++++ .../java/io/trino/plugin/opa/OpaContainer.java | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoColumn.java b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoColumn.java index 4f21cb04a20b6..1070f02558dc7 100644 --- a/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoColumn.java +++ b/plugin/trino-opa/src/main/java/io/trino/plugin/opa/schema/TrinoColumn.java @@ -19,6 +19,23 @@ import static java.util.Objects.requireNonNull; +/** + * This class is used to represent information about a column for the purposes of column masking. + * It is (perhaps counterintuitively) only used for column masking and not for operations like + * FilterColumns. This is for 3 reasons: + * - API stability between the batch & non-batch modes: sending an array of TrinoColumn objects would be wasteful for + * the batch authorizer mode (as it would repeat the catalog, schema and table names once per column). As such, this + * object is not used for FilterColumns even if batch mode is disabled + * - This object contains in-depth information about the column (e.g. its type), and it may be modified to include + * additional fields in the future. This level of information is not provided to operations like FilterColumns + * - Backwards compatibility + * + * @param catalogName The name of the catalog this column's table belongs to + * @param schemaName The name of the schema this column's table belongs to + * @param tableName The name of the table this column is in + * @param columnName Column name + * @param columnType String representation of the column type + */ public record TrinoColumn( @NotNull String catalogName, @NotNull String schemaName, diff --git a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaContainer.java b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaContainer.java index ddb8c4aaf6fce..673fafcaf4f14 100644 --- a/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaContainer.java +++ b/plugin/trino-opa/src/test/java/io/trino/plugin/opa/OpaContainer.java @@ -37,7 +37,7 @@ public class OpaContainer public OpaContainer() { this.container = new GenericContainer<>(DockerImageName.parse("openpolicyagent/opa:latest-rootless")) - .withCommand("run", "--server", "--addr", ":%d".formatted(OPA_PORT)) + .withCommand("run", "--server", "--addr", ":%d".formatted(OPA_PORT), "--set", "decision_logs.console=true") .withExposedPorts(OPA_PORT) .waitingFor(Wait.forListeningPort()); }