Skip to content

Commit

Permalink
feat: Implement iceberg retry logic (#657)
Browse files Browse the repository at this point in the history
Co-authored-by: nicor88 <[email protected]>
  • Loading branch information
svdimchenko and nicor88 committed May 28, 2024
1 parent 97430f9 commit 76f6d5f
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 36 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ You can either:
A dbt profile can be configured to run against AWS Athena using the following configuration:

| Option | Description | Required? | Example |
| --------------------- | ---------------------------------------------------------------------------------------- | --------- | ------------------------------------------ |
|-----------------------|------------------------------------------------------------------------------------------|-----------|--------------------------------------------|
| s3_staging_dir | S3 location to store Athena query results and metadata | Required | `s3://bucket/dbt/` |
| s3_data_dir | Prefix for storing tables, if different from the connection's `s3_staging_dir` | Optional | `s3://bucket2/dbt/` |
| s3_data_naming | How to generate table paths in `s3_data_dir` | Optional | `schema_table_unique` |
Expand All @@ -134,8 +134,9 @@ A dbt profile can be configured to run against AWS Athena using the following co
| aws_profile_name | Profile to use from your AWS shared credentials file | Optional | `my-profile` |
| work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` |
| num_retries | Number of times to retry a failing query | Optional | `3` |
| spark_work_group | Identifier of Athena Spark workgroup for running Python models | Optional | `my-spark-workgroup` |
| num_boto3_retries | Number of times to retry boto3 requests (e.g. deleting S3 files for materialized tables) | Optional | `5` |
| num_iceberg_retries | Number of times to retry iceberg commit queries to fix ICEBERG_COMMIT_ERROR | Optional | `0` |
| spark_work_group | Identifier of Athena Spark workgroup for running Python models | Optional | `my-spark-workgroup` |
| seed_s3_upload_args | Dictionary containing boto3 ExtraArgs when uploading to S3 | Optional | `{"ACL": "bucket-owner-full-control"}` |
| lf_tags_database | Default LF tags for new database if it's created by dbt | Optional | `tag_key: tag_value` |

Expand Down
87 changes: 53 additions & 34 deletions dbt/adapters/athena/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
from pyathena.result_set import AthenaResultSet
from pyathena.util import RetryConfig
from tenacity import (
Retrying,
retry,
retry_if_exception,
stop_after_attempt,
wait_random_exponential,
)
from typing_extensions import Self

from dbt.adapters.athena.config import get_boto3_config
from dbt.adapters.athena.constants import LOGGER
Expand Down Expand Up @@ -64,8 +65,9 @@ class AthenaCredentials(Credentials):
_ALIASES = {"catalog": "database"}
num_retries: int = 5
num_boto3_retries: Optional[int] = None
num_iceberg_retries: int = 3
s3_data_dir: Optional[str] = None
s3_data_naming: Optional[str] = "schema_table_unique"
s3_data_naming: str = "schema_table_unique"
spark_work_group: Optional[str] = None
s3_tmp_table_dir: Optional[str] = None
# Unfortunately we can not just use dict, must be Dict because we'll get the following error:
Expand Down Expand Up @@ -147,7 +149,7 @@ def __poll(self, query_id: str) -> AthenaQueryExecution:
LOGGER.debug(f"Query state is: {query_execution.state}. Sleeping for {self._poll_interval}...")
time.sleep(self._poll_interval)

def execute( # type: ignore
def execute(
self,
operation: str,
parameters: Optional[Dict[str, Any]] = None,
Expand All @@ -157,35 +159,9 @@ def execute( # type: ignore
cache_size: int = 0,
cache_expiration_time: int = 0,
catch_partitions_limit: bool = False,
**kwargs,
):
def inner() -> AthenaCursor:
query_id = self._execute(
operation,
parameters=parameters,
work_group=work_group,
s3_staging_dir=s3_staging_dir,
cache_size=cache_size,
cache_expiration_time=cache_expiration_time,
)

LOGGER.debug(f"Athena query ID {query_id}")

query_execution = self._executor.submit(self._collect_result_set, query_id).result()
if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED:
self.result_set = self._result_set_class(
self._connection,
self._converter,
query_execution,
self.arraysize,
self._retry_config,
)

else:
raise OperationalError(query_execution.state_change_reason)
return self

retry = Retrying(
**kwargs: Dict[str, Any],
) -> Self:
@retry(
# No need to retry if TOO_MANY_OPEN_PARTITIONS occurs.
# Otherwise, Athena throws ICEBERG_FILESYSTEM_ERROR after retry,
# because not all files are removed immediately after first try to create table
Expand All @@ -200,7 +176,47 @@ def inner() -> AthenaCursor:
),
reraise=True,
)
return retry(inner)
def inner() -> AthenaCursor:
num_iceberg_retries = self.connection.cursor_kwargs.get("num_iceberg_retries") + 1

@retry(
# Nested retry is needed to handle ICEBERG_COMMIT_ERROR for parallel inserts
retry=retry_if_exception(lambda e: "ICEBERG_COMMIT_ERROR" in str(e)),
stop=stop_after_attempt(num_iceberg_retries),
wait=wait_random_exponential(
multiplier=num_iceberg_retries,
max=self._retry_config.max_delay,
exp_base=self._retry_config.exponential_base,
),
reraise=True,
)
def execute_with_iceberg_retries() -> AthenaCursor:
query_id = self._execute(
operation,
parameters=parameters,
work_group=work_group,
s3_staging_dir=s3_staging_dir,
cache_size=cache_size,
cache_expiration_time=cache_expiration_time,
)

LOGGER.debug(f"Athena query ID {query_id}")

query_execution = self._executor.submit(self._collect_result_set, query_id).result()
if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED:
self.result_set = self._result_set_class(
self._connection,
self._converter,
query_execution,
self.arraysize,
self._retry_config,
)
return self
raise OperationalError(query_execution.state_change_reason)

return execute_with_iceberg_retries() # type: ignore

return inner() # type: ignore


class AthenaConnectionManager(SQLConnectionManager):
Expand Down Expand Up @@ -243,7 +259,10 @@ def open(cls, connection: Connection) -> Connection:
schema_name=creds.schema,
work_group=creds.work_group,
cursor_class=AthenaCursor,
cursor_kwargs={"debug_query_state": creds.debug_query_state},
cursor_kwargs={
"debug_query_state": creds.debug_query_state,
"num_iceberg_retries": creds.num_iceberg_retries,
},
formatter=AthenaParameterFormatter(),
poll_interval=creds.poll_interval,
session=get_boto3_session(connection),
Expand Down
135 changes: 135 additions & 0 deletions tests/functional/adapter/test_retries_iceberg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Test parallel insert into iceberg table."""
import copy
import os

import pytest

from dbt.artifacts.schemas.results import RunStatus
from dbt.tests.util import check_relations_equal, run_dbt, run_dbt_and_capture

PARALLELISM = 10

base_dbt_profile = {
"type": "athena",
"s3_staging_dir": os.getenv("DBT_TEST_ATHENA_S3_STAGING_DIR"),
"s3_tmp_table_dir": os.getenv("DBT_TEST_ATHENA_S3_TMP_TABLE_DIR"),
"schema": os.getenv("DBT_TEST_ATHENA_SCHEMA"),
"database": os.getenv("DBT_TEST_ATHENA_DATABASE"),
"region_name": os.getenv("DBT_TEST_ATHENA_REGION_NAME"),
"threads": PARALLELISM,
"poll_interval": float(os.getenv("DBT_TEST_ATHENA_POLL_INTERVAL", "1.0")),
"num_retries": 0,
"work_group": os.getenv("DBT_TEST_ATHENA_WORK_GROUP"),
"aws_profile_name": os.getenv("DBT_TEST_ATHENA_AWS_PROFILE_NAME") or None,
}

models__target = """
{{
config(
table_type='iceberg',
materialized='table'
)
}}
select * from (
values
(1, -1)
) as t (id, status)
limit 0
"""

models__source = {
f"model_{i}.sql": f"""
{{{{
config(
table_type='iceberg',
materialized='table',
tags=['src'],
pre_hook='insert into target values ({i}, {i})'
)
}}}}
select 1 as col
"""
for i in range(PARALLELISM)
}

seeds__expected_target_init = "id,status"
seeds__expected_target_post = "id,status\n" + "\n".join([f"{i},{i}" for i in range(PARALLELISM)])


class TestIcebergRetriesDisabled:
@pytest.fixture(scope="class")
def dbt_profile_target(self):
profile = copy.deepcopy(base_dbt_profile)
profile["num_iceberg_retries"] = 0
return profile

@pytest.fixture(scope="class")
def models(self):
return {**{"target.sql": models__target}, **models__source}

@pytest.fixture(scope="class")
def seeds(self):
return {
"expected_target_init.csv": seeds__expected_target_init,
"expected_target_post.csv": seeds__expected_target_post,
}

def test__retries_iceberg(self, project):
"""Seed should match the model after run"""

expected__init_seed_name = "expected_target_init"
run_dbt(["seed", "--select", expected__init_seed_name, "--full-refresh"])

relation_name = "target"
model_run = run_dbt(["run", "--select", relation_name])
model_run_result = model_run.results[0]
assert model_run_result.status == RunStatus.Success
check_relations_equal(project.adapter, [relation_name, expected__init_seed_name])

expected__post_seed_name = "expected_target_post"
run_dbt(["seed", "--select", expected__post_seed_name, "--full-refresh"])

run, log = run_dbt_and_capture(["run", "--select", "tag:src"], expect_pass=False)
assert any(model_run_result.status == RunStatus.Error for model_run_result in run.results)
assert "ICEBERG_COMMIT_ERROR" in log


class TestIcebergRetriesEnabled:
@pytest.fixture(scope="class")
def dbt_profile_target(self):
profile = copy.deepcopy(base_dbt_profile)
profile["num_iceberg_retries"] = 1
return profile

@pytest.fixture(scope="class")
def models(self):
return {**{"target.sql": models__target}, **models__source}

@pytest.fixture(scope="class")
def seeds(self):
return {
"expected_target_init.csv": seeds__expected_target_init,
"expected_target_post.csv": seeds__expected_target_post,
}

def test__retries_iceberg(self, project):
"""Seed should match the model after run"""

expected__init_seed_name = "expected_target_init"
run_dbt(["seed", "--select", expected__init_seed_name, "--full-refresh"])

relation_name = "target"
model_run = run_dbt(["run", "--select", relation_name])
model_run_result = model_run.results[0]
assert model_run_result.status == RunStatus.Success
check_relations_equal(project.adapter, [relation_name, expected__init_seed_name])

expected__post_seed_name = "expected_target_post"
run_dbt(["seed", "--select", expected__post_seed_name, "--full-refresh"])

run = run_dbt(["run", "--select", "tag:src"])
assert all([model_run_result.status == RunStatus.Success for model_run_result in run.results])
check_relations_equal(project.adapter, [relation_name, expected__post_seed_name])

0 comments on commit 76f6d5f

Please sign in to comment.