Skip to content

Commit

Permalink
Adds a Snowflake-specific Call/Realm Context (apache#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-tjones authored May 1, 2024
1 parent ac50220 commit 3cfe087
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
public interface CallContextResolver {
CallContext resolveCallContext(
RealmContext realmContext,
HTTPMethod method,
String method,
String path,
Map<String, String> queryParams,
Map<String, String> headers);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public class DefaultContextResolver implements RealmContextResolver, CallContext

@Override
public RealmContext resolveRealmContext(
HTTPMethod method,
String requestURL,
String method,
String path,
Map<String, String> queryParams,
Map<String, String> headers) {
Expand All @@ -57,7 +58,7 @@ public String getRealmIdentifier() {
@Override
public CallContext resolveCallContext(
final RealmContext realmContext,
HTTPMethod method,
String method,
String path,
Map<String, String> queryParams,
Map<String, String> headers) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.iceberg.rest.api.impl.IcebergRestOAuth2ApiServiceImpl;
import org.apache.iceberg.rest.config.IcebergRestApplicationConfig;
import org.apache.iceberg.rest.responses.ConfigResponse;
import org.apache.iceberg.rest.snowflake.SnowflakeContextResolver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -38,8 +39,11 @@ public static void main(final String[] args) throws Exception {

@Override
public void run(IcebergRestApplicationConfig configuration, Environment environment) throws Exception {
environment.servlets().addFilter("realmContext", new ContextResolverFilter(
new DefaultContextResolver(), new DefaultContextResolver()))
environment
.servlets()
.addFilter(
"realmContext",
new ContextResolverFilter(new DefaultContextResolver(), new DefaultContextResolver()))
.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST), true, "/*");

RealmCatalogFactory catalogFactory;
Expand All @@ -63,7 +67,7 @@ public Response getConfig(String warehouse, SecurityContext securityContext) {
}
}));
environment.healthChecks().register("pinnacle", new PinnacleHealthCheck());
environment.jersey().register(new IcebergRestOAuth2Api(new IcebergRestOAuth2ApiServiceImpl())); // 501 default impl
environment.jersey().register(new IcebergRestOAuth2Api(new IcebergRestOAuth2ApiServiceImpl()));
environment.jersey().register(new IcebergExceptionMapper());
PinnacleServiceImpl pinnacleService = new PinnacleServiceImpl();
environment.jersey().register(new PinnacleCatalogApi(pinnacleService));
Expand Down Expand Up @@ -96,15 +100,16 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha
Map<String, String> headers = headerNames.collect(Collectors.toMap(Function.identity(),
httpRequest::getHeader));
RealmContext currentRealmContext = realmContextResolver.resolveRealmContext(
RESTCatalogAdapter.HTTPMethod.valueOf(httpRequest.getMethod()),
httpRequest.getRequestURL().toString(),
httpRequest.getMethod(),
httpRequest.getRequestURI().substring(1),
request.getParameterMap().entrySet()
.stream()
.collect(Collectors.toMap(Map.Entry::getKey, (e) -> ((String[]) e.getValue())[0])),
headers);
CallContext currentCallContext = callContextResolver.resolveCallContext(
currentRealmContext,
RESTCatalogAdapter.HTTPMethod.valueOf(httpRequest.getMethod()),
httpRequest.getMethod(),
httpRequest.getRequestURI().substring(1),
request.getParameterMap().entrySet()
.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

public interface RealmContextResolver {
RealmContext resolveRealmContext(
HTTPMethod method,
String requestURL,
String method,
String path,
Map<String, String> queryParams,
Map<String, String> headers);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package org.apache.iceberg.rest.snowflake;

import org.apache.iceberg.rest.CallContext;
import org.apache.iceberg.rest.RealmContext;

public class SnowflakeCallContext implements CallContext {

private final RealmContext realmContext;

/**
* Default constructor
* @param realmContext
*/
SnowflakeCallContext(RealmContext realmContext) {
this.realmContext = realmContext;
}

@Override
public RealmContext getRealmContext() {
return realmContext;
}

/**
* This will return the identifier of the Pinnacle Principal
* @return
*/
@Override
public String getUser() {
return "";
}

@Override
public String getRole() {
return "";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package org.apache.iceberg.rest.snowflake;

import org.apache.iceberg.rest.*;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import java.util.logging.Logger;

public class SnowflakeContextResolver implements RealmContextResolver, CallContextResolver {

Logger LOGGER = Logger.getLogger(SnowflakeContextResolver.class.getName());

/**
* Extracts an account name from the request host. Taken from `ServletRequestUtil`
*
* @return account name if resolvable, null otherwise.
*/
private String getAccountNameFromURL(String hostName) {
if (hostName == null) {
return null;
}
String accountName = null;
// Extract the snowflake domain (the first part of the server name)
int SFDomainIdx = hostName.indexOf(".");
if (SFDomainIdx > 0) {
// The account name/alias (as well as possible a global identifier)
// exists within the snowflake domain, and can be derived from there
accountName = hostName.substring(0, SFDomainIdx).toUpperCase();

// There might be dashes in the server name from translated underscores,
// but since the external id cannot have dashes we know everything
// after the last dash in the snowflake domain will be an external id
int externalIdIdx = accountName.lastIndexOf("-");
if (externalIdIdx >= 0) {
// An external id exists within the snowflake domain
String[] hostNameArr = hostName.split("\\.");
if (hostNameArr.length > 1 && hostNameArr[1].equals("global")) {
// the server name of the format myname-abc123.global.sfc.com has
// been verified - abc123 is the external id, myname is the
// account name/alias, and the presence of global after the
// snowflake domain indicates this is a global identifier
accountName = accountName.substring(0, externalIdIdx);
}
}
}
return accountName;
}

@Override
public CallContext resolveCallContext(
RealmContext realmContext,
String method,
String path,
Map<String, String> queryParams,
Map<String, String> headers) {
return new SnowflakeCallContext(realmContext);
}

/** Resolves the "Realm" which in this case is a Snowflake account */
@Override
public RealmContext resolveRealmContext(
String requestUrl,
String method,
String path,
Map<String, String> queryParams,
Map<String, String> headers) {
// First resolve the Account - we expect our URLs to be of the form
// "https://pinnacle.account.snowflakecomputing.com"
// so get the host and strip "pinnacle" from it
String accountUrl;
try {
String host = new URI(requestUrl).getHost();
accountUrl = host.replace("pinnacle.", "");
} catch (URISyntaxException e) {
// TODO Add better / Pinnacle REST Service generic error handling
LOGGER.info("Error parsing request URL: " + requestUrl);
throw new RuntimeException("Unable to parse the provided account");
}
return new SnowflakeRealmContext(accountUrl, getAccountNameFromURL(accountUrl));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.apache.iceberg.rest.snowflake;

import org.apache.iceberg.rest.RealmContext;

/** The Snowflake "Realm" Context, i.e. the account that is making the request */
public class SnowflakeRealmContext implements RealmContext {

// Base Account URL - ex "myaccount.snowflakecomputing.com"
private final String accountUrl;

// The name of the account - ex "myaccount"
private final String accountName;

SnowflakeRealmContext(final String accountUrl, final String accountName) {
this.accountUrl = accountUrl;
this.accountName = accountName;
}

public String getAccountUrl() {
return accountUrl;
}

public String getAccountName() {
return accountName;
}

/**
* The Realm Identifier for Snowflake is simply the name of the account
* @return
*/
@Override
public String getRealmIdentifier() {
return accountName;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package org.apache.iceberg.rest.snowflake;

import org.apache.iceberg.rest.CallContext;
import org.apache.iceberg.rest.RealmContext;
import org.assertj.core.api.InstanceOfAssertFactories;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.util.HashMap;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.*;

public class SnowflakeContextResolverTest {

@Test
void resolveCallContext() {
RealmContext realmContext =
new SnowflakeContextResolver()
.resolveRealmContext(
"https://pinnacle.testaccount.snowflakecomputing.com:8181",
"POST",
"api/catalog/v1/oauth/tokens",
new HashMap<>(),
new HashMap<>());
CallContext context =
new SnowflakeContextResolver()
.resolveCallContext(
realmContext,
"POST",
"api/catalog/v1/oauth/tokens",
new HashMap<>(),
new HashMap<>());
assertThat(context.getRealmContext())
.returns("TESTACCOUNT", RealmContext::getRealmIdentifier)
.isInstanceOf(SnowflakeRealmContext.class)
.asInstanceOf(InstanceOfAssertFactories.type(SnowflakeRealmContext.class))
.returns("TESTACCOUNT", SnowflakeRealmContext::getAccountName)
.returns("testaccount.snowflakecomputing.com", SnowflakeRealmContext::getAccountUrl);
}

@Test
void resolveRealmContextValidRequestURL() {
SnowflakeContextResolver resolver = new SnowflakeContextResolver();
RealmContext realmContext =
resolver.resolveRealmContext(
"https://pinnacle.testaccount.snowflakecomputing.com:8181",
"POST",
"api/catalog/v1/oauth/tokens",
new HashMap<>(),
new HashMap<>());
Assertions.assertEquals("TESTACCOUNT", realmContext.getRealmIdentifier());
Assertions.assertEquals("TESTACCOUNT", ((SnowflakeRealmContext) realmContext).getAccountName());
Assertions.assertEquals(
"testaccount.snowflakecomputing.com",
((SnowflakeRealmContext) realmContext).getAccountUrl());

}

@Test
void resolveRealmContextInvalidRequestURL() {
SnowflakeContextResolver resolver = new SnowflakeContextResolver();
try {
resolver.resolveRealmContext(
"deerdance", "POST", "api/catalog/v1/oauth/tokens", new HashMap<>(), new HashMap<>());
Assertions.fail("Did not expect this to generate a real context");
} catch (RuntimeException e) {
// pass
}
}
}

0 comments on commit 3cfe087

Please sign in to comment.