diff --git a/iceberg-rest-server/build.gradle b/iceberg-rest-server/build.gradle index 2d13bfae3..b5cc08d22 100644 --- a/iceberg-rest-server/build.gradle +++ b/iceberg-rest-server/build.gradle @@ -99,6 +99,7 @@ openApiGenerate { LoadTableResult: "org.apache.iceberg.rest.responses.LoadTableResponse" , LoadViewResult: "org.apache.iceberg.rest.responses.LoadTableResponse" , OAuthTokenResponse: "org.apache.iceberg.rest.responses.OAuthTokenResponse" , + OAuthErrorResponse: "org.apache.iceberg.rest.responses.OAuthErrorResponse", RenameTableRequest: "org.apache.iceberg.rest.requests.RenameTableRequest" , ReportMetricsRequest: "org.apache.iceberg.rest.requests.ReportMetricsRequest" , UpdateNamespacePropertiesRequest: "org.apache.iceberg.rest.requests.UpdateNamespacePropertiesRequest" , diff --git a/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/http/HTTPUtil.java b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/http/HTTPUtil.java new file mode 100644 index 000000000..8c123842f --- /dev/null +++ b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/http/HTTPUtil.java @@ -0,0 +1,170 @@ +package org.apache.iceberg.pinnacle.http; + +import org.apache.http.HttpHost; +import org.apache.http.HttpResponse; +import org.apache.http.client.ServiceUnavailableRetryStrategy; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.conn.ssl.DefaultHostnameVerifier; +import org.apache.http.conn.ssl.SSLConnectionSocketFactory; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClientBuilder; +import org.apache.http.impl.conn.PoolingHttpClientConnectionManager; +import org.apache.http.protocol.HttpContext; +import org.apache.http.ssl.SSLContexts; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLContext; +import java.security.Security; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.concurrent.TimeUnit; + +/** Utility class that can be used to make REST HTTP calls to Snowflake with */ +public class HTTPUtil { + + private static final Logger LOGGER = LoggerFactory.getLogger(HTTPUtil.class); + + private static final String FIRST_FAULT_TIMESTAMP = "FIRST_FAULT_TIMESTAMP"; + private static final Duration TOTAL_RETRY_DURATION = Duration.of(120, ChronoUnit.SECONDS); + private static final Duration RETRY_INTERVAL = Duration.of(3, ChronoUnit.SECONDS); + + private static ServiceUnavailableRetryStrategy getServiceUnavailableRetryStrategy() { + return new ServiceUnavailableRetryStrategy() { + final int REQUEST_TIMEOUT = 408; + final int TOO_MANY_REQUESTS = 429; + final int SERVER_ERRORS = 500; + + @Override + public boolean retryRequest( + final HttpResponse response, final int executionCount, final HttpContext context) { + Object firstFault = context.getAttribute(FIRST_FAULT_TIMESTAMP); + long totalRetryDurationSoFarInSeconds = 0; + if (firstFault == null) { + context.setAttribute(FIRST_FAULT_TIMESTAMP, Instant.now()); + } else { + Instant firstFaultInstant = (Instant) firstFault; + Instant now = Instant.now(); + totalRetryDurationSoFarInSeconds = Duration.between(firstFaultInstant, now).getSeconds(); + + if (totalRetryDurationSoFarInSeconds > TOTAL_RETRY_DURATION.getSeconds()) { + LOGGER.info( + String.format( + "Reached the max retry time of %d seconds, not retrying anymore", + TOTAL_RETRY_DURATION.getSeconds())); + return false; + } + } + + int statusCode = response.getStatusLine().getStatusCode(); + boolean needNextRetry = + (statusCode == REQUEST_TIMEOUT + || statusCode == TOO_MANY_REQUESTS + || statusCode >= SERVER_ERRORS); + if (needNextRetry) { + long interval = getRetryInterval(); + LOGGER.info("In retryRequest for service unavailability with statusCode:{}", statusCode); + LOGGER.info( + "Sleep time in millisecond: {}, retryCount: {}, total retry duration: {}s / {}s", + interval, + executionCount, + totalRetryDurationSoFarInSeconds, + TOTAL_RETRY_DURATION.getSeconds()); + } + return needNextRetry; + } + + @Override + public long getRetryInterval() { + return RETRY_INTERVAL.toMillis(); + } + }; + } + + private static volatile CloseableHttpClient httpClient; + + private static PoolingHttpClientConnectionManager connectionManager; + + private static final int DEFAULT_CONNECTION_TIMEOUT_MINUTES = 1; + private static final int DEFAULT_HTTP_CLIENT_SOCKET_TIMEOUT_MINUTES = 5; + + /** + * After how many seconds of inactivity should be idle connections evicted from the connection + * pool. + */ + private static final int DEFAULT_EVICT_IDLE_AFTER_SECONDS = 60; + + // Default is 2, but scaling it up to 100 to match with default_max_connections + private static final int DEFAULT_MAX_CONNECTIONS_PER_ROUTE = 100; + + // 100 is close to max partition number we have seen for a kafka topic ingesting into snowflake. + private static final int DEFAULT_MAX_CONNECTIONS = 100; + + // Interval in which we check if there are connections which needs to be closed. + private static final long IDLE_HTTP_CONNECTION_MONITOR_THREAD_INTERVAL_MS = + TimeUnit.SECONDS.toMillis(5); + + // Only connections that are currently owned, not checked out, are subject to idle timeouts. + private static final int DEFAULT_IDLE_CONNECTION_TIMEOUT_SECONDS = 30; + + public static void initHttpClient() { + Security.setProperty("ocsp.enable", "true"); + SSLContext sslContext = SSLContexts.createDefault(); + + SSLConnectionSocketFactory f = + new SSLConnectionSocketFactory( + sslContext, new String[] {"TLSv1.2"}, null, new DefaultHostnameVerifier()); + // Set connectionTimeout which is the timeout until a connection with the server is established + // Set connectionRequestTimeout which is the time to wait for getting a connection from the + // connection pool + // Set socketTimeout which is the max time gap between two consecutive data packets + RequestConfig requestConfig = + RequestConfig.custom() + .setConnectTimeout( + (int) + TimeUnit.MILLISECONDS.convert( + DEFAULT_CONNECTION_TIMEOUT_MINUTES, TimeUnit.MINUTES)) + .setConnectionRequestTimeout( + (int) + TimeUnit.MILLISECONDS.convert( + DEFAULT_CONNECTION_TIMEOUT_MINUTES, TimeUnit.MINUTES)) + .setSocketTimeout( + (int) + TimeUnit.MILLISECONDS.convert( + DEFAULT_HTTP_CLIENT_SOCKET_TIMEOUT_MINUTES, TimeUnit.MINUTES)) + .build(); + + // Below pooling client connection manager uses time_to_live value as -1 which means it will not + // refresh a persisted connection + connectionManager = new PoolingHttpClientConnectionManager(); + connectionManager.setDefaultMaxPerRoute(DEFAULT_MAX_CONNECTIONS_PER_ROUTE); + connectionManager.setMaxTotal(DEFAULT_MAX_CONNECTIONS); + + // Use an anonymous class to implement the interface ServiceUnavailableRetryStrategy() The max + // retry time is 3. The interval time is backoff. + HttpClientBuilder clientBuilder = + HttpClientBuilder.create() + .setConnectionManager(connectionManager) + .evictIdleConnections(DEFAULT_EVICT_IDLE_AFTER_SECONDS, TimeUnit.SECONDS) + .setSSLSocketFactory(f) + .setServiceUnavailableRetryStrategy(getServiceUnavailableRetryStrategy()) + // TODO add retry handler .setRetryHandler(getHttpRequestRetryHandler()) + .setDefaultRequestConfig(requestConfig); + httpClient = clientBuilder.build(); + } + + /** + * @return Instance of CloseableHttpClient + */ + public static CloseableHttpClient getHttpClient() { + if (httpClient == null) { + synchronized (HTTPUtil.class) { + if (httpClient == null) { + initHttpClient(); + } + } + } + return httpClient; + } +} diff --git a/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/OAuthUtils.java b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/OAuthUtils.java new file mode 100644 index 000000000..d9cc54c73 --- /dev/null +++ b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/OAuthUtils.java @@ -0,0 +1,19 @@ +package org.apache.iceberg.pinnacle.oauth; + +import com.nimbusds.jose.util.Base64; +import org.apache.http.entity.StringEntity; + +/** Simple utility class to assist with OAuth operations*/ +public class OAuthUtils { + + public static final String AUTHORIZATION_HEADER = "Authorization"; + + /** + * @param clientId + * @param clientSecret + * @return basic Authorization Header of the form `base64_encode(client_id:client_secret) + */ + public static String getBasicAuthHeader(String clientId, String clientSecret) { + return Base64.encode(clientId + ":" + clientSecret).toString(); + } +} diff --git a/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeOAuth2Service.java b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeOAuth2Service.java new file mode 100644 index 000000000..6a5294362 --- /dev/null +++ b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeOAuth2Service.java @@ -0,0 +1,99 @@ +package org.apache.iceberg.pinnacle.oauth; + +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.SecurityContext; +import org.apache.iceberg.rest.CallContext; +import org.apache.iceberg.rest.api.IcebergRestOAuth2ApiService; +import org.apache.iceberg.rest.responses.OAuthTokenResponse; +import org.apache.iceberg.rest.snowflake.SnowflakeRealmContext; +import org.apache.iceberg.rest.types.TokenType; + +/** + * Snowflake-specific OAuth2 Service. This class essentially acts as a broker between an External + * Iceberg Client and Snowflake. Specifically, this handles `/v1/oauth/tokens` requests made from + * Iceberg Clients and translates those into calls against Snowflake to fetch credentials derived + * from a `PINNACLE_PRINCIPAL` integration. + */ +public class SnowflakeOAuth2Service implements IcebergRestOAuth2ApiService { + + // I need to figure out how to bootstrap config with the REST Application - this might get pushed + // into the class itself depending on what I find. + private static final SnowflakePinnacleServiceTokenBroker pinnacleServiceTokenBroker = + new SnowflakePinnacleServiceTokenBroker( + "TODO-CLIENT-ID", "TODO-CLIENT-SECRET", "TODO-CLIENT_SECRET-2"); + + public static SnowflakePinnacleServiceTokenBroker getPinnacleServiceTokenBroker() { + return pinnacleServiceTokenBroker; + } + + /** + * Initializes an instance of the SnowflakeOAuth2Service. This is generally expected to be used + * for the duration of the Dropwizard application as initialized in {@link + * org.apache.iceberg.rest.IcebergRestApplication} + */ + public SnowflakeOAuth2Service() {} + + /** + * Handles a `/v1/oauth/tokens` request made from an External Iceberg Client We make a REST + * request to Snowflake's `/oauth/token-request` endpoint with the grant type set to + * `pinnacle_principal` and the Client ID/Secret in the `Authorization` header. + * + *

A note regarding the spec: it appears that client ID/Secret can come in via the request + * payload OR it can come in via the Authorization header. We should look at the Spark clients and + * see where they generally put Client ID/Secret but we'll probably have to check both and pick + * whatever is not null. For now we'll just look at the request payload. As per the docs: + * + *

+   *   This can be sent in the request body, but OAuth2 recommends sending it in
+   *   a Basic Authorization header.
+   * 
+ * + * @param grantType This should either be `client_credentials` for obtaining a token or + * `urn:ietf:params:oauth:grant-type:token-exchange` for exchanging a token. For now we only + * support `client_credentials` + * @param scope the requested scope of the Iceberg Client. TODO define later + * @param clientId (Nullable) the Client ID that should map to a `PINNACLE_PRINCIPAL` integration + * in Snowflake + * @param clientSecret (Nullable) the Client Secret that should map to a `PINNACLE_PRINCIPAL` + * integration in Snowflake + * @param requestedTokenType + * @param subjectToken + * @param subjectTokenType + * @param actorToken + * @param actorTokenType + * @param securityContext + * @return Either an `OAuthTokenResponse` or `OAuthErrorResponse` as defined in the REST spec + */ + @Override + public Response getToken( + String grantType, + String scope, + String clientId, + String clientSecret, + TokenType requestedTokenType, + String subjectToken, + TokenType subjectTokenType, + String actorToken, + TokenType actorTokenType, + SecurityContext securityContext) { + TokenRequestValidator validator = new TokenRequestValidator(); + if (!validator.validateForClientCredentialsFlow(clientId, clientSecret, grantType)) { + // TODO this needs to be `OAuthErrorResponse` as defined in the spec but I can't + // seem to build with it on the plane? + return Response.status(Response.Status.BAD_REQUEST).build(); + } + + // This needs error handling as well + SnowflakePinnaclePrincipalTokenBroker broker = + new SnowflakePinnaclePrincipalTokenBroker(clientId, clientSecret); + SnowflakeTokenResponse tokenResponse = + broker.getToken((SnowflakeRealmContext) CallContext.getCurrentContext().getRealmContext()); + return Response.ok( + OAuthTokenResponse.builder() + .withToken(tokenResponse.getAccessToken()) + .withTokenType("bearer") // TODO there's gotta be a constant somewhere + .setExpirationInSeconds(tokenResponse.getExpiresIn()) + .build()) + .build(); + } +} diff --git a/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakePinnaclePrincipalTokenBroker.java b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakePinnaclePrincipalTokenBroker.java new file mode 100644 index 000000000..830cedbb8 --- /dev/null +++ b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakePinnaclePrincipalTokenBroker.java @@ -0,0 +1,51 @@ +package org.apache.iceberg.pinnacle.oauth; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Responsible for acting as a broker to obtain `PINNACLE_PRINCIPAL` level tokens. */ +public class SnowflakePinnaclePrincipalTokenBroker extends SnowflakeTokenBroker { + + private static final Logger LOGGER = + LoggerFactory.getLogger(SnowflakePinnaclePrincipalTokenBroker.class); + + private static final String PINNACLE_PRINCIPAL_CLIENT_CREDENTIALS_GRANT_TYPE = + "pinnacle_principal_client_credentials"; + + private static final ObjectMapper mapper = new ObjectMapper(); + + private final String clientId; + private final String clientSecret; + + /** + * @param clientId Client ID corresponding to a `PINNACLE_PRINCIPAL` integration + * @param clientSecret Client Secret corresponding to a `PINNACLE_PRINCIPAL` integration + */ + public SnowflakePinnaclePrincipalTokenBroker(final String clientId, final String clientSecret) { + this.clientId = clientId; + this.clientSecret = clientSecret; + } + + @Override + String getAuthHeader() { + return OAuthUtils.getBasicAuthHeader(clientId, clientSecret); + } + + /** + * In the future this will likely have the Client ID/Secret in it and we'll put a PINNACLE_SERVICE + * token in the Authorization Header. There is a JIRA for this somewhere. + */ + @Override + String getPayload() { + SnowflakeTokenRequestPayload payload = new SnowflakeTokenRequestPayload(); + payload.setGrantType(PINNACLE_PRINCIPAL_CLIENT_CREDENTIALS_GRANT_TYPE); + try { + return mapper.writeValueAsString(payload); + } catch (JsonProcessingException e) { + LOGGER.error("Unable to serialize payload", e); + throw new RuntimeException(e); + } + } +} diff --git a/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakePinnacleServiceTokenBroker.java b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakePinnacleServiceTokenBroker.java new file mode 100644 index 000000000..b95f70316 --- /dev/null +++ b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakePinnacleServiceTokenBroker.java @@ -0,0 +1,101 @@ +package org.apache.iceberg.pinnacle.oauth; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.iceberg.rest.CallContext; +import org.apache.iceberg.rest.snowflake.SnowflakeRealmContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** Responsible for obtaining Pinnacle Service tokens on behalf of an account */ +public class SnowflakePinnacleServiceTokenBroker extends SnowflakeTokenBroker { + + private static final Logger LOGGER = + LoggerFactory.getLogger(SnowflakePinnacleServiceTokenBroker.class); + + private static final ObjectMapper mapper = new ObjectMapper(); + + private static final String PINNACLE_SERVICE_CLIENT_CREDENTIALS_GRANT_TYPE = + "pinnacle_service_client_credentials"; + + // This is fixed for a PINNACLE_SERVICE integration in a DEPLOYMENT, not a shard. + // That is, multiple shards can share the same client id/secret for a PINNACLE_SERVICE + // integration. Each will have its own integration ID but all will share the same + // ID/Secret, similar to the Tableau system integrations. + private final String clientId; + + // The Client Secret associated with the Pinnacle Service integration. This can be + // rotated during the rotation of the Pinnacle REST application's lifetime. + private String clientSecret; + + // The Second Client Secret associated with the Pinnacle Service integration. This can be + // rotated during the rotation of the Pinnacle REST application's lifetime. There are two of + // these to support Client Secret rotation without any down time. + private String clientSecret2; + + // This contains a mapping of account name to the latest Pinnacle Service token that + // can be used to make requests to the `oauth/token-info` and `oauth/token-request` endpoints. + // This really needs to be a Guava loading cache with an expiry but I'm on a plane and I don't + // have the JAR locally soooooooo + final ConcurrentHashMap accountToServiceToken; + + /** + * Default constructor + * + * @param pinnacleServiceClientId the Client ID corresponding to a Snowflake PINNACLE_SERVICE + * integration + * @param pinnacleServiceClientSecret the Client Secret corresponding to a Snowflake + * PINNACLE_SERVICE integration + * @param pinnacleServiceClientSecret2 the second Client Secret corresponding to a Snowflake + * PINNACLE_SERVICE integration + */ + public SnowflakePinnacleServiceTokenBroker( + final String pinnacleServiceClientId, + final String pinnacleServiceClientSecret, + final String pinnacleServiceClientSecret2) { + this.clientId = pinnacleServiceClientId; + this.clientSecret = pinnacleServiceClientSecret; + this.clientSecret2 = pinnacleServiceClientSecret2; + this.accountToServiceToken = new ConcurrentHashMap<>(); + } + + /** + * Gets or computes a Pinnacle Service token for an account. + * + *

TODO: this is pretty shitty for now, basically assuming everything worked end-to-end its + * going to populate the map with a token that lasts for an hour and then things are going to fall + * apart until you restart the JVM. The general idea that I have here is that we'll periodically + * refresh the tokens based on expiry. I think there is some Guava Map/Cache loader that'll help + * here but as I can't check the documentation on the plane I'm going to leave it as-is. + * + * @param accountName the name of the account, ex: `testaccount` + */ + public SnowflakeTokenResponse getOrComputeTokenForAccount(String accountName) { + return accountToServiceToken.computeIfAbsent( + accountName, + (key) -> + getToken((SnowflakeRealmContext) CallContext.getCurrentContext().getRealmContext())); + } + + @Override + String getPayload() { + SnowflakeTokenRequestPayload payload = new SnowflakeTokenRequestPayload(); + payload.setGrantType(PINNACLE_SERVICE_CLIENT_CREDENTIALS_GRANT_TYPE); + try { + return mapper.writeValueAsString(payload); + } catch (JsonProcessingException e) { + LOGGER.error("Unable to serialize payload", e); + throw new RuntimeException(e); + } + } + + /** Authorization Header as the `PINNACLE_SYSTEM` Client ID/Secret */ + @Override + String getAuthHeader() { + return OAuthUtils.getBasicAuthHeader(clientId, clientSecret); + } +} diff --git a/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakePinnacleTokenInfoExchangeBroker.java b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakePinnacleTokenInfoExchangeBroker.java new file mode 100644 index 000000000..ccfc31add --- /dev/null +++ b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakePinnacleTokenInfoExchangeBroker.java @@ -0,0 +1,117 @@ +package org.apache.iceberg.pinnacle.oauth; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.http.client.ClientProtocolException; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ContentType; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.iceberg.pinnacle.http.HTTPUtil; +import org.apache.iceberg.rest.CallContext; +import org.apache.iceberg.rest.RealmContext; +import org.apache.iceberg.rest.snowflake.SnowflakeRealmContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; + +import static org.apache.iceberg.pinnacle.oauth.OAuthUtils.AUTHORIZATION_HEADER; + +/** + * Interacts with the new token info/validator endpoint to validate a user-supplied token and get + * back an intermediary token. This endpoint is sufficiently weirder than the + * `/v1/oauth/token-request` endpoint so it gets its own classe rather than extending {@link + * SnowflakeTokenBroker} + */ +public class SnowflakePinnacleTokenInfoExchangeBroker { + private static final Logger LOGGER = + LoggerFactory.getLogger(SnowflakePinnacleTokenInfoExchangeBroker.class); + private static final String OAUTH_TOKEN_INFO_ENDPOINT = "v1/oauth/token-info"; + private static final ObjectMapper mapper = new ObjectMapper(); + + /** + * Returns the authorization header to be used in this request which is the `PINNACLE_SERVE` + * Client ID/Secret + */ + private String getAuthHeader(SnowflakeRealmContext realmContext) { + return SnowflakeOAuth2Service.getPinnacleServiceTokenBroker() + .getToken(realmContext) + .getAccessToken(); + } + + /** + * Returns the payload that is used in this request which is the `PINNACLE_PRINCIPAL` token. This + * may be extended in the future to include additional information + * + * @return + */ + private String getPayload(String pinnaclePrincipalOAuthToken) { + SnowflakeTokenInfoRequest payload = new SnowflakeTokenInfoRequest(); + payload.setToken(pinnaclePrincipalOAuthToken); + try { + return mapper.writeValueAsString(payload); + } catch (JsonProcessingException e) { + LOGGER.error("Unable to serialize payload", e); + throw new RuntimeException(e); + } + } + + /** + * Takes the given token in a request to the Pinnacle REST Application and validates it with + * respect to a `PINNACLE_PRINCIPAL` integration. In addition to confirming whether the token is + * valid or not Snowflake will respond with an "intermediary" token that will be used for the + * duration of this REST Request. The caller of this should validate the response and either return + * an error to the Client OR set the intermediary token in the CallContext. + * + * @param pinnaclePrincipalOAuthToken + * @return + */ + public SnowflakeTokenInfoExchangeResponse validateExchangeToken( + String pinnaclePrincipalOAuthToken) { + SnowflakeRealmContext realmContext = + (SnowflakeRealmContext) CallContext.getCurrentContext().getRealmContext(); + + URI snowflakeURI; + try { + snowflakeURI = + new URIBuilder() + .setScheme(realmContext.getHttpScheme()) + .setHost(realmContext.getAccountUrl()) + .setPath(OAUTH_TOKEN_INFO_ENDPOINT) + .build(); + } catch (URISyntaxException e) { + // TODO: better error handling but if this happens we're kind of hosed + LOGGER.error( + "Cannot generate Snowflake URI. Scheme:{}, URL:{}", + realmContext.getHttpScheme(), + realmContext.getAccountUrl()); + throw new RuntimeException("Cannot generate a Snowflake URL"); + } + + HttpPost httpPost = new HttpPost(snowflakeURI); + httpPost.addHeader(AUTHORIZATION_HEADER, getAuthHeader(realmContext)); + httpPost.setEntity( + new StringEntity(getPayload(pinnaclePrincipalOAuthToken), ContentType.APPLICATION_JSON)); + CloseableHttpClient client = HTTPUtil.getHttpClient(); + try (CloseableHttpResponse response = client.execute(httpPost)) { + // TODO: better error handling for deserialization, right now I'm on a plane and can't google + // this. + LOGGER.info("Attempting to unmarshall insert response - {}", response); + return mapper.readValue( + response.getEntity().getContent(), SnowflakeTokenInfoExchangeResponse.class); + } catch (ClientProtocolException e) { + // TODO: better error handling + LOGGER.error("Unexpected client protocol exception", e); + throw new RuntimeException(e); + } catch (IOException e) { + // TODO: better error handling + LOGGER.error("Unexpected IO exception", e); + throw new RuntimeException(e); + } + } +} diff --git a/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeTokenBroker.java b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeTokenBroker.java new file mode 100644 index 000000000..08cd06619 --- /dev/null +++ b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeTokenBroker.java @@ -0,0 +1,77 @@ +package org.apache.iceberg.pinnacle.oauth; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.http.client.ClientProtocolException; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ContentType; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.iceberg.pinnacle.http.HTTPUtil; +import org.apache.iceberg.rest.CallContext; +import org.apache.iceberg.rest.RealmContext; +import org.apache.iceberg.rest.snowflake.SnowflakeRealmContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; + +import static org.apache.iceberg.pinnacle.oauth.OAuthUtils.AUTHORIZATION_HEADER; + +public abstract class SnowflakeTokenBroker { + + private static final Logger LOGGER = LoggerFactory.getLogger(SnowflakeTokenBroker.class); + private static final String OAUTH_TOKEN_ENDPOINT = "v1/oauth/token-request"; + private static final ObjectMapper mapper = new ObjectMapper(); + + abstract String getPayload(); + + abstract String getAuthHeader(); + + /** + * Calls Snowflake's `/v1/oauth/token-request` endpoint with the Client ID/Secret in the header + * corresponding to the `PINNACLE_PRINCIPAL` application. + * + * @return token scoped to an account's PINNACLE_ + */ + public SnowflakeTokenResponse getToken(SnowflakeRealmContext realmContext) { + URI snowflakeURI; + try { + snowflakeURI = + new URIBuilder() + .setScheme(realmContext.getHttpScheme()) + .setHost(realmContext.getAccountUrl()) + .setPath(OAUTH_TOKEN_ENDPOINT) + .build(); + } catch (URISyntaxException e) { + // TODO: better error handling but if this happens we're kind of hosed + LOGGER.error( + "Cannot generate Snowflake URI. Scheme:{}, URL:{}", + realmContext.getHttpScheme(), + realmContext.getAccountUrl()); + throw new RuntimeException("Cannot generate a Snowflake URL"); + } + + HttpPost httpPost = new HttpPost(snowflakeURI); + httpPost.addHeader(AUTHORIZATION_HEADER, getAuthHeader()); + httpPost.setEntity(new StringEntity(getPayload(), ContentType.APPLICATION_JSON)); + CloseableHttpClient client = HTTPUtil.getHttpClient(); + try (CloseableHttpResponse response = client.execute(httpPost)) { + // TODO: better error handling for deserialization, right now I'm on a plane and can't google + // this. + LOGGER.info("Attempting to unmarshall insert response - {}", response); + return mapper.readValue(response.getEntity().getContent(), SnowflakeTokenResponse.class); + } catch (ClientProtocolException e) { + // TODO: better error handling + LOGGER.error("Unexpected client protocol exception", e); + throw new RuntimeException(e); + } catch (IOException e) { + // TODO: better error handling + LOGGER.error("Unexpected IO exception", e); + throw new RuntimeException(e); + } + } +} diff --git a/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeTokenInfoExchangeResponse.java b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeTokenInfoExchangeResponse.java new file mode 100644 index 000000000..c5f9c16b3 --- /dev/null +++ b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeTokenInfoExchangeResponse.java @@ -0,0 +1,103 @@ +package org.apache.iceberg.pinnacle.oauth; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public class SnowflakeTokenInfoExchangeResponse { + + private boolean active; + + @JsonProperty("active") + public boolean isActive() { + return active; + } + + @JsonProperty("active") + public void setActive(boolean active) { + this.active = active; + } + + private String scope; + + @JsonProperty("scope") + public String getScope() { + return scope; + } + + @JsonProperty("scope") + public void setScope(String scope) { + this.scope = scope; + } + + private String clientId; + + @JsonProperty("client_id") + public String getClientId() { + return clientId; + } + + @JsonProperty("client_id") + public void setClientId(String clientId) { + this.clientId = clientId; + } + + private String tokenType; + + @JsonProperty("token_type") + public String getTokenType() { + return tokenType; + } + + @JsonProperty("token_type") + public void setTokenType(String tokenType) { + this.tokenType = tokenType; + } + + private Long exp; + + @JsonProperty("exp") + public Long getExp() { + return exp; + } + + @JsonProperty("exp") + public void setExp(Long exp) { + this.exp = exp; + } + + private String sub; + + @JsonProperty("sub") + public String getSub() { + return sub; + } + + @JsonProperty("sub") + public void setSub(String sub) { + this.sub = sub; + } + + private String aud; + + @JsonProperty("aud") + public String getAud() { + return aud; + } + + @JsonProperty("aud") + public void setAud(String aud) { + this.aud = aud; + } + + @JsonProperty("iss") + private String iss; + + @JsonProperty("iss") + public String getIss() { + return iss; + } + + @JsonProperty("iss") + public void setIss(String iss) { + this.iss = iss; + } +} diff --git a/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeTokenInfoRequest.java b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeTokenInfoRequest.java new file mode 100644 index 000000000..1824fd2e3 --- /dev/null +++ b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeTokenInfoRequest.java @@ -0,0 +1,20 @@ +package org.apache.iceberg.pinnacle.oauth; + +import org.codehaus.jackson.annotate.JsonProperty; + +/** Encapsulates the request data to be used in a request to the `/v1/oauth/token-info` endpoint */ +public class SnowflakeTokenInfoRequest { + + private String token; + + public SnowflakeTokenInfoRequest() {} + + @JsonProperty("token") + public String getToken() { + return token; + } + + public void setToken(String token) { + this.token = token; + } +} diff --git a/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeTokenRequestPayload.java b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeTokenRequestPayload.java new file mode 100644 index 000000000..06c863d80 --- /dev/null +++ b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeTokenRequestPayload.java @@ -0,0 +1,21 @@ +package org.apache.iceberg.pinnacle.oauth; + +import org.codehaus.jackson.annotate.JsonProperty; + +/** Basically a Pojo */ +public class SnowflakeTokenRequestPayload { + + /** Token Request Payload. */ + public SnowflakeTokenRequestPayload() {} + + private String grantType; + + @JsonProperty("grant_type") + public String getGrantType() { + return grantType; + } + + public void setGrantType(String grantType) { + this.grantType = grantType; + } +} diff --git a/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeTokenResponse.java b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeTokenResponse.java new file mode 100644 index 000000000..a14b01980 --- /dev/null +++ b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeTokenResponse.java @@ -0,0 +1,35 @@ +package org.apache.iceberg.pinnacle.oauth; + +import org.codehaus.jackson.annotate.JsonProperty; + +public class SnowflakeTokenResponse { + + private String accessToken; + + private int expiresIn; + + /** Public constructor needed for deserialization */ + public SnowflakeTokenResponse() { + + } + + @JsonProperty("access_token") + public String getAccessToken() { + return accessToken; + } + + @JsonProperty("access_token") + public void setAccessToken(String accessToken) { + this.accessToken = accessToken; + } + + @JsonProperty("expires_in") + public int getExpiresIn() { + return expiresIn; + } + + @JsonProperty("expires_in") + public void setExpiresIn(int expiresIn) { + this.expiresIn = expiresIn; + } +} diff --git a/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/TokenRequestValidator.java b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/TokenRequestValidator.java new file mode 100644 index 000000000..aa354b677 --- /dev/null +++ b/iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/TokenRequestValidator.java @@ -0,0 +1,28 @@ +package org.apache.iceberg.pinnacle.oauth; + +import java.util.Set; +import java.util.logging.Logger; + +public class TokenRequestValidator { + + static final Logger LOGGER = Logger.getLogger(TokenRequestValidator.class.getName()); + + static final Set ALLOWED_GRANT_TYPES = Set.of("client_credentials"); + + /** Default constructor */ + public TokenRequestValidator() {} + + public boolean validateForClientCredentialsFlow( + final String clientId, final String clientSecret, final String grantType) { + if (clientId == null || clientId.isEmpty() || clientSecret == null || clientSecret.isEmpty()) { + // TODO: Figure out how to get the authorization header from `securityContext` + LOGGER.info("Missing Client ID or Client Secret in Request Body"); + return false; + } + if (!ALLOWED_GRANT_TYPES.contains(grantType)) { + LOGGER.info("Invalid grant type: " + grantType); + return false; + } + return true; + } +} diff --git a/iceberg-rest-server/src/main/java/org/apache/iceberg/rest/snowflake/SnowflakeCallContext.java b/iceberg-rest-server/src/main/java/org/apache/iceberg/rest/snowflake/SnowflakeCallContext.java index 74276188f..e83369c73 100644 --- a/iceberg-rest-server/src/main/java/org/apache/iceberg/rest/snowflake/SnowflakeCallContext.java +++ b/iceberg-rest-server/src/main/java/org/apache/iceberg/rest/snowflake/SnowflakeCallContext.java @@ -5,32 +5,62 @@ public class SnowflakeCallContext implements CallContext { - private final RealmContext realmContext; - - /** - * Default constructor - * @param realmContext - */ - SnowflakeCallContext(RealmContext realmContext) { - this.realmContext = realmContext; - } + private final RealmContext realmContext; - @Override - public RealmContext getRealmContext() { - return realmContext; - } + // This token is obtained by validating and exchanging a token associated with a + // `PINNACLE_PRINCIPAL` integration for an intermediary one as fetched via + // SnowflakePinnacleTokenInfoExchangeBroker. Depending on the request (namely `/v1/oauth/tokens`) + // this may not be set and it also may not be needed. + private String pinnaclePrincipalIntermediaryToken; - /** - * This will return the identifier of the Pinnacle Principal - * @return - */ - @Override - public String getUser() { - return ""; - } + /** + * Default constructor + * + * @param realmContext + */ + SnowflakeCallContext(RealmContext realmContext) { + this.realmContext = realmContext; + } + + @Override + public RealmContext getRealmContext() { + return realmContext; + } + + /** + * This will return the identifier of the Pinnacle Principal + * + * @return + */ + @Override + public String getUser() { + return ""; + } + + @Override + public String getRole() { + return ""; + } + + /** + * @return The Pinnacle Principal Intermediary Token if set, or Null if no such token has been + * set. If you expected this to return a non-null value then you probably needed to invoke + * SnowflakePinnacleTokenInfoExchangeBroker at some point + */ + public String getPinnaclePrincipalIntermediaryToken() { + return pinnaclePrincipalIntermediaryToken; + } - @Override - public String getRole() { - return ""; + /** + * Sets the intermediary token to be used for the duration of this Call Context. As these tokens + * are valid for an hour and we currently do not expect any call to last an hour this API only + * supports setting the token once. + * + * @param pinnaclePrincipalIntermediaryToken + */ + public void setPinnaclePrincipalIntermediaryToken(String pinnaclePrincipalIntermediaryToken) { + if (this.pinnaclePrincipalIntermediaryToken == null) { + this.pinnaclePrincipalIntermediaryToken = pinnaclePrincipalIntermediaryToken; } + } } diff --git a/iceberg-rest-server/src/main/java/org/apache/iceberg/rest/snowflake/SnowflakeContextResolver.java b/iceberg-rest-server/src/main/java/org/apache/iceberg/rest/snowflake/SnowflakeContextResolver.java index a28b258eb..e444ca7aa 100644 --- a/iceberg-rest-server/src/main/java/org/apache/iceberg/rest/snowflake/SnowflakeContextResolver.java +++ b/iceberg-rest-server/src/main/java/org/apache/iceberg/rest/snowflake/SnowflakeContextResolver.java @@ -69,14 +69,19 @@ public RealmContext resolveRealmContext( // "https://pinnacle.account.snowflakecomputing.com" // so get the host and strip "pinnacle" from it String accountUrl; + String httpScheme; + int httpPort; try { - String host = new URI(requestUrl).getHost(); + URI uri = new URI(requestUrl); + String host = uri.getHost(); + httpScheme = uri.getScheme(); + httpPort = uri.getPort(); accountUrl = host.replace("pinnacle.", ""); } catch (URISyntaxException e) { // TODO Add better / Pinnacle REST Service generic error handling LOGGER.info("Error parsing request URL: " + requestUrl); throw new RuntimeException("Unable to parse the provided account"); } - return new SnowflakeRealmContext(accountUrl, getAccountNameFromURL(accountUrl)); + return new SnowflakeRealmContext(accountUrl, getAccountNameFromURL(accountUrl), httpScheme, httpPort); } } diff --git a/iceberg-rest-server/src/main/java/org/apache/iceberg/rest/snowflake/SnowflakeRealmContext.java b/iceberg-rest-server/src/main/java/org/apache/iceberg/rest/snowflake/SnowflakeRealmContext.java index 30decbdf3..2ffffdbac 100644 --- a/iceberg-rest-server/src/main/java/org/apache/iceberg/rest/snowflake/SnowflakeRealmContext.java +++ b/iceberg-rest-server/src/main/java/org/apache/iceberg/rest/snowflake/SnowflakeRealmContext.java @@ -5,31 +5,52 @@ /** The Snowflake "Realm" Context, i.e. the account that is making the request */ public class SnowflakeRealmContext implements RealmContext { - // Base Account URL - ex "myaccount.snowflakecomputing.com" - private final String accountUrl; - - // The name of the account - ex "myaccount" - private final String accountName; - - SnowflakeRealmContext(final String accountUrl, final String accountName) { - this.accountUrl = accountUrl; - this.accountName = accountName; - } - - public String getAccountUrl() { - return accountUrl; - } - - public String getAccountName() { - return accountName; + // Base Account URL - ex "myaccount.snowflakecomputing.com" + private final String accountUrl; + + // The name of the account - ex "myaccount" + private final String accountName; + + // The http scheme - should be https but may be http for local testing + private final String httpScheme; + + // http port - if no port is specified this is -1; + private final int httpPort; + + SnowflakeRealmContext( + final String accountUrl, + final String accountName, + final String httpScheme, + final int httpPort) { + this.accountUrl = accountUrl; + this.accountName = accountName; + this.httpScheme = httpScheme; + this.httpPort = httpPort; + } + + public String getAccountUrl() { + return accountUrl; + } + + public String getAccountName() { + return accountName; + } + + /** + * The Realm Identifier for Snowflake is simply the name of the account + * + * @return + */ + @Override + public String getRealmIdentifier() { + return accountName; + } + + public String getHttpScheme() { + return httpScheme; } - /** - * The Realm Identifier for Snowflake is simply the name of the account - * @return - */ - @Override - public String getRealmIdentifier() { - return accountName; + public int getHttpPort() { + return httpPort; } } diff --git a/iceberg-rest-server/src/test/java/org/apache/iceberg/pinnacle/oauth/TokenRequestValidatorTest.java b/iceberg-rest-server/src/test/java/org/apache/iceberg/pinnacle/oauth/TokenRequestValidatorTest.java new file mode 100644 index 000000000..7487f296b --- /dev/null +++ b/iceberg-rest-server/src/test/java/org/apache/iceberg/pinnacle/oauth/TokenRequestValidatorTest.java @@ -0,0 +1,41 @@ +package org.apache.iceberg.pinnacle.oauth; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + + +public class TokenRequestValidatorTest { + @Test + public void testValidateForClientCredentialsFlowNullClientId() { + Assertions.assertFalse( + new TokenRequestValidator().validateForClientCredentialsFlow(null, "notnull", "notnull")); + Assertions.assertFalse( + new TokenRequestValidator().validateForClientCredentialsFlow("", "notnull", "notnull")); + } + + @Test + public void testValidateForClientCredentialsFlowNullClientSecret() { + Assertions.assertFalse( + new TokenRequestValidator().validateForClientCredentialsFlow("client-id", null, "notnull")); + Assertions.assertFalse( + new TokenRequestValidator().validateForClientCredentialsFlow("client-id", "", "notnull")); + } + + @Test + public void testValidateForClientCredentialsFlowInvalidGrantType() { + Assertions.assertFalse( + new TokenRequestValidator() + .validateForClientCredentialsFlow( + "client-id", "client-secret", "not-client-credentials")); + Assertions.assertFalse( + new TokenRequestValidator() + .validateForClientCredentialsFlow("client-id", "client-secret", "grant")); + } + + @Test + public void testValidateForClientCredentialsFlowAllValid() { + Assertions.assertTrue( + new TokenRequestValidator() + .validateForClientCredentialsFlow("client-id", "client-secret", "client_credentials")); + } +} diff --git a/iceberg-rest-server/src/test/java/org/apache/iceberg/rest/snowflake/SnowflakeContextResolverTest.java b/iceberg-rest-server/src/test/java/org/apache/iceberg/rest/snowflake/SnowflakeContextResolverTest.java index 1874ec542..2242794bf 100644 --- a/iceberg-rest-server/src/test/java/org/apache/iceberg/rest/snowflake/SnowflakeContextResolverTest.java +++ b/iceberg-rest-server/src/test/java/org/apache/iceberg/rest/snowflake/SnowflakeContextResolverTest.java @@ -9,7 +9,6 @@ import java.util.HashMap; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.*; public class SnowflakeContextResolverTest { @@ -32,11 +31,36 @@ void resolveCallContext() { new HashMap<>(), new HashMap<>()); assertThat(context.getRealmContext()) - .returns("TESTACCOUNT", RealmContext::getRealmIdentifier) - .isInstanceOf(SnowflakeRealmContext.class) - .asInstanceOf(InstanceOfAssertFactories.type(SnowflakeRealmContext.class)) - .returns("TESTACCOUNT", SnowflakeRealmContext::getAccountName) - .returns("testaccount.snowflakecomputing.com", SnowflakeRealmContext::getAccountUrl); + .returns("TESTACCOUNT", RealmContext::getRealmIdentifier) + .isInstanceOf(SnowflakeRealmContext.class) + .asInstanceOf(InstanceOfAssertFactories.type(SnowflakeRealmContext.class)) + .returns("TESTACCOUNT", SnowflakeRealmContext::getAccountName) + .returns("testaccount.snowflakecomputing.com", SnowflakeRealmContext::getAccountUrl); + Assertions.assertNull(((SnowflakeCallContext) context).getPinnaclePrincipalIntermediaryToken()); + } + + @Test + public void testCallContextPinnaclePrincipalIntermediaryToken() { + RealmContext realmContext = + new SnowflakeContextResolver() + .resolveRealmContext( + "https://pinnacle.testaccount.snowflakecomputing.com:8181", + "POST", + "api/catalog/v1/oauth/tokens", + new HashMap<>(), + new HashMap<>()); + CallContext context = + new SnowflakeContextResolver() + .resolveCallContext( + realmContext, + "POST", + "api/catalog/v1/oauth/tokens", + new HashMap<>(), + new HashMap<>()); + ((SnowflakeCallContext) context).setPinnaclePrincipalIntermediaryToken("intermediaryToken"); + Assertions.assertEquals("intermediaryToken", ((SnowflakeCallContext) context).getPinnaclePrincipalIntermediaryToken()); + ((SnowflakeCallContext) context).setPinnaclePrincipalIntermediaryToken("attemptOverride"); + Assertions.assertEquals("intermediaryToken", ((SnowflakeCallContext) context).getPinnaclePrincipalIntermediaryToken()); } @Test @@ -54,7 +78,6 @@ void resolveRealmContextValidRequestURL() { Assertions.assertEquals( "testaccount.snowflakecomputing.com", ((SnowflakeRealmContext) realmContext).getAccountUrl()); - } @Test