Skip to content

Commit

Permalink
Add validation to check that next URI host and port does not change d…
Browse files Browse the repository at this point in the history
…uring query execution
  • Loading branch information
dianatatar authored and tdcmeehan committed Dec 23, 2023
1 parent a735bcb commit 0f89633
Show file tree
Hide file tree
Showing 13 changed files with 97 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ public ClientSession getClientSession()
clientRequestTimeout,
disableCompression,
ImmutableMap.of(),
ImmutableMap.of());
ImmutableMap.of(),
false);
}

private static URI parseServer(String server)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ public class ClientOptions
@Option(name = "--disable-compression", title = "disable response compression", description = "Disable compression of query results")
public boolean disableCompression;

@Option(name = "--validate-nexturi-source", title = "validate nextUri source", description = "Validate nextUri server host and port does not change during query execution")
public boolean validateNextUriSource;

public enum OutputFormat
{
ALIGNED,
Expand Down Expand Up @@ -176,7 +179,8 @@ public ClientSession toClientSession()
clientRequestTimeout,
disableCompression,
emptyMap(),
emptyMap());
emptyMap(),
validateNextUriSource);
}

public static URI parseServer(String server)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ protected ClientSession createMockClientSession()
new Duration(2, MINUTES),
true,
ImmutableMap.of(),
ImmutableMap.of());
ImmutableMap.of(),
false);
}

protected QueryResults createMockQueryResults()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,12 @@ public void testDuplicateExtraCredentialKey()
ClientOptions options = console.clientOptions;
options.toClientSession();
}

@Test
public void testValidateNextUriSource()
{
Console console = singleCommand(Console.class).parse("--validate-nexturi-source");
assertTrue(console.clientOptions.validateNextUriSource);
assertTrue(console.clientOptions.toClientSession().validateNextUriSource());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public class ClientSession
private final Duration clientRequestTimeout;
private final boolean compressionDisabled;
private final Map<String, String> sessionFunctions;
private final boolean validateNextUriSource;

public static Builder builder(ClientSession clientSession)
{
Expand Down Expand Up @@ -87,7 +88,8 @@ public ClientSession(
Duration clientRequestTimeout,
boolean compressionDisabled,
Map<String, String> sessionFunctions,
Map<String, String> customHeaders)
Map<String, String> customHeaders,
boolean validateNextUriSource)
{
this.server = requireNonNull(server, "server is null");
this.user = user;
Expand All @@ -109,6 +111,7 @@ public ClientSession(
this.clientRequestTimeout = clientRequestTimeout;
this.compressionDisabled = compressionDisabled;
this.sessionFunctions = ImmutableMap.copyOf(requireNonNull(sessionFunctions, "sessionFunctions is null"));
this.validateNextUriSource = validateNextUriSource;

for (String clientTag : clientTags) {
checkArgument(!clientTag.contains(","), "client tag cannot contain ','");
Expand Down Expand Up @@ -255,6 +258,11 @@ public Map<String, String> getSessionFunctions()
return sessionFunctions;
}

public boolean validateNextUriSource()
{
return validateNextUriSource;
}

@Override
public String toString()
{
Expand Down Expand Up @@ -296,6 +304,7 @@ public static final class Builder
private Duration clientRequestTimeout;
private boolean compressionDisabled;
private Map<String, String> sessionFunctions;
private boolean validateNextUriSource;

private Builder(ClientSession clientSession)
{
Expand All @@ -320,6 +329,7 @@ private Builder(ClientSession clientSession)
clientRequestTimeout = clientSession.getClientRequestTimeout();
compressionDisabled = clientSession.isCompressionDisabled();
sessionFunctions = clientSession.getSessionFunctions();
validateNextUriSource = clientSession.validateNextUriSource();
}

public Builder withCatalog(String catalog)
Expand Down Expand Up @@ -388,6 +398,12 @@ public Builder withSessionFunctions(Map<String, String> sessionFunctions)
return this;
}

public Builder withValidateNextUriSource(final boolean validateNextUriSource)
{
this.validateNextUriSource = validateNextUriSource;
return this;
}

public ClientSession build()
{
return new ClientSession(
Expand All @@ -410,7 +426,8 @@ public ClientSession build()
clientRequestTimeout,
compressionDisabled,
sessionFunctions,
customHeaders);
customHeaders,
validateNextUriSource);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class StatementClientV1
private final boolean compressionDisabled;
private final Map<String, String> addedSessionFunctions = new ConcurrentHashMap<>();
private final Set<String> removedSessionFunctions = newConcurrentHashSet();
private final boolean validateNextUriSource;

private final AtomicReference<State> state = new AtomicReference<>(State.RUNNING);

Expand All @@ -129,6 +130,7 @@ public StatementClientV1(OkHttpClient httpClient, ClientSession session, String
this.requestTimeoutNanos = session.getClientRequestTimeout();
this.user = session.getUser();
this.compressionDisabled = session.isCompressionDisabled();
this.validateNextUriSource = session.validateNextUriSource();

Request request = buildQueryRequest(session, query);

Expand Down Expand Up @@ -365,6 +367,7 @@ public boolean advance()
state.compareAndSet(State.RUNNING, State.FINISHED);
return false;
}
validateNextUriSource(nextUri, currentStatusInfo().getInfoUri());

Request request = prepareRequest(HttpUrl.get(nextUri)).build();

Expand Down Expand Up @@ -422,6 +425,20 @@ public boolean advance()
}
}

private void validateNextUriSource(final URI nextUri, final URI infoUri)
{
if (!validateNextUriSource) {
return;
}

if (nextUri.getHost().equals(infoUri.getHost())
&& nextUri.getPort() == infoUri.getPort()) {
return;
}
state.compareAndSet(State.RUNNING, State.CLIENT_ERROR);
throw new RuntimeException(format("Next URI host and port %s are different than current %s", nextUri.getHost(), infoUri.getHost()));
}

private void processResponse(Headers headers, QueryResults results)
{
setCatalog.set(headers.get(PRESTO_SET_CATALOG));
Expand Down
1 change: 1 addition & 0 deletions presto-docs/src/main/sphinx/installation/jdbc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,5 @@ Name Description
customHeaders is a list of key-value pairs. Example:
``testHeaderKey:testHeaderValue`` will inject the header ``testHeaderKey``
with value ``testHeaderValue``. Values should be percent encoded.
``validateNextUriSource`` Validates that host and port in next URI does not change during query execution.
================================= =======================================================================
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ final class ConnectionProperties
public static final ConnectionProperty<Map<String, String>> SESSION_PROPERTIES = new SessionProperties();
public static final ConnectionProperty<List<Protocol>> HTTP_PROTOCOLS = new HttpProtocols();
public static final ConnectionProperty<List<QueryInterceptor>> QUERY_INTERCEPTORS = new QueryInterceptors();
public static final ConnectionProperty<Boolean> VALIDATE_NEXTURI_SOURCE = new ValidateNextUriSource();

private static final Set<ConnectionProperty<?>> ALL_PROPERTIES = ImmutableSet.<ConnectionProperty<?>>builder()
.add(USER)
Expand Down Expand Up @@ -89,6 +90,7 @@ final class ConnectionProperties
.add(SESSION_PROPERTIES)
.add(HTTP_PROTOCOLS)
.add(QUERY_INTERCEPTORS)
.add(VALIDATE_NEXTURI_SOURCE)
.build();

private static final Map<String, ConnectionProperty<?>> KEY_LOOKUP = unmodifiableMap(ALL_PROPERTIES.stream()
Expand Down Expand Up @@ -367,4 +369,13 @@ public QueryInterceptors()
super("queryInterceptors", NOT_REQUIRED, ALLOWED, CLASS_LIST_CONVERTER);
}
}

private static class ValidateNextUriSource
extends AbstractConnectionProperty<Boolean>
{
public ValidateNextUriSource()
{
super("validateNextUriSource", Optional.of("false"), NOT_REQUIRED, ALLOWED, BOOLEAN_CONVERTER);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ public class PrestoConnection
private final QueryExecutor queryExecutor;
private final WarningsManager warningsManager = new WarningsManager();
private final List<QueryInterceptor> queryInterceptorInstances;
private final boolean validateNextUriSource;

PrestoConnection(PrestoDriverUri uri, QueryExecutor queryExecutor)
throws SQLException
Expand All @@ -116,6 +117,7 @@ public class PrestoConnection
this.sessionProperties = new ConcurrentHashMap<>(uri.getSessionProperties());
this.connectionProperties = uri.getProperties();
this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null");
this.validateNextUriSource = uri.validateNextUriSource();
uri.getClientTags().ifPresent(tags -> clientInfo.put("ClientTags", tags));

timeZoneId.set(uri.getTimeZoneId());
Expand Down Expand Up @@ -792,7 +794,8 @@ else if (applicationName != null) {
timeout,
compressionDisabled,
ImmutableMap.of(),
customHeaders);
customHeaders,
validateNextUriSource);

return queryExecutor.startQuery(session, sql);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import static com.facebook.presto.jdbc.ConnectionProperties.SSL_TRUST_STORE_PATH;
import static com.facebook.presto.jdbc.ConnectionProperties.TIMEZONE_ID;
import static com.facebook.presto.jdbc.ConnectionProperties.USER;
import static com.facebook.presto.jdbc.ConnectionProperties.VALIDATE_NEXTURI_SOURCE;
import static com.google.common.base.Strings.isNullOrEmpty;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -210,6 +211,12 @@ public Optional<List<Protocol>> getProtocols()
return HTTP_PROTOCOLS.getValue(properties);
}

public boolean validateNextUriSource()
throws SQLException
{
return VALIDATE_NEXTURI_SOURCE.getValue(properties).orElse(false);
}

public void setupClient(OkHttpClient.Builder builder)
throws SQLException
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
import static com.facebook.presto.jdbc.ConnectionProperties.SOCKS_PROXY;
import static com.facebook.presto.jdbc.ConnectionProperties.SSL_TRUST_STORE_PASSWORD;
import static com.facebook.presto.jdbc.ConnectionProperties.SSL_TRUST_STORE_PATH;
import static com.facebook.presto.jdbc.ConnectionProperties.VALIDATE_NEXTURI_SOURCE;
import static java.lang.String.format;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
Expand Down Expand Up @@ -319,6 +321,21 @@ public void testUriWithQueryInterceptors()
assertEquals(properties.getProperty(QUERY_INTERCEPTORS.getKey()), queryInterceptor);
}

@Test
public void testValidateNextUriSource()
throws SQLException
{
PrestoDriverUri defaultParams = createDriverUri("presto://localhost:8080/blackhole");
assertFalse(defaultParams.validateNextUriSource());
assertEquals(defaultParams.getProperties().getProperty(VALIDATE_NEXTURI_SOURCE.getKey()), "false");

PrestoDriverUri parameters = createDriverUri("presto://localhost:8080/blackhole?validateNextUriSource=true");
assertTrue(parameters.validateNextUriSource());
assertEquals(parameters.getProperties().getProperty(VALIDATE_NEXTURI_SOURCE.getKey()), "true");

assertInvalid("presto://localhost:8080/blackhole?validateNextUriSource=ANOTHERVALUE", "Connection property 'validateNextUriSource' value is invalid: ANOTHERVALUE");
}

public static class TestForUriQueryInterceptor
implements QueryInterceptor
{}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ private static ClientSession toClientSession(Session session, URI server, Durati
clientRequestTimeout,
true,
serializedSessionFunctions,
ImmutableMap.of());
ImmutableMap.of(),
false);
}

public List<QualifiedObjectName> listTables(Session session, String catalog, String schema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ private static QueryId startQuery(String sql, DistributedQueryRunner queryRunner
new Duration(2, MINUTES),
true,
ImmutableMap.of(),
ImmutableMap.of());
ImmutableMap.of(),
false);

// start query
StatementClient client = newStatementClient(httpClient, clientSession, sql);
Expand Down

0 comments on commit 0f89633

Please sign in to comment.