Skip to content

Commit

Permalink
Adds plumbing for Snowflake OAuth Context (apache#23)
Browse files Browse the repository at this point in the history
This does the following:

- Adds a broker to manage `PINNACLE_SERVICE` tokens
- Adds a broker to manage `PINNACLE_PRINCIPAL` tokens
- Adds a broker to interact with the `oauth/token-info` endpoint in GS
- Extends the CallCtx plumbing
- Adds a HTTPUtility to make calls to GS with

Notably missing:

- End-to-end tests (I wrote this all on a plane)
- Error handling

There are extenstive tests
  • Loading branch information
sfc-gh-tjones committed May 2, 2024
1 parent 3cfe087 commit 87ba4ea
Show file tree
Hide file tree
Showing 18 changed files with 1,019 additions and 57 deletions.
1 change: 1 addition & 0 deletions iceberg-rest-server/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ openApiGenerate {
LoadTableResult: "org.apache.iceberg.rest.responses.LoadTableResponse" ,
LoadViewResult: "org.apache.iceberg.rest.responses.LoadTableResponse" ,
OAuthTokenResponse: "org.apache.iceberg.rest.responses.OAuthTokenResponse" ,
OAuthErrorResponse: "org.apache.iceberg.rest.responses.OAuthErrorResponse",
RenameTableRequest: "org.apache.iceberg.rest.requests.RenameTableRequest" ,
ReportMetricsRequest: "org.apache.iceberg.rest.requests.ReportMetricsRequest" ,
UpdateNamespacePropertiesRequest: "org.apache.iceberg.rest.requests.UpdateNamespacePropertiesRequest" ,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package org.apache.iceberg.pinnacle.http;

import org.apache.http.HttpHost;
import org.apache.http.HttpResponse;
import org.apache.http.client.ServiceUnavailableRetryStrategy;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.conn.ssl.DefaultHostnameVerifier;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
import org.apache.http.protocol.HttpContext;
import org.apache.http.ssl.SSLContexts;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLContext;
import java.security.Security;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.concurrent.TimeUnit;

/** Utility class that can be used to make REST HTTP calls to Snowflake with */
public class HTTPUtil {

private static final Logger LOGGER = LoggerFactory.getLogger(HTTPUtil.class);

private static final String FIRST_FAULT_TIMESTAMP = "FIRST_FAULT_TIMESTAMP";
private static final Duration TOTAL_RETRY_DURATION = Duration.of(120, ChronoUnit.SECONDS);
private static final Duration RETRY_INTERVAL = Duration.of(3, ChronoUnit.SECONDS);

private static ServiceUnavailableRetryStrategy getServiceUnavailableRetryStrategy() {
return new ServiceUnavailableRetryStrategy() {
final int REQUEST_TIMEOUT = 408;
final int TOO_MANY_REQUESTS = 429;
final int SERVER_ERRORS = 500;

@Override
public boolean retryRequest(
final HttpResponse response, final int executionCount, final HttpContext context) {
Object firstFault = context.getAttribute(FIRST_FAULT_TIMESTAMP);
long totalRetryDurationSoFarInSeconds = 0;
if (firstFault == null) {
context.setAttribute(FIRST_FAULT_TIMESTAMP, Instant.now());
} else {
Instant firstFaultInstant = (Instant) firstFault;
Instant now = Instant.now();
totalRetryDurationSoFarInSeconds = Duration.between(firstFaultInstant, now).getSeconds();

if (totalRetryDurationSoFarInSeconds > TOTAL_RETRY_DURATION.getSeconds()) {
LOGGER.info(
String.format(
"Reached the max retry time of %d seconds, not retrying anymore",
TOTAL_RETRY_DURATION.getSeconds()));
return false;
}
}

int statusCode = response.getStatusLine().getStatusCode();
boolean needNextRetry =
(statusCode == REQUEST_TIMEOUT
|| statusCode == TOO_MANY_REQUESTS
|| statusCode >= SERVER_ERRORS);
if (needNextRetry) {
long interval = getRetryInterval();
LOGGER.info("In retryRequest for service unavailability with statusCode:{}", statusCode);
LOGGER.info(
"Sleep time in millisecond: {}, retryCount: {}, total retry duration: {}s / {}s",
interval,
executionCount,
totalRetryDurationSoFarInSeconds,
TOTAL_RETRY_DURATION.getSeconds());
}
return needNextRetry;
}

@Override
public long getRetryInterval() {
return RETRY_INTERVAL.toMillis();
}
};
}

private static volatile CloseableHttpClient httpClient;

private static PoolingHttpClientConnectionManager connectionManager;

private static final int DEFAULT_CONNECTION_TIMEOUT_MINUTES = 1;
private static final int DEFAULT_HTTP_CLIENT_SOCKET_TIMEOUT_MINUTES = 5;

/**
* After how many seconds of inactivity should be idle connections evicted from the connection
* pool.
*/
private static final int DEFAULT_EVICT_IDLE_AFTER_SECONDS = 60;

// Default is 2, but scaling it up to 100 to match with default_max_connections
private static final int DEFAULT_MAX_CONNECTIONS_PER_ROUTE = 100;

// 100 is close to max partition number we have seen for a kafka topic ingesting into snowflake.
private static final int DEFAULT_MAX_CONNECTIONS = 100;

// Interval in which we check if there are connections which needs to be closed.
private static final long IDLE_HTTP_CONNECTION_MONITOR_THREAD_INTERVAL_MS =
TimeUnit.SECONDS.toMillis(5);

// Only connections that are currently owned, not checked out, are subject to idle timeouts.
private static final int DEFAULT_IDLE_CONNECTION_TIMEOUT_SECONDS = 30;

public static void initHttpClient() {
Security.setProperty("ocsp.enable", "true");
SSLContext sslContext = SSLContexts.createDefault();

SSLConnectionSocketFactory f =
new SSLConnectionSocketFactory(
sslContext, new String[] {"TLSv1.2"}, null, new DefaultHostnameVerifier());
// Set connectionTimeout which is the timeout until a connection with the server is established
// Set connectionRequestTimeout which is the time to wait for getting a connection from the
// connection pool
// Set socketTimeout which is the max time gap between two consecutive data packets
RequestConfig requestConfig =
RequestConfig.custom()
.setConnectTimeout(
(int)
TimeUnit.MILLISECONDS.convert(
DEFAULT_CONNECTION_TIMEOUT_MINUTES, TimeUnit.MINUTES))
.setConnectionRequestTimeout(
(int)
TimeUnit.MILLISECONDS.convert(
DEFAULT_CONNECTION_TIMEOUT_MINUTES, TimeUnit.MINUTES))
.setSocketTimeout(
(int)
TimeUnit.MILLISECONDS.convert(
DEFAULT_HTTP_CLIENT_SOCKET_TIMEOUT_MINUTES, TimeUnit.MINUTES))
.build();

// Below pooling client connection manager uses time_to_live value as -1 which means it will not
// refresh a persisted connection
connectionManager = new PoolingHttpClientConnectionManager();
connectionManager.setDefaultMaxPerRoute(DEFAULT_MAX_CONNECTIONS_PER_ROUTE);
connectionManager.setMaxTotal(DEFAULT_MAX_CONNECTIONS);

// Use an anonymous class to implement the interface ServiceUnavailableRetryStrategy() The max
// retry time is 3. The interval time is backoff.
HttpClientBuilder clientBuilder =
HttpClientBuilder.create()
.setConnectionManager(connectionManager)
.evictIdleConnections(DEFAULT_EVICT_IDLE_AFTER_SECONDS, TimeUnit.SECONDS)
.setSSLSocketFactory(f)
.setServiceUnavailableRetryStrategy(getServiceUnavailableRetryStrategy())
// TODO add retry handler .setRetryHandler(getHttpRequestRetryHandler())
.setDefaultRequestConfig(requestConfig);
httpClient = clientBuilder.build();
}

/**
* @return Instance of CloseableHttpClient
*/
public static CloseableHttpClient getHttpClient() {
if (httpClient == null) {
synchronized (HTTPUtil.class) {
if (httpClient == null) {
initHttpClient();
}
}
}
return httpClient;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.apache.iceberg.pinnacle.oauth;

import com.nimbusds.jose.util.Base64;
import org.apache.http.entity.StringEntity;

/** Simple utility class to assist with OAuth operations*/
public class OAuthUtils {

public static final String AUTHORIZATION_HEADER = "Authorization";

/**
* @param clientId
* @param clientSecret
* @return basic Authorization Header of the form `base64_encode(client_id:client_secret)
*/
public static String getBasicAuthHeader(String clientId, String clientSecret) {
return Base64.encode(clientId + ":" + clientSecret).toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package org.apache.iceberg.pinnacle.oauth;

import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.SecurityContext;
import org.apache.iceberg.rest.CallContext;
import org.apache.iceberg.rest.api.IcebergRestOAuth2ApiService;
import org.apache.iceberg.rest.responses.OAuthTokenResponse;
import org.apache.iceberg.rest.snowflake.SnowflakeRealmContext;
import org.apache.iceberg.rest.types.TokenType;

/**
* Snowflake-specific OAuth2 Service. This class essentially acts as a broker between an External
* Iceberg Client and Snowflake. Specifically, this handles `/v1/oauth/tokens` requests made from
* Iceberg Clients and translates those into calls against Snowflake to fetch credentials derived
* from a `PINNACLE_PRINCIPAL` integration.
*/
public class SnowflakeOAuth2Service implements IcebergRestOAuth2ApiService {

// I need to figure out how to bootstrap config with the REST Application - this might get pushed
// into the class itself depending on what I find.
private static final SnowflakePinnacleServiceTokenBroker pinnacleServiceTokenBroker =
new SnowflakePinnacleServiceTokenBroker(
"TODO-CLIENT-ID", "TODO-CLIENT-SECRET", "TODO-CLIENT_SECRET-2");

public static SnowflakePinnacleServiceTokenBroker getPinnacleServiceTokenBroker() {
return pinnacleServiceTokenBroker;
}

/**
* Initializes an instance of the SnowflakeOAuth2Service. This is generally expected to be used
* for the duration of the Dropwizard application as initialized in {@link
* org.apache.iceberg.rest.IcebergRestApplication}
*/
public SnowflakeOAuth2Service() {}

/**
* Handles a `/v1/oauth/tokens` request made from an External Iceberg Client We make a REST
* request to Snowflake's `/oauth/token-request` endpoint with the grant type set to
* `pinnacle_principal` and the Client ID/Secret in the `Authorization` header.
*
* <p>A note regarding the spec: it appears that client ID/Secret can come in via the request
* payload OR it can come in via the Authorization header. We should look at the Spark clients and
* see where they generally put Client ID/Secret but we'll probably have to check both and pick
* whatever is not null. For now we'll just look at the request payload. As per the docs:
*
* <pre>
* This can be sent in the request body, but OAuth2 recommends sending it in
* a Basic Authorization header.
* </pre>
*
* @param grantType This should either be `client_credentials` for obtaining a token or
* `urn:ietf:params:oauth:grant-type:token-exchange` for exchanging a token. For now we only
* support `client_credentials`
* @param scope the requested scope of the Iceberg Client. TODO define later
* @param clientId (Nullable) the Client ID that should map to a `PINNACLE_PRINCIPAL` integration
* in Snowflake
* @param clientSecret (Nullable) the Client Secret that should map to a `PINNACLE_PRINCIPAL`
* integration in Snowflake
* @param requestedTokenType
* @param subjectToken
* @param subjectTokenType
* @param actorToken
* @param actorTokenType
* @param securityContext
* @return Either an `OAuthTokenResponse` or `OAuthErrorResponse` as defined in the REST spec
*/
@Override
public Response getToken(
String grantType,
String scope,
String clientId,
String clientSecret,
TokenType requestedTokenType,
String subjectToken,
TokenType subjectTokenType,
String actorToken,
TokenType actorTokenType,
SecurityContext securityContext) {
TokenRequestValidator validator = new TokenRequestValidator();
if (!validator.validateForClientCredentialsFlow(clientId, clientSecret, grantType)) {
// TODO this needs to be `OAuthErrorResponse` as defined in the spec but I can't
// seem to build with it on the plane?
return Response.status(Response.Status.BAD_REQUEST).build();
}

// This needs error handling as well
SnowflakePinnaclePrincipalTokenBroker broker =
new SnowflakePinnaclePrincipalTokenBroker(clientId, clientSecret);
SnowflakeTokenResponse tokenResponse =
broker.getToken((SnowflakeRealmContext) CallContext.getCurrentContext().getRealmContext());
return Response.ok(
OAuthTokenResponse.builder()
.withToken(tokenResponse.getAccessToken())
.withTokenType("bearer") // TODO there's gotta be a constant somewhere
.setExpirationInSeconds(tokenResponse.getExpiresIn())
.build())
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package org.apache.iceberg.pinnacle.oauth;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Responsible for acting as a broker to obtain `PINNACLE_PRINCIPAL` level tokens. */
public class SnowflakePinnaclePrincipalTokenBroker extends SnowflakeTokenBroker {

private static final Logger LOGGER =
LoggerFactory.getLogger(SnowflakePinnaclePrincipalTokenBroker.class);

private static final String PINNACLE_PRINCIPAL_CLIENT_CREDENTIALS_GRANT_TYPE =
"pinnacle_principal_client_credentials";

private static final ObjectMapper mapper = new ObjectMapper();

private final String clientId;
private final String clientSecret;

/**
* @param clientId Client ID corresponding to a `PINNACLE_PRINCIPAL` integration
* @param clientSecret Client Secret corresponding to a `PINNACLE_PRINCIPAL` integration
*/
public SnowflakePinnaclePrincipalTokenBroker(final String clientId, final String clientSecret) {
this.clientId = clientId;
this.clientSecret = clientSecret;
}

@Override
String getAuthHeader() {
return OAuthUtils.getBasicAuthHeader(clientId, clientSecret);
}

/**
* In the future this will likely have the Client ID/Secret in it and we'll put a PINNACLE_SERVICE
* token in the Authorization Header. There is a JIRA for this somewhere.
*/
@Override
String getPayload() {
SnowflakeTokenRequestPayload payload = new SnowflakeTokenRequestPayload();
payload.setGrantType(PINNACLE_PRINCIPAL_CLIENT_CREDENTIALS_GRANT_TYPE);
try {
return mapper.writeValueAsString(payload);
} catch (JsonProcessingException e) {
LOGGER.error("Unable to serialize payload", e);
throw new RuntimeException(e);
}
}
}
Loading

0 comments on commit 87ba4ea

Please sign in to comment.