Skip to content

Commit

Permalink
chore: Use domain name from JDBC URL for Postgres and Mysql, Part of #…
Browse files Browse the repository at this point in the history
…2043.

When connecting to a Cloud SQL database using a domain name like `db.example.com`, 
you may now set the domain name in the JDBC URL. For example, you can use a URL like
"jdbc:mysql://db.example.com/my-schema?socketFactory=socketFactory=com.google.cloud.sql.mysql.SocketFactory"

The socket factory will detect "db.example.com" and look up the TXT record to resolve
the instance name. 

See #2043 for the whole feature definition.
  • Loading branch information
hessjcg committed Sep 27, 2024
1 parent 092af52 commit 3bd264a
Show file tree
Hide file tree
Showing 13 changed files with 308 additions and 19 deletions.
9 changes: 5 additions & 4 deletions core/src/main/java/com/google/cloud/sql/ConnectorConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class ConnectorConfig {
private final String adminRootUrl;
private final String adminServicePath;
private final Supplier<GoogleCredentials> googleCredentialsSupplier;
private final Function<String,String> instanceNameResolver;
private final Function<String, String> instanceNameResolver;
private final GoogleCredentials googleCredentials;
private final String googleCredentialsPath;
private final String adminQuotaProject;
Expand All @@ -53,7 +53,7 @@ private ConnectorConfig(
String adminQuotaProject,
String universeDomain,
RefreshStrategy refreshStrategy,
Function<String,String> instanceNameResolver) {
Function<String, String> instanceNameResolver) {
this.targetPrincipal = targetPrincipal;
this.delegates = delegates;
this.adminRootUrl = adminRootUrl;
Expand Down Expand Up @@ -162,7 +162,7 @@ public static class Builder {
private String adminQuotaProject;
private String universeDomain;
private RefreshStrategy refreshStrategy = RefreshStrategy.BACKGROUND;
private Function<String,String> instanceNameResolver;
private Function<String, String> instanceNameResolver;

public Builder withTargetPrincipal(String targetPrincipal) {
this.targetPrincipal = targetPrincipal;
Expand Down Expand Up @@ -214,7 +214,8 @@ public Builder withRefreshStrategy(RefreshStrategy refreshStrategy) {
this.refreshStrategy = refreshStrategy;
return this;
}
public Builder withInstanceNameResolver(Function<String,String> instanceNameResolver) {

public Builder withInstanceNameResolver(Function<String, String> instanceNameResolver) {
this.instanceNameResolver = instanceNameResolver;
return this;
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ public Builder withIpTypes(List<IpType> ipTypes) {
this.ipTypes = ipTypes;
return this;
}
/** Set domainName as. */

/** Set domainName. */
public Builder withDomainName(String domainName) {
this.domainName = domainName;
return this;
Expand Down
38 changes: 35 additions & 3 deletions core/src/main/java/com/google/cloud/sql/core/Connector.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.google.cloud.sql.ConnectorConfig;
import com.google.cloud.sql.CredentialFactory;
import com.google.cloud.sql.RefreshStrategy;
import com.google.common.base.Strings;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningScheduledExecutorService;
import java.io.File;
Expand Down Expand Up @@ -142,9 +141,11 @@ Socket connect(ConnectionConfig config, long timeoutMs) throws IOException {
}
}

ConnectionInfoCache getConnection(ConnectionConfig config) {
ConnectionInfoCache getConnection(final ConnectionConfig config) {
final ConnectionConfig updatedConfig = resolveConnectionName(config);

ConnectionInfoCache instance =
instances.computeIfAbsent(config, k -> createConnectionInfo(config));
instances.computeIfAbsent(updatedConfig, k -> createConnectionInfo(updatedConfig));

// If the client certificate has expired (as when the computer goes to
// sleep, and the refresh cycle cannot run), force a refresh immediately.
Expand All @@ -156,6 +157,37 @@ ConnectionInfoCache getConnection(ConnectionConfig config) {
return instance;
}

private ConnectionConfig resolveConnectionName(ConnectionConfig config) {
// If domainName is not set, return the original configuration unmodified.
if (config.getDomainName() == null || config.getDomainName().isEmpty()) {
return config;
}

// If both domainName and cloudSqlInstance are set, ignore the domain name. Return a new
// configuration with domainName set to null.
if (config.getCloudSqlInstance() != null && !config.getCloudSqlInstance().isEmpty()) {
return config.withDomainName(null);
}

// If only domainName is set, resolve the domain name.
try {
final String unresolvedName = config.getDomainName();
final Function<String, String> resolver =
config.getConnectorConfig().getInstanceNameResolver();
if (resolver != null) {
return config.withCloudSqlInstance(resolver.apply(unresolvedName));
} else {
throw new IllegalStateException(
"Can't resolve domain " + unresolvedName + ". ConnectorConfig.resolver is not set.");
}
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException(
String.format(
"Cloud SQL connection name is invalid: \"%s\"", config.getCloudSqlInstance()),
e);
}
}

private ConnectionInfoCache createConnectionInfo(ConnectionConfig config) {
logger.debug(
String.format("[%s] Connection info added to cache.", config.getCloudSqlInstance()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,8 @@ public void testHashCode() {
wantGoogleCredentialsPath,
wantAdminQuotaProject,
null, // universeDomain
wantRefreshStrategy // refreshStrategy
wantRefreshStrategy, // refreshStrategy
null // instanceNameResolver
));
}
}
44 changes: 44 additions & 0 deletions core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,50 @@ public void create_successfulPrivateConnection() throws IOException, Interrupted
assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE);
}

@Test
public void create_successfulPublicConnectionWithDomainName()
throws IOException, InterruptedException {
FakeSslServer sslServer = new FakeSslServer();
ConnectionConfig config =
new ConnectionConfig.Builder()
.withDomainName("db.example.com")
.withIpTypes("PRIMARY")
.withConnectorConfig(
new ConnectorConfig.Builder()
.withInstanceNameResolver((domainName) -> "myProject:myRegion:myInstance")
.build())
.build();

int port = sslServer.start(PUBLIC_IP);

Connector connector = newConnector(config.getConnectorConfig(), port);

Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS);

assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE);
}

@Test
public void create_throwsErrorForDomainNameWithNoResolver()
throws IOException, InterruptedException {
// The server TLS certificate matches myProject:myRegion:myInstance
FakeSslServer sslServer = new FakeSslServer();
ConnectionConfig config =
new ConnectionConfig.Builder()
.withDomainName("db.example.com")
.withIpTypes("PRIMARY")
.build();

int port = sslServer.start(PUBLIC_IP);

Connector connector = newConnector(config.getConnectorConfig(), port);
IllegalStateException ex =
assertThrows(
IllegalStateException.class, () -> connector.connect(config, TEST_MAX_REFRESH_MS));

assertThat(ex).hasMessageThat().contains("ConnectorConfig.resolver is not set");
}

@Test
public void create_successfulPublicConnection() throws IOException, InterruptedException {
FakeSslServer sslServer = new FakeSslServer();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,22 @@ public class SocketFactory extends ConfigurableSocketFactory {
}

private Configuration conf;
private String host;

public SocketFactory() {}

@Override
public void setConfiguration(Configuration conf, String host) {
// Ignore the hostname
this.conf = conf;
this.host = host;
}

@Override
public Socket createSocket() throws IOException {
try {
return InternalConnectorRegistry.getInstance()
.connect(ConnectionConfig.fromConnectionProperties(conf.nonMappedOptions()));
.connect(ConnectionConfig.fromConnectionProperties(conf.nonMappedOptions(), host));
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public <T extends Closeable> T connect(
T socket =
(T)
InternalConnectorRegistry.getInstance()
.connect(ConnectionConfig.fromConnectionProperties(props));
.connect(ConnectionConfig.fromConnectionProperties(props, host));
return socket;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.cloud.sql.mysql;

import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;

import com.google.cloud.sql.ConnectorConfig;
import com.google.cloud.sql.ConnectorRegistry;
import com.google.common.collect.ImmutableList;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class JdbcMysqlJ8DomainNameIntegrationTests {

private static final String CONNECTION_NAME = System.getenv("MYSQL_CONNECTION_NAME");
private static final String DB_NAME = System.getenv("MYSQL_DB");
private static final String DB_USER = System.getenv("MYSQL_USER");
private static final String DB_PASSWORD = System.getenv("MYSQL_PASS");
private static final ImmutableList<String> requiredEnvVars =
ImmutableList.of("MYSQL_USER", "MYSQL_PASS", "MYSQL_DB", "MYSQL_CONNECTION_NAME");
@Rule public Timeout globalTimeout = new Timeout(80, TimeUnit.SECONDS);
private HikariDataSource connectionPool;

@BeforeClass
public static void checkEnvVars() {
// Check that required env vars are set
requiredEnvVars.forEach(
(varName) ->
assertWithMessage(
String.format(
"Environment variable '%s' must be set to perform these tests.", varName))
.that(System.getenv(varName))
.isNotEmpty());
}

@Before
public void setUpPool() throws SQLException {
// Set up URL parameters
String jdbcURL = String.format("jdbc:mysql://db.example.com/%s", DB_NAME);
Properties connProps = new Properties();
connProps.setProperty("user", DB_USER);
connProps.setProperty("password", DB_PASSWORD);
connProps.setProperty("socketFactory", "com.google.cloud.sql.mysql.SocketFactory");

// Register a resolver that resolves `db.example.com` to the connection name
connProps.setProperty("cloudSqlNamedConnector", "resolver-test");
ConnectorRegistry.register(
"resolver-test",
new ConnectorConfig.Builder()
.withInstanceNameResolver((n) -> "db.example.com".equals(n) ? CONNECTION_NAME : null)
.build());

// Initialize connection pool
HikariConfig config = new HikariConfig();
config.setJdbcUrl(jdbcURL);
config.setDataSourceProperties(connProps);
config.setConnectionTimeout(10000); // 10s

this.connectionPool = new HikariDataSource(config);
}

@Test
public void pooledConnectionTest() throws SQLException {

List<Timestamp> rows = new ArrayList<>();
try (Connection conn = connectionPool.getConnection()) {
try (PreparedStatement selectStmt = conn.prepareStatement("SELECT NOW() as TS")) {
ResultSet rs = selectStmt.executeQuery();
while (rs.next()) {
rows.add(rs.getTimestamp("TS"));
}
}
}
assertThat(rows.size()).isEqualTo(1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public static void checkEnvVars() {
@Before
public void setUpPool() throws SQLException {
// Set up URL parameters
String jdbcURL = String.format("jdbc:mysql:///%s", DB_NAME);
String jdbcURL = String.format("jdbc:mysql://db.example.com/%s", DB_NAME);
Properties connProps = new Properties();
connProps.setProperty("user", DB_USER);
connProps.setProperty("password", DB_PASSWORD);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ public class SocketFactory extends javax.net.SocketFactory {
private static final String DEPRECATED_SOCKET_ARG = "SocketFactoryArg";
private static final String POSTGRES_SUFFIX = "/.s.PGSQL.5432";

/** The connection property containing the hostname from the JDBC url. */
private static final String POSTGRES_HOST_PROP = "PGHOST";

private final Properties props;

static {
Expand Down Expand Up @@ -78,7 +81,9 @@ private static Properties createDefaultProperties(String instanceName) {
public Socket createSocket() throws IOException {
try {
return InternalConnectorRegistry.getInstance()
.connect(ConnectionConfig.fromConnectionProperties(props));
.connect(
ConnectionConfig.fromConnectionProperties(
props, props.getProperty(POSTGRES_HOST_PROP)));
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
Expand Down
Loading

0 comments on commit 3bd264a

Please sign in to comment.