Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: skip_workgroup_check setting to reduce AWS throttling #713

Merged
merged 7 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ A dbt profile can be configured to run against AWS Athena using the following co
| aws_secret_access_key | Secret access key of the user performing requests | Optional | `wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY` |
| aws_profile_name | Profile to use from your AWS shared credentials file | Optional | `my-profile` |
| work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` |
| skip_workgroup_check | Indicates if the WorkGroup check (additional AWS call) can be skipped | Optional | `true` |
| num_retries | Number of times to retry a failing query | Optional | `3` |
| num_boto3_retries | Number of times to retry boto3 requests (e.g. deleting S3 files for materialized tables) | Optional | `5` |
| num_iceberg_retries | Number of times to retry iceberg commit queries to fix ICEBERG_COMMIT_ERROR | Optional | `3` |
Expand Down
2 changes: 2 additions & 0 deletions dbt/adapters/athena/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class AthenaCredentials(Credentials):
region_name: str
endpoint_url: Optional[str] = None
work_group: Optional[str] = None
skip_workgroup_check: bool = False
aws_profile_name: Optional[str] = None
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
Expand Down Expand Up @@ -91,6 +92,7 @@ def _connection_keys(self) -> Tuple[str, ...]:
return (
"s3_staging_dir",
"work_group",
"skip_workgroup_check",
"region_name",
"database",
"schema",
Expand Down
4 changes: 3 additions & 1 deletion dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class AthenaConfig(AdapterConfig):

Args:
work_group: Identifier of Athena workgroup.
skip_workgroup_check: Indicates if the WorkGroup check (additional AWS call) can be skipped.
s3_staging_dir: S3 location to store Athena query results and metadata.
external_location: If set, the full S3 path in which the table will be saved.
partitioned_by: An array list of columns by which the table will be partitioned.
Expand All @@ -102,6 +103,7 @@ class AthenaConfig(AdapterConfig):
"""

work_group: Optional[str] = None
skip_workgroup_check: bool = False
s3_staging_dir: Optional[str] = None
external_location: Optional[str] = None
partitioned_by: Optional[str] = None
Expand Down Expand Up @@ -240,7 +242,7 @@ def is_work_group_output_location_enforced(self) -> bool:
conn = self.connections.get_thread_connection()
creds = conn.credentials

if creds.work_group:
if creds.work_group and not creds.skip_workgroup_check:
work_group = self._get_work_group(creds.work_group)
output_location = (
work_group.get("WorkGroup", {})
Expand Down
63 changes: 41 additions & 22 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,7 @@

class TestAthenaAdapter:
def setup_method(self, _):
project_cfg = {
"name": "X",
"version": "0.1",
"profile": "test",
"project-root": "/tmp/dbt/does-not-exist",
"config-version": 2,
}
profile_cfg = {
"outputs": {
"test": {
"type": "athena",
"s3_staging_dir": S3_STAGING_DIR,
"region_name": AWS_REGION,
"database": DATA_CATALOG_NAME,
"work_group": ATHENA_WORKGROUP,
"schema": DATABASE_NAME,
}
},
"target": "test",
}

self.config = config_from_parts_or_dicts(project_cfg, profile_cfg)
self.config = TestAthenaAdapter._config_from_settings()
self._adapter = None
self.used_schemas = frozenset(
{
Expand All @@ -79,6 +58,35 @@ def adapter(self):
inject_adapter(self._adapter, AthenaPlugin)
return self._adapter

@staticmethod
def _config_from_settings(settings={}):
project_cfg = {
"name": "X",
"version": "0.1",
"profile": "test",
"project-root": "/tmp/dbt/does-not-exist",
"config-version": 2,
}

profile_cfg = {
"outputs": {
"test": {
**{
"type": "athena",
"s3_staging_dir": S3_STAGING_DIR,
"region_name": AWS_REGION,
"database": DATA_CATALOG_NAME,
"work_group": ATHENA_WORKGROUP,
"schema": DATABASE_NAME,
},
**settings,
}
},
"target": "test",
}

return config_from_parts_or_dicts(project_cfg, profile_cfg)

@mock.patch("dbt.adapters.athena.connections.AthenaConnection")
def test_acquire_connection_validations(self, connection_cls):
try:
Expand Down Expand Up @@ -931,6 +939,17 @@ def test_get_work_group_output_location(self, mock_aws_service):
work_group_location_enforced = self.adapter.is_work_group_output_location_enforced()
assert work_group_location_enforced

def test_get_work_group_output_location_if_workgroup_check_is_skipepd(self):
settings = {
"skip_workgroup_check": True,
}

self.config = TestAthenaAdapter._config_from_settings(settings)
self.adapter.acquire_connection("dummy")

work_group_location_enforced = self.adapter.is_work_group_output_location_enforced()
assert not work_group_location_enforced

@mock_aws
def test_get_work_group_output_location_no_location(self, mock_aws_service):
self.adapter.acquire_connection("dummy")
Expand Down
Loading