From 180a7068b8b839f5fa763499514aa3a608ce628f Mon Sep 17 00:00:00 2001 From: Akash Chetty Date: Mon, 22 Jul 2024 19:23:38 +0530 Subject: [PATCH] feat(redshift): add support for assuming an iam role for redshift (#132) --- .github/workflows/test.yaml | 1 + go.mod | 2 +- sqlconnect/internal/redshift/config.go | 4 + sqlconnect/internal/redshift/db.go | 3 + .../internal/redshift/driver/connector.go | 6 +- .../internal/redshift/driver/driver_test.go | 532 +++++++++--------- sqlconnect/internal/redshift/driver/dsn.go | 62 +- .../internal/redshift/driver/dsn_test.go | 12 + .../internal/redshift/integration_test.go | 51 +- 9 files changed, 397 insertions(+), 276 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 9be1392..d58f0e1 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -55,6 +55,7 @@ jobs: env: REDSHIFT_TEST_ENVIRONMENT_CREDENTIALS: ${{ secrets.REDSHIFT_TEST_ENVIRONMENT_CREDENTIALS }} REDSHIFT_DATA_TEST_ENVIRONMENT_CREDENTIALS: ${{ secrets.REDSHIFT_DATA_TEST_ENVIRONMENT_CREDENTIALS }} + REDSHIFT_DATA_TEST_ENVIRONMENT_ROLE_ARN_CREDENTIALS: ${{ secrets.REDSHIFT_DATA_TEST_ENVIRONMENT_ROLE_ARN_CREDENTIALS }} SNOWFLAKE_TEST_ENVIRONMENT_CREDENTIALS: ${{ secrets.SNOWFLAKE_TEST_ENVIRONMENT_CREDENTIALS }} BIGQUERY_TEST_ENVIRONMENT_CREDENTIALS: ${{ secrets.BIGQUERY_TEST_ENVIRONMENT_CREDENTIALS }} DATABRICKS_TEST_ENVIRONMENT_CREDENTIALS: ${{ secrets.DATABRICKS_TEST_ENVIRONMENT_CREDENTIALS }} diff --git a/go.mod b/go.mod index dc0d9d1..966eaba 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.27.27 github.com/aws/aws-sdk-go-v2/credentials v1.17.27 github.com/aws/aws-sdk-go-v2/service/redshiftdata v1.27.3 + github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 github.com/databricks/databricks-sql-go v1.5.7 github.com/dlclark/regexp2 v1.11.2 github.com/gliderlabs/ssh v0.3.7 @@ -66,7 +67,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/s3 v1.53.1 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 // indirect github.com/aws/smithy-go v1.20.3 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/containerd/continuity v0.4.3 // indirect diff --git a/sqlconnect/internal/redshift/config.go b/sqlconnect/internal/redshift/config.go index 4e66f66..526d774 100644 --- a/sqlconnect/internal/redshift/config.go +++ b/sqlconnect/internal/redshift/config.go @@ -30,6 +30,10 @@ type Config struct { SecretAccessKey string `json:"secretAccessKey"` SessionToken string `json:"sessionToken"` + RoleARN string `json:"roleARN"` + ExternalID string `json:"externalID"` + RoleARNExpiry time.Duration `json:"roleARNExpiry"` // default: 15m + Timeout time.Duration `json:"timeout"` // default: no timeout MinPolling time.Duration `json:"minPolling"` // default: 10ms MaxPolling time.Duration `json:"maxPolling"` // default: 5s diff --git a/sqlconnect/internal/redshift/db.go b/sqlconnect/internal/redshift/db.go index ef38e3a..9e477b5 100644 --- a/sqlconnect/internal/redshift/db.go +++ b/sqlconnect/internal/redshift/db.go @@ -123,6 +123,9 @@ func newRedshiftDataDB(credentialsJSON json.RawMessage) (*sql.DB, error) { SharedConfigProfile: config.SharedConfigProfile, SecretAccessKey: config.SecretAccessKey, SessionToken: config.SessionToken, + RoleARN: config.RoleARN, + ExternalID: config.ExternalID, + RoleARNExpiry: config.RoleARNExpiry, Timeout: config.Timeout, MinPolling: config.MinPolling, MaxPolling: config.MaxPolling, diff --git a/sqlconnect/internal/redshift/driver/connector.go b/sqlconnect/internal/redshift/driver/connector.go index c305833..af2fbb7 100644 --- a/sqlconnect/internal/redshift/driver/connector.go +++ b/sqlconnect/internal/redshift/driver/connector.go @@ -19,7 +19,11 @@ type redshiftDataConnector struct { } func (c *redshiftDataConnector) Connect(ctx context.Context) (driver.Conn, error) { - client, err := newRedshiftDataClient(ctx, c.cfg, c.cfg.LoadOpts()...) + loadOpts, err := c.cfg.LoadOpts(ctx) + if err != nil { + return nil, err + } + client, err := newRedshiftDataClient(ctx, c.cfg, loadOpts...) if err != nil { return nil, err } diff --git a/sqlconnect/internal/redshift/driver/driver_test.go b/sqlconnect/internal/redshift/driver/driver_test.go index 1a78898..1b4ac22 100644 --- a/sqlconnect/internal/redshift/driver/driver_test.go +++ b/sqlconnect/internal/redshift/driver/driver_test.go @@ -13,269 +13,289 @@ import ( "github.com/stretchr/testify/require" "github.com/rudderlabs/rudder-go-kit/testhelper/rand" + "github.com/rudderlabs/sqlconnect-go/sqlconnect/internal/redshift/driver" ) func TestRedshiftDriver(t *testing.T) { - configJSON, ok := os.LookupEnv("REDSHIFT_DATA_TEST_ENVIRONMENT_CREDENTIALS") - if !ok { - if os.Getenv("FORCE_RUN_INTEGRATION_TESTS") == "true" { - t.Fatal("REDSHIFT_DATA_TEST_ENVIRONMENT_CREDENTIALS environment variable not set") - } - t.Skip("skipping redshift sdk driver integration test due to lack of a test environment") + testCases := []struct { + name string + credentialsKey string + }{ + { + name: "with AccessKeyID and cfg.SecretAccessKey", + credentialsKey: "REDSHIFT_DATA_TEST_ENVIRONMENT_CREDENTIALS", + }, + { + name: "with RoleARN", + credentialsKey: "REDSHIFT_DATA_TEST_ENVIRONMENT_ROLE_ARN_CREDENTIALS", + }, } - var cfg driver.RedshiftConfig - err := json.Unmarshal([]byte(configJSON), &cfg) - require.NoError(t, err, "it should be able to unmarshal the config") - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - connector := driver.NewRedshiftConnector(cfg) - db := sql.OpenDB(connector) - schema := GenerateTestSchema() - t.Cleanup(func() { - _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA IF EXISTS "%s" CASCADE`, schema)) - require.NoError(t, err, "it should be able to drop the schema") - }) - - t.Run("Open", func(t *testing.T) { - dsn := cfg.String() - db, err := sql.Open("redshift-data", dsn) - require.NoError(t, err, "it should be able to open the database") - require.NoError(t, db.Ping(), "it should be able to ping the database") - err = db.Close() - require.NoError(t, err, "it should be able to close the database") - }) - - t.Run("Driver", func(t *testing.T) { - driver := connector.Driver() - require.NotNil(t, driver, "it should be able to get the driver") - conn, err := driver.Open(cfg.String()) - require.NoError(t, err, "it should be able to open a connection") - err = conn.Close() - require.NoError(t, err, "it should be able to close the connection") - }) - - t.Run("Ping", func(t *testing.T) { - require.NoError(t, db.Ping(), "it should be able to ping the database") - require.NoError(t, db.PingContext(ctx), "it should be able to ping the database using a context") - }) - - t.Run("Exec", func(t *testing.T) { - _, err := db.Exec(fmt.Sprintf(`CREATE SCHEMA "%s"`, schema)) - require.NoError(t, err, "it should be able to create a schema") - }) - - t.Run("ExecContext", func(t *testing.T) { - _, err := db.ExecContext(ctx, fmt.Sprintf(`CREATE TABLE "%s"."test_table" ("C1" INT4, "C2" VARCHAR)`, schema)) - require.NoError(t, err, "it should be able to create a table") - }) - - t.Run("prepared statement", func(t *testing.T) { - t.Run("QueryRow", func(t *testing.T) { - stmt, err := db.Prepare(fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)) - require.NoError(t, err, "it should be able to prepare a statement") - defer func() { - require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") - }() - - var count int - err = stmt.QueryRow().Scan(&count) - require.NoError(t, err, "it should be able to execute a prepared statement") - }) - - t.Run("Exec", func(t *testing.T) { - stmt, err := db.Prepare(fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES (?)`, schema)) - require.NoError(t, err, "it should be able to prepare a statement") - defer func() { - require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") - }() - result, err := stmt.Exec(1) - require.NoError(t, err, "it should be able to execute a prepared statement") - - _, err = result.LastInsertId() - require.Error(t, err) - require.ErrorIs(t, err, driver.ErrNotSupported) - - rowsAffected, err := result.RowsAffected() - require.NoError(t, err, "it should be able to get rows affected") - require.EqualValues(t, 1, rowsAffected, "rows affected should be 1") - }) - - t.Run("Query", func(t *testing.T) { - stmt, err := db.Prepare(fmt.Sprintf(`SELECT C1 FROM "%s"."test_table" WHERE C1 = $1`, schema)) - require.NoError(t, err, "it should be able to prepare a statement") - defer func() { - require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") - }() - rows, err := stmt.Query(1) - require.NoError(t, err, "it should be able to execute a prepared statement") - defer func() { - require.NoError(t, rows.Close(), "it should be able to close the rows") - }() - require.True(t, rows.Next(), "it should be able to get a row") - var c1 int - err = rows.Scan(&c1) - require.NoError(t, err, "it should be able to scan the row") - require.EqualValues(t, 1, c1, "it should be able to get the correct value") - require.False(t, rows.Next(), "it shouldn't have next row") - - require.NoError(t, rows.Err()) - }) - t.Run("Query with named parameters", func(t *testing.T) { - stmt, err := db.PrepareContext(ctx, fmt.Sprintf(`SELECT C1, C2 FROM "%s"."test_table" WHERE C1 = :c1_value`, schema)) - require.NoError(t, err, "it should be able to prepare a statement") - defer func() { - require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") - }() - rows, err := stmt.QueryContext(ctx, sql.Named("c1_value", 1)) - require.NoError(t, err, "it should be able to execute a prepared statement") - defer func() { - require.NoError(t, rows.Close(), "it should be able to close the rows") - }() - - cols, err := rows.Columns() - require.NoError(t, err, "it should be able to get the columns") - require.EqualValues(t, []string{"c1", "c2"}, cols, "it should be able to get the correct columns") - - colTypes, err := rows.ColumnTypes() - require.NoError(t, err, "it should be able to get the column types") - require.Len(t, colTypes, 2, "it should be able to get the correct number of column types") - require.EqualValues(t, "INT4", colTypes[0].DatabaseTypeName(), "it should be able to get the correct column type") - require.EqualValues(t, "VARCHAR", colTypes[1].DatabaseTypeName(), "it should be able to get the correct column type") - - require.True(t, rows.Next(), "it should be able to get a row") - var c1 int - var c2 any - err = rows.Scan(&c1, &c2) - require.NoError(t, err, "it should be able to scan the row") - require.EqualValues(t, 1, c1, "it should be able to get the correct value") - require.Nil(t, c2, "it should be able to get the correct value") - require.False(t, rows.Next(), "it shouldn't have next row") - - require.NoError(t, rows.Err()) - }) - }) - t.Run("query", func(t *testing.T) { - t.Run("QueryRow", func(t *testing.T) { - var count int - err := db.QueryRow(fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&count) - require.NoError(t, err, "it should be able to execute a prepared statement") - require.Equal(t, 1, count, "it should be able to get the correct value") - }) - t.Run("Exec", func(t *testing.T) { - result, err := db.Exec(fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES ($1)`, schema), 2) - require.NoError(t, err, "it should be able to execute a prepared statement") - rowsAffected, err := result.RowsAffected() - require.NoError(t, err, "it should be able to get rows affected") - require.EqualValues(t, 1, rowsAffected, "rows affected should be 1") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + configJSON, ok := os.LookupEnv(tc.credentialsKey) + if !ok { + if os.Getenv("FORCE_RUN_INTEGRATION_TESTS") == "true" { + t.Fatalf("%s environment variable not set", tc.credentialsKey) + } + t.Skipf("skipping redshift sdk driver integration test (%s) due to lack of a test environment", tc.name) + } + var cfg driver.RedshiftConfig + err := json.Unmarshal([]byte(configJSON), &cfg) + require.NoError(t, err, "it should be able to unmarshal the config") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + connector := driver.NewRedshiftConnector(cfg) + db := sql.OpenDB(connector) + schema := GenerateTestSchema() + t.Cleanup(func() { + _, err := db.Exec(fmt.Sprintf(`DROP SCHEMA IF EXISTS "%s" CASCADE`, schema)) + require.NoError(t, err, "it should be able to drop the schema") + }) + + t.Run("Open", func(t *testing.T) { + dsn := cfg.String() + db, err := sql.Open("redshift-data", dsn) + require.NoError(t, err, "it should be able to open the database") + require.NoError(t, db.Ping(), "it should be able to ping the database") + err = db.Close() + require.NoError(t, err, "it should be able to close the database") + }) + + t.Run("Driver", func(t *testing.T) { + driver := connector.Driver() + require.NotNil(t, driver, "it should be able to get the driver") + conn, err := driver.Open(cfg.String()) + require.NoError(t, err, "it should be able to open a connection") + err = conn.Close() + require.NoError(t, err, "it should be able to close the connection") + }) + + t.Run("Ping", func(t *testing.T) { + require.NoError(t, db.Ping(), "it should be able to ping the database") + require.NoError(t, db.PingContext(ctx), "it should be able to ping the database using a context") + }) + + t.Run("Exec", func(t *testing.T) { + _, err := db.Exec(fmt.Sprintf(`CREATE SCHEMA "%s"`, schema)) + require.NoError(t, err, "it should be able to create a schema") + }) + + t.Run("ExecContext", func(t *testing.T) { + _, err := db.ExecContext(ctx, fmt.Sprintf(`CREATE TABLE "%s"."test_table" ("C1" INT4, "C2" VARCHAR)`, schema)) + require.NoError(t, err, "it should be able to create a table") + }) + + t.Run("prepared statement", func(t *testing.T) { + t.Run("QueryRow", func(t *testing.T) { + stmt, err := db.Prepare(fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)) + require.NoError(t, err, "it should be able to prepare a statement") + defer func() { + require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") + }() + + var count int + err = stmt.QueryRow().Scan(&count) + require.NoError(t, err, "it should be able to execute a prepared statement") + }) + + t.Run("Exec", func(t *testing.T) { + stmt, err := db.Prepare(fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES (?)`, schema)) + require.NoError(t, err, "it should be able to prepare a statement") + defer func() { + require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") + }() + result, err := stmt.Exec(1) + require.NoError(t, err, "it should be able to execute a prepared statement") + + _, err = result.LastInsertId() + require.Error(t, err) + require.ErrorIs(t, err, driver.ErrNotSupported) + + rowsAffected, err := result.RowsAffected() + require.NoError(t, err, "it should be able to get rows affected") + require.EqualValues(t, 1, rowsAffected, "rows affected should be 1") + }) + + t.Run("Query", func(t *testing.T) { + stmt, err := db.Prepare(fmt.Sprintf(`SELECT C1 FROM "%s"."test_table" WHERE C1 = $1`, schema)) + require.NoError(t, err, "it should be able to prepare a statement") + defer func() { + require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") + }() + rows, err := stmt.Query(1) + require.NoError(t, err, "it should be able to execute a prepared statement") + defer func() { + require.NoError(t, rows.Close(), "it should be able to close the rows") + }() + require.True(t, rows.Next(), "it should be able to get a row") + var c1 int + err = rows.Scan(&c1) + require.NoError(t, err, "it should be able to scan the row") + require.EqualValues(t, 1, c1, "it should be able to get the correct value") + require.False(t, rows.Next(), "it shouldn't have next row") + + require.NoError(t, rows.Err()) + }) + t.Run("Query with named parameters", func(t *testing.T) { + stmt, err := db.PrepareContext(ctx, fmt.Sprintf(`SELECT C1, C2 FROM "%s"."test_table" WHERE C1 = :c1_value`, schema)) + require.NoError(t, err, "it should be able to prepare a statement") + defer func() { + require.NoError(t, stmt.Close(), "it should be able to close the prepared statement") + }() + rows, err := stmt.QueryContext(ctx, sql.Named("c1_value", 1)) + require.NoError(t, err, "it should be able to execute a prepared statement") + defer func() { + require.NoError(t, rows.Close(), "it should be able to close the rows") + }() + + cols, err := rows.Columns() + require.NoError(t, err, "it should be able to get the columns") + require.EqualValues(t, []string{"c1", "c2"}, cols, "it should be able to get the correct columns") + + colTypes, err := rows.ColumnTypes() + require.NoError(t, err, "it should be able to get the column types") + require.Len(t, colTypes, 2, "it should be able to get the correct number of column types") + require.EqualValues(t, "INT4", colTypes[0].DatabaseTypeName(), "it should be able to get the correct column type") + require.EqualValues(t, "VARCHAR", colTypes[1].DatabaseTypeName(), "it should be able to get the correct column type") + + require.True(t, rows.Next(), "it should be able to get a row") + var c1 int + var c2 any + err = rows.Scan(&c1, &c2) + require.NoError(t, err, "it should be able to scan the row") + require.EqualValues(t, 1, c1, "it should be able to get the correct value") + require.Nil(t, c2, "it should be able to get the correct value") + require.False(t, rows.Next(), "it shouldn't have next row") + + require.NoError(t, rows.Err()) + }) + }) + + t.Run("query", func(t *testing.T) { + t.Run("QueryRow", func(t *testing.T) { + var count int + err := db.QueryRow(fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&count) + require.NoError(t, err, "it should be able to execute a prepared statement") + require.Equal(t, 1, count, "it should be able to get the correct value") + }) + + t.Run("Exec", func(t *testing.T) { + result, err := db.Exec(fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES ($1)`, schema), 2) + require.NoError(t, err, "it should be able to execute a prepared statement") + rowsAffected, err := result.RowsAffected() + require.NoError(t, err, "it should be able to get rows affected") + require.EqualValues(t, 1, rowsAffected, "rows affected should be 1") + }) + + t.Run("Query", func(t *testing.T) { + rows, err := db.Query(fmt.Sprintf(`SELECT C1 FROM "%s"."test_table" WHERE C1 = ?`, schema), 2) + require.NoError(t, err, "it should be able to execute a prepared statement") + defer func() { + require.NoError(t, rows.Close(), "it should be able to close the rows") + }() + require.True(t, rows.Next(), "it should be able to get a row") + var c1 int + err = rows.Scan(&c1) + require.NoError(t, err, "it should be able to scan the row") + require.EqualValues(t, 2, c1, "it should be able to get the correct value") + require.False(t, rows.Next(), "it shouldn't have next row") + + require.NoError(t, rows.Err()) + }) + + t.Run("Query with named parameters", func(t *testing.T) { + rows, err := db.QueryContext(ctx, fmt.Sprintf(`SELECT C1 FROM "%s"."test_table" WHERE C1 = :c1_value`, schema), sql.Named("c1_value", 2)) + require.NoError(t, err, "it should be able to execute a prepared statement") + defer func() { + require.NoError(t, rows.Close(), "it should be able to close the rows") + }() + + cols, err := rows.Columns() + require.NoError(t, err, "it should be able to get the columns") + require.EqualValues(t, []string{"c1"}, cols, "it should be able to get the correct columns") + + colTypes, err := rows.ColumnTypes() + require.NoError(t, err, "it should be able to get the column types") + require.Len(t, colTypes, 1, "it should be able to get the correct number of column types") + require.EqualValues(t, "INT4", colTypes[0].DatabaseTypeName(), "it should be able to get the correct column type") + + require.True(t, rows.Next(), "it should be able to get a row") + var c1 int + err = rows.Scan(&c1) + require.NoError(t, err, "it should be able to scan the row") + require.EqualValues(t, 2, c1, "it should be able to get the correct value") + require.False(t, rows.Next(), "it shouldn't have next row") + + require.NoError(t, rows.Err()) + }) + }) + + t.Run("transaction support", func(t *testing.T) { + t.Run("Begin and Commit", func(t *testing.T) { + var countBefore int + err := db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countBefore) + require.NoError(t, err, "it should be able to execute a prepared statement") + + tx, err := db.Begin() + require.NoError(t, err, "it should be able to begin a transaction") + _, err = tx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES (3)`, schema)) + require.NoError(t, err, "it should be able to execute a prepared statement") + res, err := tx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES (4)`, schema)) + require.NoError(t, err, "it should be able to execute a prepared statement") + _, err = res.RowsAffected() + require.Error(t, err, "it should not be able to get rows affected before commit") + require.ErrorIs(t, err, driver.ErrBeforeCommit) + _, err = res.LastInsertId() + require.Error(t, err, "it should not be able to get last insert id before commit") + require.ErrorIs(t, err, driver.ErrBeforeCommit) + + _, err = tx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES (?)`, schema), 5) + require.Error(t, err, "it should not be able to execute a prepared statement with parameters in a transaction") + require.ErrorIs(t, err, driver.ErrNotSupported) + + var countDuring int + err = db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countDuring) + require.NoError(t, err, "it should be able to execute a prepared statement") + require.Equal(t, countBefore, countDuring, "it should not be able to see the changes from the transaction") + + err = tx.Commit() + require.NoError(t, err, "it should be able to commit the transaction") + + var countAfter int + err = db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countAfter) + require.NoError(t, err, "it should be able to execute a prepared statement") + require.Equal(t, countBefore+2, countAfter, "it should be able to see the changes from the transaction") + }) + t.Run("BeginTx and Rollback", func(t *testing.T) { + var countBefore int + err := db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countBefore) + require.NoError(t, err, "it should be able to execute a prepared statement") + + tx, err := db.BeginTx(ctx, nil) + require.NoError(t, err, "it should be able to begin a transaction") + _, err = tx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES (5)`, schema)) + require.NoError(t, err, "it should be able to execute a prepared statement") + + var countDuring int + err = db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countDuring) + require.NoError(t, err, "it should be able to execute a prepared statement") + require.Equal(t, countBefore, countDuring, "it should not be able to see the changes from the transaction") + + err = tx.Rollback() + require.NoError(t, err, "it should be able to rollback the transaction") + + var countAfter int + err = db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countAfter) + require.NoError(t, err, "it should be able to execute a prepared statement") + require.Equal(t, countBefore, countAfter, "changes from the transaction should be rolled back") + }) + }) }) - - t.Run("Query", func(t *testing.T) { - rows, err := db.Query(fmt.Sprintf(`SELECT C1 FROM "%s"."test_table" WHERE C1 = ?`, schema), 2) - require.NoError(t, err, "it should be able to execute a prepared statement") - defer func() { - require.NoError(t, rows.Close(), "it should be able to close the rows") - }() - require.True(t, rows.Next(), "it should be able to get a row") - var c1 int - err = rows.Scan(&c1) - require.NoError(t, err, "it should be able to scan the row") - require.EqualValues(t, 2, c1, "it should be able to get the correct value") - require.False(t, rows.Next(), "it shouldn't have next row") - - require.NoError(t, rows.Err()) - }) - - t.Run("Query with named parameters", func(t *testing.T) { - rows, err := db.QueryContext(ctx, fmt.Sprintf(`SELECT C1 FROM "%s"."test_table" WHERE C1 = :c1_value`, schema), sql.Named("c1_value", 2)) - require.NoError(t, err, "it should be able to execute a prepared statement") - defer func() { - require.NoError(t, rows.Close(), "it should be able to close the rows") - }() - - cols, err := rows.Columns() - require.NoError(t, err, "it should be able to get the columns") - require.EqualValues(t, []string{"c1"}, cols, "it should be able to get the correct columns") - - colTypes, err := rows.ColumnTypes() - require.NoError(t, err, "it should be able to get the column types") - require.Len(t, colTypes, 1, "it should be able to get the correct number of column types") - require.EqualValues(t, "INT4", colTypes[0].DatabaseTypeName(), "it should be able to get the correct column type") - - require.True(t, rows.Next(), "it should be able to get a row") - var c1 int - err = rows.Scan(&c1) - require.NoError(t, err, "it should be able to scan the row") - require.EqualValues(t, 2, c1, "it should be able to get the correct value") - require.False(t, rows.Next(), "it shouldn't have next row") - - require.NoError(t, rows.Err()) - }) - }) - - t.Run("transaction support", func(t *testing.T) { - t.Run("Begin and Commit", func(t *testing.T) { - var countBefore int - err := db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countBefore) - require.NoError(t, err, "it should be able to execute a prepared statement") - - tx, err := db.Begin() - require.NoError(t, err, "it should be able to begin a transaction") - _, err = tx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES (3)`, schema)) - require.NoError(t, err, "it should be able to execute a prepared statement") - res, err := tx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES (4)`, schema)) - require.NoError(t, err, "it should be able to execute a prepared statement") - _, err = res.RowsAffected() - require.Error(t, err, "it should not be able to get rows affected before commit") - require.ErrorIs(t, err, driver.ErrBeforeCommit) - _, err = res.LastInsertId() - require.Error(t, err, "it should not be able to get last insert id before commit") - require.ErrorIs(t, err, driver.ErrBeforeCommit) - - _, err = tx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES (?)`, schema), 5) - require.Error(t, err, "it should not be able to execute a prepared statement with parameters in a transaction") - require.ErrorIs(t, err, driver.ErrNotSupported) - - var countDuring int - err = db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countDuring) - require.NoError(t, err, "it should be able to execute a prepared statement") - require.Equal(t, countBefore, countDuring, "it should not be able to see the changes from the transaction") - - err = tx.Commit() - require.NoError(t, err, "it should be able to commit the transaction") - - var countAfter int - err = db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countAfter) - require.NoError(t, err, "it should be able to execute a prepared statement") - require.Equal(t, countBefore+2, countAfter, "it should be able to see the changes from the transaction") - }) - t.Run("BeginTx and Rollback", func(t *testing.T) { - var countBefore int - err := db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countBefore) - require.NoError(t, err, "it should be able to execute a prepared statement") - - tx, err := db.BeginTx(ctx, nil) - require.NoError(t, err, "it should be able to begin a transaction") - _, err = tx.ExecContext(ctx, fmt.Sprintf(`INSERT INTO "%s"."test_table" (C1) VALUES (5)`, schema)) - require.NoError(t, err, "it should be able to execute a prepared statement") - - var countDuring int - err = db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countDuring) - require.NoError(t, err, "it should be able to execute a prepared statement") - require.Equal(t, countBefore, countDuring, "it should not be able to see the changes from the transaction") - - err = tx.Rollback() - require.NoError(t, err, "it should be able to rollback the transaction") - - var countAfter int - err = db.QueryRowContext(ctx, fmt.Sprintf(`SELECT COUNT(*) FROM "%s"."test_table"`, schema)).Scan(&countAfter) - require.NoError(t, err, "it should be able to execute a prepared statement") - require.Equal(t, countBefore, countAfter, "changes from the transaction should be rolled back") - }) - }) + } } func GenerateTestSchema() string { diff --git a/sqlconnect/internal/redshift/driver/dsn.go b/sqlconnect/internal/redshift/driver/dsn.go index 8e9eb8f..d6df868 100644 --- a/sqlconnect/internal/redshift/driver/dsn.go +++ b/sqlconnect/internal/redshift/driver/dsn.go @@ -1,6 +1,7 @@ package driver import ( + "context" "errors" "fmt" "net/url" @@ -8,9 +9,16 @@ import ( "strings" "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/redshiftdata" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +const ( + roleSessionName = "rudderstack-aws-redshift-access" ) type RedshiftConfig struct { @@ -24,6 +32,9 @@ type RedshiftConfig struct { AccessKeyID string `json:"accessKeyId"` SecretAccessKey string `json:"secretAccessKey"` SessionToken string `json:"sessionToken"` + RoleARN string `json:"roleARN"` + RoleARNExpiry time.Duration `json:"roleARNExpiry"` // default: 15m + ExternalID string `json:"externalID"` Timeout time.Duration `json:"timeout"` // default: no timeout MinPolling time.Duration `json:"polling"` // default: 10ms MaxPolling time.Duration `json:"maxPolling"` // default: 5s @@ -41,8 +52,11 @@ func (cfg *RedshiftConfig) Sanitize() { } } -func (cfg *RedshiftConfig) LoadOpts() []func(*config.LoadOptions) error { +func (cfg *RedshiftConfig) LoadOpts(ctx context.Context) ([]func(*config.LoadOptions) error, error) { var opts []func(*config.LoadOptions) error + if cfg.Region != "" { + opts = append(opts, config.WithRegion(cfg.Region)) + } if cfg.SharedConfigProfile != "" { opts = append(opts, config.WithSharedConfigProfile(cfg.SharedConfigProfile)) } @@ -53,8 +67,22 @@ func (cfg *RedshiftConfig) LoadOpts() []func(*config.LoadOptions) error { cfg.SessionToken, ))) } + if cfg.RoleARN != "" { + stsCfg, err := config.LoadDefaultConfig(ctx, opts...) + if err != nil { + return nil, fmt.Errorf("load default aws config: %w", err) + } + stsSvc := sts.NewFromConfig(stsCfg) + opts = append([]func(*config.LoadOptions) error{}, config.WithCredentialsProvider(stscreds.NewAssumeRoleProvider(stsSvc, cfg.RoleARN, func(o *stscreds.AssumeRoleOptions) { + if cfg.ExternalID != "" { + o.ExternalID = aws.String(cfg.ExternalID) + } + o.RoleSessionName = roleSessionName + o.Duration = cfg.RoleARNExpiry + }))) + } opts = append(opts, config.WithRetryMaxAttempts(cfg.GetRetryMaxAttempts())) - return opts + return opts, nil } func (cfg *RedshiftConfig) Opts() []func(*redshiftdata.Options) { @@ -121,6 +149,21 @@ func (cfg *RedshiftConfig) String() string { } else { params.Del("sessionToken") } + if cfg.RoleARN != "" { + params.Add("roleARN", cfg.RoleARN) + } else { + params.Del("roleARN") + } + if cfg.ExternalID != "" { + params.Add("externalID", cfg.ExternalID) + } else { + params.Del("externalID") + } + if cfg.RoleARNExpiry > 0 { + params.Add("roleARNExpiry", cfg.RoleARNExpiry.String()) + } else { + params.Del("roleARNExpiry") + } encodedParams := params.Encode() if encodedParams != "" { return base + "?" + encodedParams @@ -179,6 +222,21 @@ func (cfg *RedshiftConfig) setParams(params url.Values) error { cfg.SessionToken = params.Get("sessionToken") cfg.Params.Del("sessionToken") } + if params.Has("roleARN") { + cfg.RoleARN = params.Get("roleARN") + cfg.Params.Del("roleARN") + } + if params.Has("externalID") { + cfg.ExternalID = params.Get("externalID") + cfg.Params.Del("externalID") + } + if params.Has("roleARNExpiry") { + cfg.RoleARNExpiry, err = time.ParseDuration(params.Get("roleARNExpiry")) + if err != nil { + return fmt.Errorf("parse role arn expiry as duration: %w", err) + } + cfg.Params.Del("roleARNExpiry") + } if len(cfg.Params) == 0 { cfg.Params = nil } diff --git a/sqlconnect/internal/redshift/driver/dsn_test.go b/sqlconnect/internal/redshift/driver/dsn_test.go index 7b7503b..bb8f23b 100644 --- a/sqlconnect/internal/redshift/driver/dsn_test.go +++ b/sqlconnect/internal/redshift/driver/dsn_test.go @@ -102,6 +102,18 @@ func TestRedshiftDataConfig__String(t *testing.T) { }, expected: "admin@cluster(default)/dev?accessKeyId=accessKeyID®ion=us-east-1&secretAccessKey=secretAccessKey&sessionToken=sessionToken", }, + { + dsn: &RedshiftConfig{ + ClusterIdentifier: "default", + DbUser: "admin", + Database: "dev", + Region: "us-east-1", + RoleARN: "roleARN", + ExternalID: "externalID", + RoleARNExpiry: 15 * time.Minute, + }, + expected: "admin@cluster(default)/dev?externalID=externalID®ion=us-east-1&roleARN=roleARN&roleARNExpiry=15m0s", + }, } for _, c := range cases { diff --git a/sqlconnect/internal/redshift/integration_test.go b/sqlconnect/internal/redshift/integration_test.go index f9b3aa6..3762b16 100644 --- a/sqlconnect/internal/redshift/integration_test.go +++ b/sqlconnect/internal/redshift/integration_test.go @@ -40,23 +40,42 @@ func TestRedshiftDB(t *testing.T) { }) t.Run("redshift data driver", func(t *testing.T) { - configJSON, ok := os.LookupEnv("REDSHIFT_DATA_TEST_ENVIRONMENT_CREDENTIALS") - if !ok { - if os.Getenv("FORCE_RUN_INTEGRATION_TESTS") == "true" { - t.Fatal("REDSHIFT_DATA_TEST_ENVIRONMENT_CREDENTIALS environment variable not set") - } - t.Skip("skipping redshift data driver integration test due to lack of a test environment") - } - integrationtest.TestDatabaseScenarios( - t, - redshift.DatabaseType, - []byte(configJSON), - strings.ToLower, - integrationtest.Options{ - LegacySupport: true, - ExtraTests: ExtraTests, + testCases := []struct { + name string + credentialsKey string + }{ + { + name: "with AccessKeyID and cfg.SecretAccessKey", + credentialsKey: "REDSHIFT_DATA_TEST_ENVIRONMENT_CREDENTIALS", }, - ) + { + name: "with RoleARN", + credentialsKey: "REDSHIFT_DATA_TEST_ENVIRONMENT_ROLE_ARN_CREDENTIALS", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + configJSON, ok := os.LookupEnv(tc.credentialsKey) + if !ok { + if os.Getenv("FORCE_RUN_INTEGRATION_TESTS") == "true" { + t.Fatalf("%s environment variable not set", tc.credentialsKey) + } + t.Skipf("skipping redshift data driver integration test (%s) due to lack of a test environment", tc.name) + } + + integrationtest.TestDatabaseScenarios( + t, + redshift.DatabaseType, + []byte(configJSON), + strings.ToLower, + integrationtest.Options{ + LegacySupport: true, + ExtraTests: ExtraTests, + }, + ) + }) + } }) }