Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HUDI-8563] Populate catalogId in AWS Glue sync client #12314

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions hudi-aws/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@
<version>${aws.sdk.version}</version>
</dependency>

<!-- https://mvnrepository.com/artifact/software.amazon.awssdk/sts -->
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>sts</artifactId>
<version>${aws.sdk.version}</version>
</dependency>

<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@
import software.amazon.awssdk.services.glue.model.Table;
import software.amazon.awssdk.services.glue.model.TableInput;
import software.amazon.awssdk.services.glue.model.UpdateTableRequest;
import org.apache.parquet.schema.MessageType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.model.GetCallerIdentityRequest;
import software.amazon.awssdk.services.sts.model.GetCallerIdentityResponse;

import java.net.URI;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -147,12 +153,13 @@ public class AWSGlueCatalogSyncClient extends HoodieSyncClient {
private final int allPartitionsReadParallelism;
private final int changedPartitionsReadParallelism;
private final int changeParallelism;
private final String catalogId;

public AWSGlueCatalogSyncClient(HiveSyncConfig config, HoodieTableMetaClient metaClient) {
this(buildAsyncClient(config), config, metaClient);
}

AWSGlueCatalogSyncClient(GlueAsyncClient awsGlue, HiveSyncConfig config, HoodieTableMetaClient metaClient) {
AWSGlueCatalogSyncClient(GlueAsyncClient awsGlue, StsClient stsClient, HiveSyncConfig config, HoodieTableMetaClient metaClient) {
super(config, metaClient);
this.awsGlue = awsGlue;
this.databaseName = config.getStringOrDefault(META_SYNC_DATABASE_NAME);
Expand All @@ -161,6 +168,8 @@ public AWSGlueCatalogSyncClient(HiveSyncConfig config, HoodieTableMetaClient met
this.allPartitionsReadParallelism = config.getIntOrDefault(ALL_PARTITIONS_READ_PARALLELISM);
this.changedPartitionsReadParallelism = config.getIntOrDefault(CHANGED_PARTITIONS_READ_PARALLELISM);
this.changeParallelism = config.getIntOrDefault(PARTITION_CHANGE_PARALLELISM);
GetCallerIdentityResponse identityResponse = stsClient.getCallerIdentity(GetCallerIdentityRequest.builder().build());
this.catalogId = config.getStringOrDefault(GlueCatalogSyncClientConfig.GLUE_CATALOG_ID, identityResponse.account());
}

private static GlueAsyncClient buildAsyncClient(HiveSyncConfig config) {
Expand All @@ -183,6 +192,7 @@ private List<Partition> getPartitionsSegment(Segment segment, String tableName)
String nextToken = null;
do {
GetPartitionsResponse result = awsGlue.getPartitions(GetPartitionsRequest.builder()
.catalogId(catalogId)
.databaseName(databaseName)
.tableName(tableName)
.excludeColumnSchema(true)
Expand Down Expand Up @@ -498,6 +508,7 @@ && getTable(awsGlue, databaseName, tableName).partitionKeys().equals(partitionKe
.build();

UpdateTableRequest request = UpdateTableRequest.builder()
.catalogId(catalogId)
.databaseName(databaseName)
.tableInput(updatedTableInput)
.build();
Expand Down Expand Up @@ -531,6 +542,7 @@ public void updateTableSchema(String tableName, MessageType newSchema, SchemaDif
.build();

UpdateTableRequest request = UpdateTableRequest.builder()
.catalogId(catalogId)
.databaseName(databaseName)
.skipArchive(skipTableArchive)
.tableInput(updatedTableInput)
Expand Down Expand Up @@ -657,6 +669,7 @@ public void createTable(String tableName,
.build();

CreateTableRequest request = CreateTableRequest.builder()
.catalogId(catalogId)
.databaseName(databaseName)
.tableInput(tableInput)
.build();
Expand Down Expand Up @@ -808,6 +821,7 @@ public Map<String, String> getMetastoreSchema(String tableName) {
@Override
public boolean tableExists(String tableName) {
GetTableRequest request = GetTableRequest.builder()
.catalogId(catalogId)
.databaseName(databaseName)
.name(tableName)
.build();
Expand All @@ -827,7 +841,7 @@ public boolean tableExists(String tableName) {

@Override
public boolean databaseExists(String databaseName) {
GetDatabaseRequest request = GetDatabaseRequest.builder().name(databaseName).build();
GetDatabaseRequest request = GetDatabaseRequest.builder().catalogId(catalogId).name(databaseName).build();
try {
return Objects.nonNull(awsGlue.getDatabase(request).get().database());
} catch (ExecutionException e) {
Expand All @@ -848,6 +862,7 @@ public void createDatabase(String databaseName) {
return;
}
CreateDatabaseRequest request = CreateDatabaseRequest.builder()
.catalogId(catalogId)
.databaseInput(DatabaseInput.builder()
.name(databaseName)
.description("Automatically created by " + this.getClass().getName())
Expand Down Expand Up @@ -930,6 +945,7 @@ public void updateLastCommitTimeSynced(String tableName) {
@Override
public void dropTable(String tableName) {
DeleteTableRequest deleteTableRequest = DeleteTableRequest.builder()
.catalogId(catalogId)
.databaseName(databaseName)
.name(tableName)
.build();
Expand Down Expand Up @@ -1076,8 +1092,9 @@ private enum TableType {
MATERIALIZED_VIEW
}

private static Table getTable(GlueAsyncClient awsGlue, String databaseName, String tableName) throws HoodieGlueSyncException {
private Table getTable(GlueAsyncClient awsGlue, String databaseName, String tableName) throws HoodieGlueSyncException {
GetTableRequest request = GetTableRequest.builder()
.catalogId(catalogId)
.databaseName(databaseName)
.name(tableName)
.build();
Expand All @@ -1090,7 +1107,7 @@ private static Table getTable(GlueAsyncClient awsGlue, String databaseName, Stri
}
}

private static boolean updateTableParameters(GlueAsyncClient awsGlue, String databaseName, String tableName, Map<String, String> updatingParams, boolean skipTableArchive) {
private boolean updateTableParameters(GlueAsyncClient awsGlue, String databaseName, String tableName, Map<String, String> updatingParams, boolean skipTableArchive) {
if (isNullOrEmpty(updatingParams)) {
return false;
}
Expand All @@ -1117,6 +1134,7 @@ private static boolean updateTableParameters(GlueAsyncClient awsGlue, String dat
.build();

UpdateTableRequest request = UpdateTableRequest.builder().databaseName(databaseName)
.catalogId(catalogId)
.tableInput(updatedTableInput)
.skipArchive(skipTableArchive)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,11 @@ public class GlueCatalogSyncClientConfig extends HoodieConfig {
.markAdvanced()
.withDocumentation("Glue sync may fail if the Glue table exists with partitions differing from the Hoodie table or if schema evolution is not supported by Glue."
+ "Enabling this configuration will drop and create the table to match the Hoodie config");

public static final ConfigProperty<String> GLUE_CATALOG_ID = ConfigProperty
.key(GLUE_CLIENT_PROPERTY_PREFIX + "catalogId")
.noDefaultValue()
.sinceVersion("0.15.0")
.markAdvanced()
.withDocumentation("The catalogId needs to be populated for syncing hoodie tables in a different AWS account");
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,31 @@

import org.apache.hadoop.conf.Configuration;
import org.mockito.Mock;
import org.mockito.Mockito;
import software.amazon.awssdk.services.glue.GlueAsyncClient;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.model.GetCallerIdentityRequest;
import software.amazon.awssdk.services.sts.model.GetCallerIdentityResponse;

import java.util.Properties;

import static org.mockito.Mockito.when;

class MockAwsGlueCatalogSyncTool extends AwsGlueCatalogSyncTool {
private static final String CATALOG_ID = "DEFAULT_AWS_ACCOUNT_ID";

@Mock
private GlueAsyncClient mockAwsGlue;

private static StsClient mockSts = Mockito.mock(StsClient.class);

public MockAwsGlueCatalogSyncTool(Properties props, Configuration hadoopConf) {
super(props, hadoopConf, Option.empty());
}

@Override
protected void initSyncClient(HiveSyncConfig hiveSyncConfig, HoodieTableMetaClient metaClient) {
syncClient = new AWSGlueCatalogSyncClient(mockAwsGlue, hiveSyncConfig, metaClient);
when(mockSts.getCallerIdentity(GetCallerIdentityRequest.builder().build())).thenReturn(GetCallerIdentityResponse.builder().account(CATALOG_ID).build());
syncClient = new AWSGlueCatalogSyncClient(mockAwsGlue, mockSts, hiveSyncConfig, metaClient);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
package org.apache.hudi.aws.sync;

import org.apache.hudi.aws.testutils.GlueTestUtil;
import org.apache.hudi.common.config.TypedProperties;
import org.apache.hudi.config.GlueCatalogSyncClientConfig;
import org.apache.hudi.hive.HiveSyncConfig;
import org.apache.hudi.sync.common.model.FieldSchema;

import org.apache.parquet.schema.MessageType;
Expand All @@ -27,6 +30,8 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
Expand All @@ -43,12 +48,16 @@
import software.amazon.awssdk.services.glue.model.Table;
import software.amazon.awssdk.services.glue.model.UpdateTableRequest;
import software.amazon.awssdk.services.glue.model.UpdateTableResponse;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.model.GetCallerIdentityRequest;
import software.amazon.awssdk.services.sts.model.GetCallerIdentityResponse;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

Expand All @@ -60,19 +69,24 @@
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@ExtendWith(MockitoExtension.class)
class TestAWSGlueSyncClient {
private static final String CATALOG_ID = "DEFAULT_AWS_ACCOUNT_ID";

@Mock
private GlueAsyncClient mockAwsGlue;
@Mock
private StsClient mockSts;

private AWSGlueCatalogSyncClient awsGlueSyncClient;

@BeforeEach
void setUp() throws IOException {
GlueTestUtil.setUp();
awsGlueSyncClient = new AWSGlueCatalogSyncClient(mockAwsGlue, GlueTestUtil.getHiveSyncConfig(), GlueTestUtil.getMetaClient());
when(mockSts.getCallerIdentity(GetCallerIdentityRequest.builder().build())).thenReturn(GetCallerIdentityResponse.builder().account(CATALOG_ID).build());
awsGlueSyncClient = new AWSGlueCatalogSyncClient(mockAwsGlue, mockSts, GlueTestUtil.getHiveSyncConfig(), GlueTestUtil.getMetaClient());
}

@AfterEach
Expand Down Expand Up @@ -112,7 +126,7 @@ void testCreateOrReplaceTable_TableExists() throws ExecutionException, Interrupt
.table(table)
.build();

GetTableRequest getTableRequestForTable = GetTableRequest.builder().databaseName(databaseName).name(tableName).build();
GetTableRequest getTableRequestForTable = GetTableRequest.builder().catalogId(CATALOG_ID).databaseName(databaseName).name(tableName).build();
// Mock methods
CompletableFuture<GetTableResponse> tableResponseFuture = CompletableFuture.completedFuture(tableResponse);
CompletableFuture<GetTableResponse> mockTableNotFoundResponse = Mockito.mock(CompletableFuture.class);
Expand Down Expand Up @@ -227,6 +241,29 @@ void testGetTableLocation() {
// verify if table base path is correct
assertEquals(glueSyncProps.get(META_SYNC_BASE_PATH.key()), basePath, "table base path should match");
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
void testGetTableLocationUsingCatalogId(boolean useConfiguredCatalogId) {
String catalogId = useConfiguredCatalogId ? UUID.randomUUID().toString() : CATALOG_ID;
TypedProperties properties = GlueTestUtil.getHiveSyncConfig().getProps();
if (useConfiguredCatalogId) {
properties.setProperty(GlueCatalogSyncClientConfig.GLUE_CATALOG_ID.key(), catalogId);
}
when(mockSts.getCallerIdentity(GetCallerIdentityRequest.builder().build())).thenReturn(GetCallerIdentityResponse.builder().account(CATALOG_ID).build());
awsGlueSyncClient = new AWSGlueCatalogSyncClient(mockAwsGlue, mockSts, new HiveSyncConfig(properties), GlueTestUtil.getMetaClient());

String testdb = "testdb";
String tableName = "testTable";
List<Column> columns = Arrays.asList(Column.builder().name("name").type("string").comment("person's name").build(),
Column.builder().name("age").type("int").comment("person's age").build());
CompletableFuture<GetTableResponse> tableResponse = getTableWithDefaultProps(tableName, columns, Collections.emptyList());
// mock aws glue get table call
GetTableRequest getTableRequestForTable = GetTableRequest.builder().catalogId(catalogId).databaseName(testdb).name(tableName).build();
Mockito.when(mockAwsGlue.getTable(getTableRequestForTable)).thenReturn(tableResponse);
String basePath = awsGlueSyncClient.getTableLocation(tableName);
assertEquals(glueSyncProps.get(META_SYNC_BASE_PATH.key()), basePath, "table base path should match");
}

@Test
void testGetTableLocation_ThrowsException() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.services.glue.GlueAsyncClient;
import software.amazon.awssdk.services.glue.GlueAsyncClientBuilder;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.model.GetCallerIdentityRequest;
import software.amazon.awssdk.services.sts.model.GetCallerIdentityResponse;

import java.io.IOException;

Expand All @@ -48,6 +51,7 @@

@ExtendWith(MockitoExtension.class)
class TestAwsGlueSyncTool {
private static final String CATALOG_ID = "DEFAULT_AWS_ACCOUNT_ID";

private AwsGlueCatalogSyncTool awsGlueCatalogSyncTool;

Expand Down Expand Up @@ -81,12 +85,16 @@ private void reinitGlueSyncTool() {

@Test
void validateInitThroughSyncTool() throws Exception {
try (MockedStatic<GlueAsyncClient> mockedStatic = mockStatic(GlueAsyncClient.class)) {
try (MockedStatic<GlueAsyncClient> mockedStatic = mockStatic(GlueAsyncClient.class);
MockedStatic<StsClient> mockedStsStatic = mockStatic(StsClient.class)) {
GlueAsyncClientBuilder builder = mock(GlueAsyncClientBuilder.class);
mockedStatic.when(GlueAsyncClient::builder).thenReturn(builder);
when(builder.credentialsProvider(any())).thenReturn(builder);
GlueAsyncClient mockClient = mock(GlueAsyncClient.class);
when(builder.build()).thenReturn(mockClient);
StsClient mockSts = mock(StsClient.class);
mockedStsStatic.when(StsClient::create).thenReturn(mockSts);
when(mockSts.getCallerIdentity(GetCallerIdentityRequest.builder().build())).thenReturn(GetCallerIdentityResponse.builder().account("").build());
HoodieSyncTool syncTool = SyncUtilHelpers.instantiateMetaSyncTool(
AwsGlueCatalogSyncTool.class.getName(),
new TypedProperties(),
Expand Down
Loading