forked from apache/polaris
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds plumbing for Snowflake OAuth Context (apache#23)
This does the following: - Adds a broker to manage `PINNACLE_SERVICE` tokens - Adds a broker to manage `PINNACLE_PRINCIPAL` tokens - Adds a broker to interact with the `oauth/token-info` endpoint in GS - Extends the CallCtx plumbing - Adds a HTTPUtility to make calls to GS with Notably missing: - End-to-end tests (I wrote this all on a plane) - Error handling There are extenstive tests
- Loading branch information
1 parent
3cfe087
commit 87ba4ea
Showing
18 changed files
with
1,019 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
170 changes: 170 additions & 0 deletions
170
iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/http/HTTPUtil.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
package org.apache.iceberg.pinnacle.http; | ||
|
||
import org.apache.http.HttpHost; | ||
import org.apache.http.HttpResponse; | ||
import org.apache.http.client.ServiceUnavailableRetryStrategy; | ||
import org.apache.http.client.config.RequestConfig; | ||
import org.apache.http.conn.ssl.DefaultHostnameVerifier; | ||
import org.apache.http.conn.ssl.SSLConnectionSocketFactory; | ||
import org.apache.http.impl.client.CloseableHttpClient; | ||
import org.apache.http.impl.client.HttpClientBuilder; | ||
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager; | ||
import org.apache.http.protocol.HttpContext; | ||
import org.apache.http.ssl.SSLContexts; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import javax.net.ssl.SSLContext; | ||
import java.security.Security; | ||
import java.time.Duration; | ||
import java.time.Instant; | ||
import java.time.temporal.ChronoUnit; | ||
import java.util.concurrent.TimeUnit; | ||
|
||
/** Utility class that can be used to make REST HTTP calls to Snowflake with */ | ||
public class HTTPUtil { | ||
|
||
private static final Logger LOGGER = LoggerFactory.getLogger(HTTPUtil.class); | ||
|
||
private static final String FIRST_FAULT_TIMESTAMP = "FIRST_FAULT_TIMESTAMP"; | ||
private static final Duration TOTAL_RETRY_DURATION = Duration.of(120, ChronoUnit.SECONDS); | ||
private static final Duration RETRY_INTERVAL = Duration.of(3, ChronoUnit.SECONDS); | ||
|
||
private static ServiceUnavailableRetryStrategy getServiceUnavailableRetryStrategy() { | ||
return new ServiceUnavailableRetryStrategy() { | ||
final int REQUEST_TIMEOUT = 408; | ||
final int TOO_MANY_REQUESTS = 429; | ||
final int SERVER_ERRORS = 500; | ||
|
||
@Override | ||
public boolean retryRequest( | ||
final HttpResponse response, final int executionCount, final HttpContext context) { | ||
Object firstFault = context.getAttribute(FIRST_FAULT_TIMESTAMP); | ||
long totalRetryDurationSoFarInSeconds = 0; | ||
if (firstFault == null) { | ||
context.setAttribute(FIRST_FAULT_TIMESTAMP, Instant.now()); | ||
} else { | ||
Instant firstFaultInstant = (Instant) firstFault; | ||
Instant now = Instant.now(); | ||
totalRetryDurationSoFarInSeconds = Duration.between(firstFaultInstant, now).getSeconds(); | ||
|
||
if (totalRetryDurationSoFarInSeconds > TOTAL_RETRY_DURATION.getSeconds()) { | ||
LOGGER.info( | ||
String.format( | ||
"Reached the max retry time of %d seconds, not retrying anymore", | ||
TOTAL_RETRY_DURATION.getSeconds())); | ||
return false; | ||
} | ||
} | ||
|
||
int statusCode = response.getStatusLine().getStatusCode(); | ||
boolean needNextRetry = | ||
(statusCode == REQUEST_TIMEOUT | ||
|| statusCode == TOO_MANY_REQUESTS | ||
|| statusCode >= SERVER_ERRORS); | ||
if (needNextRetry) { | ||
long interval = getRetryInterval(); | ||
LOGGER.info("In retryRequest for service unavailability with statusCode:{}", statusCode); | ||
LOGGER.info( | ||
"Sleep time in millisecond: {}, retryCount: {}, total retry duration: {}s / {}s", | ||
interval, | ||
executionCount, | ||
totalRetryDurationSoFarInSeconds, | ||
TOTAL_RETRY_DURATION.getSeconds()); | ||
} | ||
return needNextRetry; | ||
} | ||
|
||
@Override | ||
public long getRetryInterval() { | ||
return RETRY_INTERVAL.toMillis(); | ||
} | ||
}; | ||
} | ||
|
||
private static volatile CloseableHttpClient httpClient; | ||
|
||
private static PoolingHttpClientConnectionManager connectionManager; | ||
|
||
private static final int DEFAULT_CONNECTION_TIMEOUT_MINUTES = 1; | ||
private static final int DEFAULT_HTTP_CLIENT_SOCKET_TIMEOUT_MINUTES = 5; | ||
|
||
/** | ||
* After how many seconds of inactivity should be idle connections evicted from the connection | ||
* pool. | ||
*/ | ||
private static final int DEFAULT_EVICT_IDLE_AFTER_SECONDS = 60; | ||
|
||
// Default is 2, but scaling it up to 100 to match with default_max_connections | ||
private static final int DEFAULT_MAX_CONNECTIONS_PER_ROUTE = 100; | ||
|
||
// 100 is close to max partition number we have seen for a kafka topic ingesting into snowflake. | ||
private static final int DEFAULT_MAX_CONNECTIONS = 100; | ||
|
||
// Interval in which we check if there are connections which needs to be closed. | ||
private static final long IDLE_HTTP_CONNECTION_MONITOR_THREAD_INTERVAL_MS = | ||
TimeUnit.SECONDS.toMillis(5); | ||
|
||
// Only connections that are currently owned, not checked out, are subject to idle timeouts. | ||
private static final int DEFAULT_IDLE_CONNECTION_TIMEOUT_SECONDS = 30; | ||
|
||
public static void initHttpClient() { | ||
Security.setProperty("ocsp.enable", "true"); | ||
SSLContext sslContext = SSLContexts.createDefault(); | ||
|
||
SSLConnectionSocketFactory f = | ||
new SSLConnectionSocketFactory( | ||
sslContext, new String[] {"TLSv1.2"}, null, new DefaultHostnameVerifier()); | ||
// Set connectionTimeout which is the timeout until a connection with the server is established | ||
// Set connectionRequestTimeout which is the time to wait for getting a connection from the | ||
// connection pool | ||
// Set socketTimeout which is the max time gap between two consecutive data packets | ||
RequestConfig requestConfig = | ||
RequestConfig.custom() | ||
.setConnectTimeout( | ||
(int) | ||
TimeUnit.MILLISECONDS.convert( | ||
DEFAULT_CONNECTION_TIMEOUT_MINUTES, TimeUnit.MINUTES)) | ||
.setConnectionRequestTimeout( | ||
(int) | ||
TimeUnit.MILLISECONDS.convert( | ||
DEFAULT_CONNECTION_TIMEOUT_MINUTES, TimeUnit.MINUTES)) | ||
.setSocketTimeout( | ||
(int) | ||
TimeUnit.MILLISECONDS.convert( | ||
DEFAULT_HTTP_CLIENT_SOCKET_TIMEOUT_MINUTES, TimeUnit.MINUTES)) | ||
.build(); | ||
|
||
// Below pooling client connection manager uses time_to_live value as -1 which means it will not | ||
// refresh a persisted connection | ||
connectionManager = new PoolingHttpClientConnectionManager(); | ||
connectionManager.setDefaultMaxPerRoute(DEFAULT_MAX_CONNECTIONS_PER_ROUTE); | ||
connectionManager.setMaxTotal(DEFAULT_MAX_CONNECTIONS); | ||
|
||
// Use an anonymous class to implement the interface ServiceUnavailableRetryStrategy() The max | ||
// retry time is 3. The interval time is backoff. | ||
HttpClientBuilder clientBuilder = | ||
HttpClientBuilder.create() | ||
.setConnectionManager(connectionManager) | ||
.evictIdleConnections(DEFAULT_EVICT_IDLE_AFTER_SECONDS, TimeUnit.SECONDS) | ||
.setSSLSocketFactory(f) | ||
.setServiceUnavailableRetryStrategy(getServiceUnavailableRetryStrategy()) | ||
// TODO add retry handler .setRetryHandler(getHttpRequestRetryHandler()) | ||
.setDefaultRequestConfig(requestConfig); | ||
httpClient = clientBuilder.build(); | ||
} | ||
|
||
/** | ||
* @return Instance of CloseableHttpClient | ||
*/ | ||
public static CloseableHttpClient getHttpClient() { | ||
if (httpClient == null) { | ||
synchronized (HTTPUtil.class) { | ||
if (httpClient == null) { | ||
initHttpClient(); | ||
} | ||
} | ||
} | ||
return httpClient; | ||
} | ||
} |
19 changes: 19 additions & 0 deletions
19
iceberg-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/OAuthUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package org.apache.iceberg.pinnacle.oauth; | ||
|
||
import com.nimbusds.jose.util.Base64; | ||
import org.apache.http.entity.StringEntity; | ||
|
||
/** Simple utility class to assist with OAuth operations*/ | ||
public class OAuthUtils { | ||
|
||
public static final String AUTHORIZATION_HEADER = "Authorization"; | ||
|
||
/** | ||
* @param clientId | ||
* @param clientSecret | ||
* @return basic Authorization Header of the form `base64_encode(client_id:client_secret) | ||
*/ | ||
public static String getBasicAuthHeader(String clientId, String clientSecret) { | ||
return Base64.encode(clientId + ":" + clientSecret).toString(); | ||
} | ||
} |
99 changes: 99 additions & 0 deletions
99
...g-rest-server/src/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakeOAuth2Service.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
package org.apache.iceberg.pinnacle.oauth; | ||
|
||
import jakarta.ws.rs.core.Response; | ||
import jakarta.ws.rs.core.SecurityContext; | ||
import org.apache.iceberg.rest.CallContext; | ||
import org.apache.iceberg.rest.api.IcebergRestOAuth2ApiService; | ||
import org.apache.iceberg.rest.responses.OAuthTokenResponse; | ||
import org.apache.iceberg.rest.snowflake.SnowflakeRealmContext; | ||
import org.apache.iceberg.rest.types.TokenType; | ||
|
||
/** | ||
* Snowflake-specific OAuth2 Service. This class essentially acts as a broker between an External | ||
* Iceberg Client and Snowflake. Specifically, this handles `/v1/oauth/tokens` requests made from | ||
* Iceberg Clients and translates those into calls against Snowflake to fetch credentials derived | ||
* from a `PINNACLE_PRINCIPAL` integration. | ||
*/ | ||
public class SnowflakeOAuth2Service implements IcebergRestOAuth2ApiService { | ||
|
||
// I need to figure out how to bootstrap config with the REST Application - this might get pushed | ||
// into the class itself depending on what I find. | ||
private static final SnowflakePinnacleServiceTokenBroker pinnacleServiceTokenBroker = | ||
new SnowflakePinnacleServiceTokenBroker( | ||
"TODO-CLIENT-ID", "TODO-CLIENT-SECRET", "TODO-CLIENT_SECRET-2"); | ||
|
||
public static SnowflakePinnacleServiceTokenBroker getPinnacleServiceTokenBroker() { | ||
return pinnacleServiceTokenBroker; | ||
} | ||
|
||
/** | ||
* Initializes an instance of the SnowflakeOAuth2Service. This is generally expected to be used | ||
* for the duration of the Dropwizard application as initialized in {@link | ||
* org.apache.iceberg.rest.IcebergRestApplication} | ||
*/ | ||
public SnowflakeOAuth2Service() {} | ||
|
||
/** | ||
* Handles a `/v1/oauth/tokens` request made from an External Iceberg Client We make a REST | ||
* request to Snowflake's `/oauth/token-request` endpoint with the grant type set to | ||
* `pinnacle_principal` and the Client ID/Secret in the `Authorization` header. | ||
* | ||
* <p>A note regarding the spec: it appears that client ID/Secret can come in via the request | ||
* payload OR it can come in via the Authorization header. We should look at the Spark clients and | ||
* see where they generally put Client ID/Secret but we'll probably have to check both and pick | ||
* whatever is not null. For now we'll just look at the request payload. As per the docs: | ||
* | ||
* <pre> | ||
* This can be sent in the request body, but OAuth2 recommends sending it in | ||
* a Basic Authorization header. | ||
* </pre> | ||
* | ||
* @param grantType This should either be `client_credentials` for obtaining a token or | ||
* `urn:ietf:params:oauth:grant-type:token-exchange` for exchanging a token. For now we only | ||
* support `client_credentials` | ||
* @param scope the requested scope of the Iceberg Client. TODO define later | ||
* @param clientId (Nullable) the Client ID that should map to a `PINNACLE_PRINCIPAL` integration | ||
* in Snowflake | ||
* @param clientSecret (Nullable) the Client Secret that should map to a `PINNACLE_PRINCIPAL` | ||
* integration in Snowflake | ||
* @param requestedTokenType | ||
* @param subjectToken | ||
* @param subjectTokenType | ||
* @param actorToken | ||
* @param actorTokenType | ||
* @param securityContext | ||
* @return Either an `OAuthTokenResponse` or `OAuthErrorResponse` as defined in the REST spec | ||
*/ | ||
@Override | ||
public Response getToken( | ||
String grantType, | ||
String scope, | ||
String clientId, | ||
String clientSecret, | ||
TokenType requestedTokenType, | ||
String subjectToken, | ||
TokenType subjectTokenType, | ||
String actorToken, | ||
TokenType actorTokenType, | ||
SecurityContext securityContext) { | ||
TokenRequestValidator validator = new TokenRequestValidator(); | ||
if (!validator.validateForClientCredentialsFlow(clientId, clientSecret, grantType)) { | ||
// TODO this needs to be `OAuthErrorResponse` as defined in the spec but I can't | ||
// seem to build with it on the plane? | ||
return Response.status(Response.Status.BAD_REQUEST).build(); | ||
} | ||
|
||
// This needs error handling as well | ||
SnowflakePinnaclePrincipalTokenBroker broker = | ||
new SnowflakePinnaclePrincipalTokenBroker(clientId, clientSecret); | ||
SnowflakeTokenResponse tokenResponse = | ||
broker.getToken((SnowflakeRealmContext) CallContext.getCurrentContext().getRealmContext()); | ||
return Response.ok( | ||
OAuthTokenResponse.builder() | ||
.withToken(tokenResponse.getAccessToken()) | ||
.withTokenType("bearer") // TODO there's gotta be a constant somewhere | ||
.setExpirationInSeconds(tokenResponse.getExpiresIn()) | ||
.build()) | ||
.build(); | ||
} | ||
} |
51 changes: 51 additions & 0 deletions
51
...rc/main/java/org/apache/iceberg/pinnacle/oauth/SnowflakePinnaclePrincipalTokenBroker.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package org.apache.iceberg.pinnacle.oauth; | ||
|
||
import com.fasterxml.jackson.core.JsonProcessingException; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
/** Responsible for acting as a broker to obtain `PINNACLE_PRINCIPAL` level tokens. */ | ||
public class SnowflakePinnaclePrincipalTokenBroker extends SnowflakeTokenBroker { | ||
|
||
private static final Logger LOGGER = | ||
LoggerFactory.getLogger(SnowflakePinnaclePrincipalTokenBroker.class); | ||
|
||
private static final String PINNACLE_PRINCIPAL_CLIENT_CREDENTIALS_GRANT_TYPE = | ||
"pinnacle_principal_client_credentials"; | ||
|
||
private static final ObjectMapper mapper = new ObjectMapper(); | ||
|
||
private final String clientId; | ||
private final String clientSecret; | ||
|
||
/** | ||
* @param clientId Client ID corresponding to a `PINNACLE_PRINCIPAL` integration | ||
* @param clientSecret Client Secret corresponding to a `PINNACLE_PRINCIPAL` integration | ||
*/ | ||
public SnowflakePinnaclePrincipalTokenBroker(final String clientId, final String clientSecret) { | ||
this.clientId = clientId; | ||
this.clientSecret = clientSecret; | ||
} | ||
|
||
@Override | ||
String getAuthHeader() { | ||
return OAuthUtils.getBasicAuthHeader(clientId, clientSecret); | ||
} | ||
|
||
/** | ||
* In the future this will likely have the Client ID/Secret in it and we'll put a PINNACLE_SERVICE | ||
* token in the Authorization Header. There is a JIRA for this somewhere. | ||
*/ | ||
@Override | ||
String getPayload() { | ||
SnowflakeTokenRequestPayload payload = new SnowflakeTokenRequestPayload(); | ||
payload.setGrantType(PINNACLE_PRINCIPAL_CLIENT_CREDENTIALS_GRANT_TYPE); | ||
try { | ||
return mapper.writeValueAsString(payload); | ||
} catch (JsonProcessingException e) { | ||
LOGGER.error("Unable to serialize payload", e); | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
} |
Oops, something went wrong.