diff --git a/storage/adls_gen1_mount.go b/storage/adls_gen1_mount.go index 8bb988e66c..f2e49d1f4b 100644 --- a/storage/adls_gen1_mount.go +++ b/storage/adls_gen1_mount.go @@ -20,7 +20,7 @@ type AzureADLSGen1Mount struct { } // Source ... -func (m AzureADLSGen1Mount) Source() string { +func (m AzureADLSGen1Mount) Source(_ *common.DatabricksClient) string { return fmt.Sprintf("adl://%s.azuredatalakestore.net%s", m.StorageResource, m.Directory) } diff --git a/storage/adls_gen2_mount.go b/storage/adls_gen2_mount.go index 0efd350ee9..dd912885d2 100644 --- a/storage/adls_gen2_mount.go +++ b/storage/adls_gen2_mount.go @@ -2,9 +2,11 @@ package storage import ( "fmt" + "strings" "github.com/databricks/terraform-provider-databricks/common" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" ) // AzureADLSGen2Mount describes the object for a azure datalake gen 2 storage mount @@ -19,10 +21,23 @@ type AzureADLSGen2Mount struct { InitializeFileSystem bool `json:"initialize_file_system"` } +func getAzureDomain(client *common.DatabricksClient) string { + domains := map[string]string{ + "PUBLIC": "core.windows.net", + "USGOVERNMENT": "core.usgovcloudapi.net", + "CHINA": "core.chinacloudapi.cn", + } + azureEnvironment := client.Config.Environment().AzureEnvironment.Name + domain, ok := domains[strings.ToUpper(azureEnvironment)] + if !ok { + panic(fmt.Sprintf("Unknown Azure environment: '%s'", azureEnvironment)) + } + return domain +} + // Source returns ABFSS URI backing the mount -func (m AzureADLSGen2Mount) Source() string { - return fmt.Sprintf("abfss://%s@%s.dfs.core.windows.net%s", - m.ContainerName, m.StorageAccountName, m.Directory) +func (m AzureADLSGen2Mount) Source(client *common.DatabricksClient) string { + return fmt.Sprintf("abfss://%s@%s.dfs.%s%s", m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory) } func (m AzureADLSGen2Mount) Name() string { @@ -106,5 +121,12 @@ func ResourceAzureAdlsGen2Mount() common.Resource { Required: true, ForceNew: true, }, + "environment": { + Type: schema.TypeString, + Optional: true, + ForceNew: true, + ValidateFunc: validation.StringInSlice([]string{"PUBLIC", "USGOVERNMENT", "CHINA"}, false), + Default: "PUBLIC", + }, })) } diff --git a/storage/adls_gen2_mount_test.go b/storage/adls_gen2_mount_test.go index 9c4b65c19a..5c1cc35102 100644 --- a/storage/adls_gen2_mount_test.go +++ b/storage/adls_gen2_mount_test.go @@ -10,11 +10,10 @@ import ( "github.com/databricks/terraform-provider-databricks/qa" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestResourceAdlsGen2Mount_Create(t *testing.T) { - d, err := qa.ResourceFixture{ + qa.ResourceFixture{ Fixtures: []qa.HTTPFixture{ { Method: "GET", @@ -51,8 +50,9 @@ func TestResourceAdlsGen2Mount_Create(t *testing.T) { "initialize_file_system": true, }, Create: true, - }.Apply(t) - require.NoError(t, err) - assert.Equal(t, "this_mount", d.Id()) - assert.Equal(t, "abfss://e@test-adls-gen2.dfs.core.windows.net", d.Get("source")) + Azure: true, + }.ApplyAndExpectData(t, map[string]any{ + "id": "this_mount", + "source": "abfss://e@test-adls-gen2.dfs.core.windows.net", + }) } diff --git a/storage/aws_s3_mount.go b/storage/aws_s3_mount.go index d4a3c5b70a..2c02f3bde1 100644 --- a/storage/aws_s3_mount.go +++ b/storage/aws_s3_mount.go @@ -16,7 +16,7 @@ type AWSIamMount struct { } // Source ... -func (m AWSIamMount) Source() string { +func (m AWSIamMount) Source(_ *common.DatabricksClient) string { return fmt.Sprintf("s3a://%s", m.S3BucketName) } diff --git a/storage/azure_blob_mount.go b/storage/azure_blob_mount.go index c15f3c11f5..1b4832be2e 100644 --- a/storage/azure_blob_mount.go +++ b/storage/azure_blob_mount.go @@ -19,9 +19,9 @@ type AzureBlobMount struct { } // Source ... -func (m AzureBlobMount) Source() string { - return fmt.Sprintf("wasbs://%[1]s@%[2]s.blob.core.windows.net%[3]s", - m.ContainerName, m.StorageAccountName, m.Directory) +func (m AzureBlobMount) Source(client *common.DatabricksClient) string { + return fmt.Sprintf("wasbs://%[1]s@%[2]s.blob.%[3]s%[4]s", + m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory) } func (m AzureBlobMount) Name() string { diff --git a/storage/azure_blob_mount_test.go b/storage/azure_blob_mount_test.go index bbd6f7d8e5..f463d9956b 100644 --- a/storage/azure_blob_mount_test.go +++ b/storage/azure_blob_mount_test.go @@ -14,7 +14,7 @@ import ( ) func TestResourceAzureBlobMountCreate(t *testing.T) { - d, err := qa.ResourceFixture{ + qa.ResourceFixture{ Fixtures: []qa.HTTPFixture{ { Method: "GET", @@ -50,11 +50,12 @@ func TestResourceAzureBlobMountCreate(t *testing.T) { "token_secret_key": "g", "token_secret_scope": "h", }, + Azure: true, Create: true, - }.Apply(t) - require.NoError(t, err) - assert.Equal(t, "e", d.Id()) - assert.Equal(t, "wasbs://c@f.blob.core.windows.net/d", d.Get("source")) + }.ApplyAndExpectData(t, map[string]any{ + "id": "e", + "source": "wasbs://c@f.blob.core.windows.net/d", + }) } func TestResourceAzureBlobMountCreate_Error(t *testing.T) { @@ -86,6 +87,7 @@ func TestResourceAzureBlobMountCreate_Error(t *testing.T) { "token_secret_scope": "h", }, Create: true, + Azure: true, }.Apply(t) require.EqualError(t, err, "Some error") assert.Equal(t, "e", d.Id()) @@ -124,8 +126,9 @@ func TestResourceAzureBlobMountRead(t *testing.T) { "token_secret_key": "g", "token_secret_scope": "h", }, - ID: "e", - Read: true, + ID: "e", + Read: true, + Azure: true, }.Apply(t) require.NoError(t, err) assert.Equal(t, "e", d.Id()) @@ -165,6 +168,7 @@ func TestResourceAzureBlobMountRead_NotFound(t *testing.T) { ID: "e", Read: true, Removed: true, + Azure: true, }.ApplyNoError(t) } @@ -198,8 +202,9 @@ func TestResourceAzureBlobMountRead_Error(t *testing.T) { "token_secret_key": "g", "token_secret_scope": "h", }, - ID: "e", - Read: true, + ID: "e", + Azure: true, + Read: true, }.Apply(t) require.EqualError(t, err, "Some error") assert.Equal(t, "e", d.Id()) @@ -239,6 +244,7 @@ func TestResourceAzureBlobMountDelete(t *testing.T) { }, ID: "e", Delete: true, + Azure: true, }.Apply(t) require.NoError(t, err) assert.Equal(t, "e", d.Id()) diff --git a/storage/generic_mounts.go b/storage/generic_mounts.go index f13fdf21c0..72e8d4dedb 100644 --- a/storage/generic_mounts.go +++ b/storage/generic_mounts.go @@ -42,9 +42,9 @@ func (m GenericMount) getBlock() Mount { } // Source returns URI backing the mount -func (m GenericMount) Source() string { +func (m GenericMount) Source(client *common.DatabricksClient) string { if block := m.getBlock(); block != nil { - return block.Source() + return block.Source(client) } return m.URI } @@ -96,7 +96,7 @@ func parseStorageContainerId(rid string) (string, string, error) { return match[3], match[4], nil } -func getContainerDefaults(d *schema.ResourceData, allowed_schemas []string, suffix string) (string, string, error) { +func getContainerDefaults(d *schema.ResourceData) (string, string, error) { rid := d.Get("resource_id").(string) if rid != "" { acc, cont, err := parseStorageContainerId(rid) @@ -134,9 +134,8 @@ type AzureADLSGen2MountGeneric struct { } // Source returns ABFSS URI backing the mount -func (m *AzureADLSGen2MountGeneric) Source() string { - return fmt.Sprintf("abfss://%s@%s.dfs.core.windows.net%s", - m.ContainerName, m.StorageAccountName, m.Directory) +func (m *AzureADLSGen2MountGeneric) Source(client *common.DatabricksClient) string { + return fmt.Sprintf("abfss://%s@%s.dfs.%s%s", m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory) } func (m *AzureADLSGen2MountGeneric) Name() string { @@ -145,7 +144,7 @@ func (m *AzureADLSGen2MountGeneric) Name() string { func (m *AzureADLSGen2MountGeneric) ValidateAndApplyDefaults(d *schema.ResourceData, client *common.DatabricksClient) error { if m.ContainerName == "" || m.StorageAccountName == "" { - acc, cont, err := getContainerDefaults(d, []string{"abfs", "abfss"}, "dfs.core.windows.net") + acc, cont, err := getContainerDefaults(d) if err != nil { return err } @@ -194,7 +193,7 @@ type AzureADLSGen1MountGeneric struct { } // Source ... -func (m *AzureADLSGen1MountGeneric) Source() string { +func (m *AzureADLSGen1MountGeneric) Source(_ *common.DatabricksClient) string { return fmt.Sprintf("adl://%s.azuredatalakestore.net%s", m.StorageResource, m.Directory) } @@ -237,10 +236,9 @@ func (m *AzureADLSGen1MountGeneric) Config(client *common.DatabricksClient) map[ aadEndpoint := client.Config.Environment().AzureActiveDirectoryEndpoint() return map[string]string{ m.PrefixType + ".oauth2.access.token.provider.type": "ClientCredential", - - m.PrefixType + ".oauth2.client.id": m.ClientID, - m.PrefixType + ".oauth2.credential": fmt.Sprintf("{{secrets/%s/%s}}", m.SecretScope, m.SecretKey), - m.PrefixType + ".oauth2.refresh.url": fmt.Sprintf("%s%s/oauth2/token", aadEndpoint, m.TenantID), + m.PrefixType + ".oauth2.client.id": m.ClientID, + m.PrefixType + ".oauth2.credential": fmt.Sprintf("{{secrets/%s/%s}}", m.SecretScope, m.SecretKey), + m.PrefixType + ".oauth2.refresh.url": fmt.Sprintf("%s%s/oauth2/token", aadEndpoint, m.TenantID), } } @@ -257,9 +255,9 @@ type AzureBlobMountGeneric struct { } // Source ... -func (m *AzureBlobMountGeneric) Source() string { - return fmt.Sprintf("wasbs://%[1]s@%[2]s.blob.core.windows.net%[3]s", - m.ContainerName, m.StorageAccountName, m.Directory) +func (m *AzureBlobMountGeneric) Source(client *common.DatabricksClient) string { + return fmt.Sprintf("wasbs://%[1]s@%[2]s.blob.%[3]s%[4]s", + m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory) } func (m *AzureBlobMountGeneric) Name() string { @@ -268,7 +266,7 @@ func (m *AzureBlobMountGeneric) Name() string { func (m *AzureBlobMountGeneric) ValidateAndApplyDefaults(d *schema.ResourceData, client *common.DatabricksClient) error { if m.ContainerName == "" || m.StorageAccountName == "" { - acc, cont, err := getContainerDefaults(d, []string{"wasb", "wasbs"}, "blob.core.windows.net") + acc, cont, err := getContainerDefaults(d) if err != nil { return err } diff --git a/storage/gs.go b/storage/gs.go index 2c47b6b2b3..547809ee69 100644 --- a/storage/gs.go +++ b/storage/gs.go @@ -19,7 +19,7 @@ type GSMount struct { } // Source ... -func (m GSMount) Source() string { +func (m GSMount) Source(_ *common.DatabricksClient) string { return fmt.Sprintf("gs://%s", m.BucketName) } diff --git a/storage/mounts.go b/storage/mounts.go index 1dcf66dd60..c5e48a3568 100644 --- a/storage/mounts.go +++ b/storage/mounts.go @@ -20,7 +20,7 @@ import ( // Mount exposes generic url & extra config map options type Mount interface { - Source() string + Source(client *common.DatabricksClient) string Config(client *common.DatabricksClient) map[string]string Name() string @@ -96,7 +96,7 @@ func (mp MountPoint) Mount(mo Mount, client *common.DatabricksClient) (source st raise e mount_source = safe_mount("/mnt/%s", "%v", %s, "%s") dbutils.notebook.exit(mount_source) - `, mp.Name, mo.Source(), extraConfigs, mp.EncryptionType) // lgtm[go/unsafe-quoting] + `, mp.Name, mo.Source(client), extraConfigs, mp.EncryptionType) // lgtm[go/unsafe-quoting] result := mp.Exec.Execute(mp.ClusterID, "python", command) return result.Text(), result.Err() } @@ -235,7 +235,7 @@ func mountCreate(tpl any, r common.Resource) func(context.Context, *schema.Resou if err != nil { return err } - log.Printf("[INFO] Mounting %s at /mnt/%s", mountConfig.Source(), d.Id()) + log.Printf("[INFO] Mounting %s at /mnt/%s", mountConfig.Source(client), d.Id()) source, err := mountPoint.Mount(mountConfig, client) if err != nil { return err diff --git a/storage/mounts_test.go b/storage/mounts_test.go index 3060c33a1b..8005950c3a 100644 --- a/storage/mounts_test.go +++ b/storage/mounts_test.go @@ -74,8 +74,8 @@ func testMountFuncHelper(t *testing.T, mountFunc func(mp MountPoint, mount Mount type mockMount struct{} -func (t mockMount) Source() string { return "fake-mount" } -func (t mockMount) Name() string { return "fake-mount" } +func (t mockMount) Source(_ *common.DatabricksClient) string { return "fake-mount" } +func (t mockMount) Name() string { return "fake-mount" } func (t mockMount) Config(client *common.DatabricksClient) map[string]string { return map[string]string{"fake-key": "fake-value"} } @@ -84,6 +84,14 @@ func (m mockMount) ValidateAndApplyDefaults(d *schema.ResourceData, client *comm } func TestMountPoint_Mount(t *testing.T) { + client := common.DatabricksClient{ + DatabricksClient: &client.DatabricksClient{ + Config: &config.Config{ + Host: ".", + Token: ".", + }, + }, + } mount := mockMount{} expectedMountSource := "fake-mount" expectedMountConfig := `{"fake-key":"fake-value"}` @@ -108,14 +116,6 @@ func TestMountPoint_Mount(t *testing.T) { dbutils.notebook.exit(mount_source) `, mountName, expectedMountSource, expectedMountConfig) testMountFuncHelper(t, func(mp MountPoint, mount Mount) (s string, e error) { - client := common.DatabricksClient{ - DatabricksClient: &client.DatabricksClient{ - Config: &config.Config{ - Host: ".", - Token: ".", - }, - }, - } return mp.Mount(mount, &client) }, mount, mountName, expectedCommand) } diff --git a/storage/resource_mount_test.go b/storage/resource_mount_test.go index 5a19e2d000..5239ee9cf6 100644 --- a/storage/resource_mount_test.go +++ b/storage/resource_mount_test.go @@ -726,6 +726,7 @@ func TestResourceAdlsGen2MountGeneric_Create(t *testing.T) { "client_secret_key": "d", "initialize_file_system": true, }}}, + Azure: true, Create: true, }.Apply(t) require.NoError(t, err) @@ -770,6 +771,7 @@ func TestResourceAdlsGen2MountGeneric_Create_ResourceID(t *testing.T) { "initialize_file_system": true, }}}, Create: true, + Azure: true, }.Apply(t) require.NoError(t, err) assert.Equal(t, "e", d.Id()) @@ -816,6 +818,7 @@ func TestResourceAdlsGen2MountGeneric_Create_NoTenantID_SPN(t *testing.T) { "initialize_file_system": true, }}}, Create: true, + Azure: true, }.Apply(t) require.NoError(t, err) assert.Equal(t, "this_mount", d.Id()) @@ -951,6 +954,7 @@ func TestResourceAzureBlobMountCreateGeneric(t *testing.T) { }, }}, Create: true, + Azure: true, }.Apply(t) require.NoError(t, err) assert.Equal(t, "e", d.Id()) @@ -996,6 +1000,7 @@ func TestResourceAzureBlobMountCreateGeneric_SAS(t *testing.T) { "directory": "/d", }, }}, + Azure: true, Create: true, }.Apply(t) require.NoError(t, err) @@ -1040,6 +1045,7 @@ func TestResourceAzureBlobMountCreateGeneric_Resource_ID(t *testing.T) { "directory": "/d", }, }}, + Azure: true, Create: true, }.Apply(t) require.NoError(t, err) @@ -1060,6 +1066,7 @@ func TestResourceAzureBlobMountCreateGeneric_Resource_ID_Error(t *testing.T) { "directory": "/d", }, }}, + Azure: true, Create: true, }.Apply(t) qa.AssertErrorStartsWith(t, err, "parsing failed for abc. Invalid container resource Id format") @@ -1094,6 +1101,7 @@ func TestResourceAzureBlobMountCreateGeneric_Error(t *testing.T) { "token_secret_key": "g", "token_secret_scope": "h", }}}, + Azure: true, Create: true, }.Apply(t) require.EqualError(t, err, "Some error") @@ -1128,6 +1136,7 @@ func TestResourceAzureBlobMountCreateGeneric_Error_NoResourceID(t *testing.T) { "token_secret_key": "g", "token_secret_scope": "h", }}}, + Azure: true, Create: true, }.Apply(t) require.EqualError(t, err, "container_name or storage_account_name are empty, and resource_id or uri aren't specified") @@ -1175,8 +1184,9 @@ func TestResourceAzureBlobMountGeneric_Read(t *testing.T) { "token_secret_scope": "h", }}, }, - ID: "e", - Read: true, + Azure: true, + ID: "e", + Read: true, }.Apply(t) require.NoError(t, err) assert.Equal(t, "e", d.Id()) @@ -1222,6 +1232,7 @@ func TestResourceAzureBlobMountGenericRead_NotFound(t *testing.T) { "token_secret_scope": "h", }}, }, + Azure: true, ID: "e", Read: true, Removed: true, @@ -1267,8 +1278,9 @@ func TestResourceAzureBlobMountGenericRead_Error(t *testing.T) { "token_secret_scope": "h", }}, }, - ID: "e", - Read: true, + Azure: true, + ID: "e", + Read: true, }.Apply(t) require.EqualError(t, err, "Some error") assert.Equal(t, "e", d.Id()) @@ -1315,6 +1327,7 @@ func TestResourceAzureBlobMountGenericDelete(t *testing.T) { "token_secret_scope": "h", }}, }, + Azure: true, ID: "e", Delete: true, }.Apply(t) diff --git a/storage/s3.go b/storage/s3.go index e6f18457c9..9b62fdf3e2 100644 --- a/storage/s3.go +++ b/storage/s3.go @@ -18,7 +18,7 @@ type S3IamMount struct { } // Source ... -func (m S3IamMount) Source() string { +func (m S3IamMount) Source(_ *common.DatabricksClient) string { return fmt.Sprintf("s3a://%s", m.BucketName) }