diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
index c6354a37..8cdee2f4 100644
--- a/.github/CODEOWNERS
+++ b/.github/CODEOWNERS
@@ -1 +1 @@
-* @jessedobbelaere @Jrmyy @mattiamatrix @nicor88 @svdimchenko @thenaturalist
+* @jessedobbelaere @Jrmyy @mattiamatrix @nicor88 @svdimchenko
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 0e2fef2a..ab75579a 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -21,7 +21,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
+ python-version: ['3.8', '3.9', '3.10', '3.11']
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
diff --git a/README.md b/README.md
index ccc82e72..133290a6 100644
--- a/README.md
+++ b/README.md
@@ -1,15 +1,19 @@
+
+
+
-
+
## Features
-* Supports dbt version `1.5.*`
+* Supports dbt version `1.6.*`
+* Supports from Python
* Supports [seeds][seeds]
* Correctly detects views and their columns
* Supports [table materialization][table]
diff --git a/dbt/adapters/athena/__version__.py b/dbt/adapters/athena/__version__.py
index 0c46db4f..38ec8ede 100644
--- a/dbt/adapters/athena/__version__.py
+++ b/dbt/adapters/athena/__version__.py
@@ -1 +1 @@
-version = "1.5.1"
+version = "1.6.0"
diff --git a/dbt/adapters/athena/connections.py b/dbt/adapters/athena/connections.py
index 67e2e71e..8645f8bc 100644
--- a/dbt/adapters/athena/connections.py
+++ b/dbt/adapters/athena/connections.py
@@ -1,4 +1,6 @@
import hashlib
+import json
+import re
import time
from concurrent.futures.thread import ThreadPoolExecutor
from contextlib import contextmanager
@@ -132,6 +134,7 @@ def execute( # type: ignore
endpoint_url: Optional[str] = None,
cache_size: int = 0,
cache_expiration_time: int = 0,
+ catch_partitions_limit: bool = False,
**kwargs,
):
def inner() -> AthenaCursor:
@@ -158,7 +161,12 @@ def inner() -> AthenaCursor:
return self
retry = tenacity.Retrying(
- retry=retry_if_exception(lambda _: True),
+ # No need to retry if TOO_MANY_OPEN_PARTITIONS occurs.
+ # Otherwise, Athena throws ICEBERG_FILESYSTEM_ERROR after retry,
+ # because not all files are removed immediately after first try to create table
+ retry=retry_if_exception(
+ lambda e: False if catch_partitions_limit and "TOO_MANY_OPEN_PARTITIONS" in str(e) else True
+ ),
stop=stop_after_attempt(self._retry_config.attempt),
wait=wait_exponential(
multiplier=self._retry_config.attempt,
@@ -231,15 +239,37 @@ def open(cls, connection: Connection) -> Connection:
@classmethod
def get_response(cls, cursor: AthenaCursor) -> AthenaAdapterResponse:
code = "OK" if cursor.state == AthenaQueryExecution.STATE_SUCCEEDED else "ERROR"
+ rowcount, data_scanned_in_bytes = cls.process_query_stats(cursor)
return AthenaAdapterResponse(
- _message=f"{code} {cursor.rowcount}",
- rows_affected=cursor.rowcount,
+ _message=f"{code} {rowcount}",
+ rows_affected=rowcount,
code=code,
- data_scanned_in_bytes=cursor.data_scanned_in_bytes,
+ data_scanned_in_bytes=data_scanned_in_bytes,
)
+ @staticmethod
+ def process_query_stats(cursor: AthenaCursor) -> Tuple[int, int]:
+ """
+ Helper function to parse query statistics from SELECT statements.
+ The function looks for all statements that contains rowcount or data_scanned_in_bytes,
+ then strip the SELECT statements, and pick the value between curly brackets.
+ """
+ if all(map(cursor.query.__contains__, ["rowcount", "data_scanned_in_bytes"])):
+ try:
+ query_split = cursor.query.lower().split("select")[-1]
+ # query statistics are in the format {"rowcount":1, "data_scanned_in_bytes": 3}
+ # the following statement extract the content between { and }
+ query_stats = re.search("{(.*)}", query_split)
+ if query_stats:
+ stats = json.loads("{" + query_stats.group(1) + "}")
+ return stats.get("rowcount", -1), stats.get("data_scanned_in_bytes", 0)
+ except Exception as err:
+ logger.debug(f"There was an error parsing query stats {err}")
+ return -1, 0
+ return cursor.rowcount, cursor.data_scanned_in_bytes
+
def cancel(self, connection: Connection) -> None:
- connection.handle.cancel()
+ pass
def add_begin_query(self) -> None:
pass
diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py
index fbc2491a..253e8d0c 100755
--- a/dbt/adapters/athena/impl.py
+++ b/dbt/adapters/athena/impl.py
@@ -1,6 +1,7 @@
import csv
import os
import posixpath as path
+import re
import tempfile
from itertools import chain
from textwrap import dedent
@@ -19,6 +20,7 @@
TableTypeDef,
TableVersionTypeDef,
)
+from pyathena.error import OperationalError
from dbt.adapters.athena import AthenaConnectionManager
from dbt.adapters.athena.column import AthenaColumn
@@ -42,7 +44,13 @@
get_table_type,
)
from dbt.adapters.athena.s3 import S3DataNaming
-from dbt.adapters.athena.utils import clean_sql_comment, get_catalog_id, get_chunks
+from dbt.adapters.athena.utils import (
+ AthenaCatalogType,
+ clean_sql_comment,
+ get_catalog_id,
+ get_catalog_type,
+ get_chunks,
+)
from dbt.adapters.base import ConstraintSupport, available
from dbt.adapters.base.relation import BaseRelation, InformationSchema
from dbt.adapters.sql import SQLAdapter
@@ -413,6 +421,29 @@ def _get_one_table_for_catalog(self, table: TableTypeDef, database: str) -> List
for idx, col in enumerate(table["StorageDescriptor"]["Columns"] + table.get("PartitionKeys", []))
]
+ def _get_one_table_for_non_glue_catalog(
+ self, table: TableTypeDef, schema: str, database: str
+ ) -> List[Dict[str, Any]]:
+ table_catalog = {
+ "table_database": database,
+ "table_schema": schema,
+ "table_name": table["Name"],
+ "table_type": RELATION_TYPE_MAP[table.get("TableType", "EXTERNAL_TABLE")].value,
+ "table_comment": table.get("Parameters", {}).get("comment", ""),
+ }
+ return [
+ {
+ **table_catalog,
+ **{
+ "column_name": col["Name"],
+ "column_index": idx,
+ "column_type": col["Type"],
+ "column_comment": col.get("Comment", ""),
+ },
+ }
+ for idx, col in enumerate(table["Columns"] + table.get("PartitionKeys", []))
+ ]
+
def _get_one_catalog(
self,
information_schema: InformationSchema,
@@ -420,29 +451,55 @@ def _get_one_catalog(
manifest: Manifest,
) -> agate.Table:
data_catalog = self._get_data_catalog(information_schema.path.database)
- catalog_id = get_catalog_id(data_catalog)
+ data_catalog_type = get_catalog_type(data_catalog)
+
conn = self.connections.get_thread_connection()
client = conn.handle
- with boto3_client_lock:
- glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
-
- catalog = []
- paginator = glue_client.get_paginator("get_tables")
- for schema, relations in schemas.items():
- kwargs = {
- "DatabaseName": schema,
- "MaxResults": 100,
- }
- # If the catalog is `awsdatacatalog` we don't need to pass CatalogId as boto3 infers it from the account Id.
- if catalog_id:
- kwargs["CatalogId"] = catalog_id
+ if data_catalog_type == AthenaCatalogType.GLUE:
+ with boto3_client_lock:
+ glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())
+
+ catalog = []
+ paginator = glue_client.get_paginator("get_tables")
+ for schema, relations in schemas.items():
+ kwargs = {
+ "DatabaseName": schema,
+ "MaxResults": 100,
+ }
+ # If the catalog is `awsdatacatalog` we don't need to pass CatalogId as boto3
+ # infers it from the account Id.
+ catalog_id = get_catalog_id(data_catalog)
+ if catalog_id:
+ kwargs["CatalogId"] = catalog_id
+
+ for page in paginator.paginate(**kwargs):
+ for table in page["TableList"]:
+ if relations and table["Name"] in relations:
+ catalog.extend(self._get_one_table_for_catalog(table, information_schema.path.database))
+ table = agate.Table.from_object(catalog)
+ else:
+ with boto3_client_lock:
+ athena_client = client.session.client(
+ "athena", region_name=client.region_name, config=get_boto3_config()
+ )
- for page in paginator.paginate(**kwargs):
- for table in page["TableList"]:
- if relations and table["Name"] in relations:
- catalog.extend(self._get_one_table_for_catalog(table, information_schema.path.database))
+ catalog = []
+ paginator = athena_client.get_paginator("list_table_metadata")
+ for schema, relations in schemas.items():
+ for page in paginator.paginate(
+ CatalogName=information_schema.path.database,
+ DatabaseName=schema,
+ MaxResults=50, # Limit supported by this operation
+ ):
+ for table in page["TableMetadataList"]:
+ if relations and table["Name"] in relations:
+ catalog.extend(
+ self._get_one_table_for_non_glue_catalog(
+ table, schema, information_schema.path.database
+ )
+ )
+ table = agate.Table.from_object(catalog)
- table = agate.Table.from_object(catalog)
filtered_table = self._catalog_filter_table(table, manifest)
return self._join_catalog_table_owners(filtered_table, manifest)
@@ -912,3 +969,26 @@ def _get_table_input(table: TableTypeDef) -> TableInputTypeDef:
returned by get_table() method.
"""
return {k: v for k, v in table.items() if k in TableInputTypeDef.__annotations__}
+
+ @available
+ def run_query_with_partitions_limit_catching(self, sql: str) -> str:
+ conn = self.connections.get_thread_connection()
+ cursor = conn.handle.cursor()
+ try:
+ cursor.execute(sql, catch_partitions_limit=True)
+ except OperationalError as e:
+ LOGGER.debug(f"CAUGHT EXCEPTION: {e}")
+ if "TOO_MANY_OPEN_PARTITIONS" in str(e):
+ return "TOO_MANY_OPEN_PARTITIONS"
+ raise e
+ return f'{{"rowcount":{cursor.rowcount},"data_scanned_in_bytes":{cursor.data_scanned_in_bytes}}}'
+
+ @available
+ def format_partition_keys(self, partition_keys: List[str]) -> str:
+ return ", ".join([self.format_one_partition_key(k) for k in partition_keys])
+
+ @available
+ def format_one_partition_key(self, partition_key: str) -> str:
+ """Check if partition key uses Iceberg hidden partitioning"""
+ hidden = re.search(r"^(hour|day|month|year)\((.+)\)", partition_key.lower())
+ return f"date_trunc('{hidden.group(1)}', {hidden.group(2)})" if hidden else partition_key.lower()
diff --git a/dbt/adapters/athena/relation.py b/dbt/adapters/athena/relation.py
index bb07393f..680f9e09 100644
--- a/dbt/adapters/athena/relation.py
+++ b/dbt/adapters/athena/relation.py
@@ -79,6 +79,7 @@ def add(self, relation: AthenaRelation) -> None:
RELATION_TYPE_MAP = {
"EXTERNAL_TABLE": TableType.TABLE,
+ "EXTERNAL": TableType.TABLE, # type returned by federated query tables
"MANAGED_TABLE": TableType.TABLE,
"VIRTUAL_VIEW": TableType.VIEW,
"table": TableType.TABLE,
diff --git a/dbt/adapters/athena/utils.py b/dbt/adapters/athena/utils.py
index 778fb4c2..dcd74916 100644
--- a/dbt/adapters/athena/utils.py
+++ b/dbt/adapters/athena/utils.py
@@ -1,3 +1,4 @@
+from enum import Enum
from typing import Generator, List, Optional, TypeVar
from mypy_boto3_athena.type_defs import DataCatalogTypeDef
@@ -9,7 +10,17 @@ def clean_sql_comment(comment: str) -> str:
def get_catalog_id(catalog: Optional[DataCatalogTypeDef]) -> Optional[str]:
- return catalog["Parameters"]["catalog-id"] if catalog else None
+ return catalog["Parameters"]["catalog-id"] if catalog and catalog["Type"] == AthenaCatalogType.GLUE.value else None
+
+
+class AthenaCatalogType(Enum):
+ GLUE = "GLUE"
+ LAMBDA = "LAMBDA"
+ HIVE = "HIVE"
+
+
+def get_catalog_type(catalog: Optional[DataCatalogTypeDef]) -> Optional[AthenaCatalogType]:
+ return AthenaCatalogType(catalog["Type"]) if catalog else None
T = TypeVar("T")
diff --git a/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql b/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql
new file mode 100644
index 00000000..20773841
--- /dev/null
+++ b/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql
@@ -0,0 +1,52 @@
+{% macro get_partition_batches(sql) -%}
+ {%- set partitioned_by = config.get('partitioned_by') -%}
+ {%- set athena_partitions_limit = config.get('partitions_limit', 100) | int -%}
+ {%- set partitioned_keys = adapter.format_partition_keys(partitioned_by) -%}
+ {% do log('PARTITIONED KEYS: ' ~ partitioned_keys) %}
+
+ {% call statement('get_partitions', fetch_result=True) %}
+ select distinct {{ partitioned_keys }} from ({{ sql }}) order by {{ partitioned_keys }};
+ {% endcall %}
+
+ {%- set table = load_result('get_partitions').table -%}
+ {%- set rows = table.rows -%}
+ {%- set partitions = {} -%}
+ {% do log('TOTAL PARTITIONS TO PROCESS: ' ~ rows | length) %}
+ {%- set partitions_batches = [] -%}
+
+ {%- for row in rows -%}
+ {%- set single_partition = [] -%}
+ {%- for col in row -%}
+
+ {%- set column_type = adapter.convert_type(table, loop.index0) -%}
+ {%- if column_type == 'integer' -%}
+ {%- set value = col | string -%}
+ {%- elif column_type == 'string' -%}
+ {%- set value = "'" + col + "'" -%}
+ {%- elif column_type == 'date' -%}
+ {%- set value = "DATE'" + col | string + "'" -%}
+ {%- elif column_type == 'timestamp' -%}
+ {%- set value = "TIMESTAMP'" + col | string + "'" -%}
+ {%- else -%}
+ {%- do exceptions.raise_compiler_error('Need to add support for column type ' + column_type) -%}
+ {%- endif -%}
+ {%- set partition_key = adapter.format_one_partition_key(partitioned_by[loop.index0]) -%}
+ {%- do single_partition.append(partition_key + '=' + value) -%}
+ {%- endfor -%}
+
+ {%- set single_partition_expression = single_partition | join(' and ') -%}
+
+ {%- set batch_number = (loop.index0 / athena_partitions_limit) | int -%}
+ {% if not batch_number in partitions %}
+ {% do partitions.update({batch_number: []}) %}
+ {% endif %}
+
+ {%- do partitions[batch_number].append('(' + single_partition_expression + ')') -%}
+ {%- if partitions[batch_number] | length == athena_partitions_limit or loop.last -%}
+ {%- do partitions_batches.append(partitions[batch_number] | join(' or ')) -%}
+ {%- endif -%}
+ {%- endfor -%}
+
+ {{ return(partitions_batches) }}
+
+{%- endmacro %}
diff --git a/dbt/include/athena/macros/materializations/models/incremental/helpers.sql b/dbt/include/athena/macros/materializations/models/incremental/helpers.sql
index a965e85c..76c2ed73 100644
--- a/dbt/include/athena/macros/materializations/models/incremental/helpers.sql
+++ b/dbt/include/athena/macros/materializations/models/incremental/helpers.sql
@@ -22,19 +22,42 @@
{% endmacro %}
{% macro incremental_insert(on_schema_change, tmp_relation, target_relation, existing_relation, statement_name="main") %}
- {% set dest_columns = process_schema_changes(on_schema_change, tmp_relation, existing_relation) %}
- {% if not dest_columns %}
+ {%- set dest_columns = process_schema_changes(on_schema_change, tmp_relation, existing_relation) -%}
+ {%- if not dest_columns -%}
{%- set dest_columns = adapter.get_columns_in_relation(target_relation) -%}
- {% endif %}
+ {%- endif -%}
{%- set dest_cols_csv = dest_columns | map(attribute='quoted') | join(', ') -%}
- insert into {{ target_relation }} ({{ dest_cols_csv }})
- (
- select {{ dest_cols_csv }}
- from {{ tmp_relation }}
- );
+ {%- set insert_full -%}
+ insert into {{ target_relation }} ({{ dest_cols_csv }})
+ (
+ select {{ dest_cols_csv }}
+ from {{ tmp_relation }}
+ );
+ {%- endset -%}
+
+ {%- set query_result = adapter.run_query_with_partitions_limit_catching(insert_full) -%}
+ {%- do log('QUERY RESULT: ' ~ query_result) -%}
+ {%- if query_result == 'TOO_MANY_OPEN_PARTITIONS' -%}
+ {% set partitions_batches = get_partition_batches(tmp_relation) %}
+ {% do log('BATCHES TO PROCESS: ' ~ partitions_batches | length) %}
+ {%- for batch in partitions_batches -%}
+ {%- do log('BATCH PROCESSING: ' ~ loop.index ~ ' OF ' ~ partitions_batches|length) -%}
+ {%- set insert_batch_partitions -%}
+ insert into {{ target_relation }} ({{ dest_cols_csv }})
+ (
+ select {{ dest_cols_csv }}
+ from {{ tmp_relation }}
+ where {{ batch }}
+ );
+ {%- endset -%}
+ {%- do run_query(insert_batch_partitions) -%}
+ {%- endfor -%}
+ {%- endif -%}
+ SELECT '{{query_result}}'
{%- endmacro %}
+
{% macro delete_overlapping_partitions(target_relation, tmp_relation, partitioned_by) %}
{%- set partitioned_keys = partitioned_by | tojson | replace('\"', '') | replace('[', '') | replace(']', '') -%}
{% call statement('get_partitions', fetch_result=True) %}
diff --git a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql
index 879aa335..b42c7cb9 100644
--- a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql
+++ b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql
@@ -7,7 +7,7 @@
{% set lf_tags_config = config.get('lf_tags_config') %}
{% set lf_grants = config.get('lf_grants') %}
- {% set partitioned_by = config.get('partitioned_by', default=none) %}
+ {% set partitioned_by = config.get('partitioned_by') %}
{% set target_relation = this.incorporate(type='table') %}
{% set existing_relation = load_relation(this) %}
{% set tmp_relation = make_temp_relation(this) %}
@@ -24,16 +24,18 @@
{% set to_drop = [] %}
{% if existing_relation is none %}
- {% set build_sql = create_table_as(False, target_relation, sql) -%}
+ {% set query_result = safe_create_table_as(False, target_relation, sql) -%}
+ {% set build_sql = "select '{{ query_result }}'" -%}
{% elif existing_relation.is_view or should_full_refresh() %}
{% do drop_relation(existing_relation) %}
- {% set build_sql = create_table_as(False, target_relation, sql) -%}
+ {% set query_result = safe_create_table_as(False, target_relation, sql) -%}
+ {% set build_sql = "select '{{ query_result }}'" -%}
{% elif partitioned_by is not none and strategy == 'insert_overwrite' %}
{% set tmp_relation = make_temp_relation(target_relation) %}
{% if tmp_relation is not none %}
{% do drop_relation(tmp_relation) %}
{% endif %}
- {% do run_query(create_table_as(True, tmp_relation, sql)) %}
+ {% set query_result = safe_create_table_as(True, tmp_relation, sql) -%}
{% do delete_overlapping_partitions(target_relation, tmp_relation, partitioned_by) %}
{% set build_sql = incremental_insert(on_schema_change, tmp_relation, target_relation, existing_relation) %}
{% do to_drop.append(tmp_relation) %}
@@ -42,7 +44,7 @@
{% if tmp_relation is not none %}
{% do drop_relation(tmp_relation) %}
{% endif %}
- {% do run_query(create_table_as(True, tmp_relation, sql)) %}
+ {% set query_result = safe_create_table_as(True, tmp_relation, sql) -%}
{% set build_sql = incremental_insert(on_schema_change, tmp_relation, target_relation, existing_relation) %}
{% do to_drop.append(tmp_relation) %}
{% elif strategy == 'merge' and table_type == 'iceberg' %}
@@ -67,7 +69,7 @@
{% if tmp_relation is not none %}
{% do drop_relation(tmp_relation) %}
{% endif %}
- {% do run_query(create_table_as(True, tmp_relation, sql)) %}
+ {% set query_result = safe_create_table_as(True, tmp_relation, sql) -%}
{% set build_sql = iceberg_merge(on_schema_change, tmp_relation, target_relation, unique_key, incremental_predicates, existing_relation, delete_condition) %}
{% do to_drop.append(tmp_relation) %}
{% endif %}
diff --git a/dbt/include/athena/macros/materializations/models/incremental/merge.sql b/dbt/include/athena/macros/materializations/models/incremental/merge.sql
index cd06cb03..07d8b886 100644
--- a/dbt/include/athena/macros/materializations/models/incremental/merge.sql
+++ b/dbt/include/athena/macros/materializations/models/incremental/merge.sql
@@ -73,33 +73,66 @@
{%- endfor -%}
{%- set update_columns = get_merge_update_columns(merge_update_columns, merge_exclude_columns, dest_columns_wo_keys) -%}
{%- set src_cols_csv = src_columns_quoted | join(', ') -%}
- merge into {{ target_relation }} as target using {{ tmp_relation }} as src
- on (
- {%- for key in unique_key_cols %}
- target.{{ key }} = src.{{ key }} {{ "and " if not loop.last }}
- {%- endfor %}
- )
- {% if incremental_predicates is not none -%}
- and (
- {%- for inc_predicate in incremental_predicates %}
- {{ inc_predicate }} {{ "and " if not loop.last }}
- {%- endfor %}
- )
- {%- endif %}
- {% if delete_condition is not none -%}
- when matched and ({{ delete_condition }})
- then delete
- {%- endif %}
- when matched
- then update set
- {%- for col in update_columns %}
- {%- if merge_update_columns_rules and col.name in merge_update_columns_rules %}
- {{ get_update_statement(col, merge_update_columns_rules[col.name], loop.last) }}
- {%- else -%}
- {{ get_update_statement(col, merge_update_columns_default_rule, loop.last) }}
- {%- endif -%}
- {%- endfor %}
- when not matched
- then insert ({{ dest_cols_csv }})
- values ({{ src_cols_csv }});
+
+ {%- set src_part -%}
+ merge into {{ target_relation }} as target using {{ tmp_relation }} as src
+ {%- endset -%}
+
+ {%- set merge_part -%}
+ on (
+ {%- for key in unique_key_cols -%}
+ target.{{ key }} = src.{{ key }}
+ {{ " and " if not loop.last }}
+ {%- endfor -%}
+ {% if incremental_predicates is not none -%}
+ and (
+ {%- for inc_predicate in incremental_predicates %}
+ {{ inc_predicate }} {{ "and " if not loop.last }}
+ {%- endfor %}
+ )
+ {%- endif %}
+ )
+ {% if delete_condition is not none -%}
+ when matched and ({{ delete_condition }})
+ then delete
+ {%- endif %}
+ when matched
+ then update set
+ {%- for col in update_columns %}
+ {%- if merge_update_columns_rules and col.name in merge_update_columns_rules %}
+ {{ get_update_statement(col, merge_update_columns_rules[col.name], loop.last) }}
+ {%- else -%}
+ {{ get_update_statement(col, merge_update_columns_default_rule, loop.last) }}
+ {%- endif -%}
+ {%- endfor %}
+ when not matched
+ then insert ({{ dest_cols_csv }})
+ values ({{ src_cols_csv }})
+ {%- endset -%}
+
+ {%- set merge_full -%}
+ {{ src_part }}
+ {{ merge_part }}
+ {%- endset -%}
+
+ {%- set query_result = adapter.run_query_with_partitions_limit_catching(merge_full) -%}
+ {%- do log('QUERY RESULT: ' ~ query_result) -%}
+ {%- if query_result == 'TOO_MANY_OPEN_PARTITIONS' -%}
+ {% set partitions_batches = get_partition_batches(tmp_relation) %}
+ {% do log('BATCHES TO PROCESS: ' ~ partitions_batches | length) %}
+ {%- for batch in partitions_batches -%}
+ {%- do log('BATCH PROCESSING: ' ~ loop.index ~ ' OF ' ~ partitions_batches | length) -%}
+ {%- set src_batch_part -%}
+ merge into {{ target_relation }} as target
+ using (select * from {{ tmp_relation }} where {{ batch }}) as src
+ {%- endset -%}
+ {%- set merge_batch -%}
+ {{ src_batch_part }}
+ {{ merge_part }}
+ {%- endset -%}
+ {%- do run_query(merge_batch) -%}
+ {%- endfor -%}
+ {%- endif -%}
+
+ SELECT '{{query_result}}'
{%- endmacro %}
diff --git a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql
index af730a30..878a0378 100644
--- a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql
+++ b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql
@@ -87,3 +87,44 @@
as
{{ sql }}
{% endmacro %}
+
+{% macro create_table_as_with_partitions(temporary, relation, sql) -%}
+
+ {% set partitions_batches = get_partition_batches(sql) %}
+ {% do log('BATCHES TO PROCESS: ' ~ partitions_batches | length) %}
+
+ {%- do log('CREATE EMPTY TABLE: ' ~ relation) -%}
+ {%- set create_empty_table_query -%}
+ {{ create_table_as(temporary, relation, sql) }}
+ limit 0
+ {%- endset -%}
+ {%- do run_query(create_empty_table_query) -%}
+ {%- set dest_columns = adapter.get_columns_in_relation(relation) -%}
+ {%- set dest_cols_csv = dest_columns | map(attribute='quoted') | join(', ') -%}
+
+ {%- for batch in partitions_batches -%}
+ {%- do log('BATCH PROCESSING: ' ~ loop.index ~ ' OF ' ~ partitions_batches | length) -%}
+
+ {%- set insert_batch_partitions -%}
+ insert into {{ relation }} ({{ dest_cols_csv }})
+ select {{ dest_cols_csv }}
+ from ({{ sql }})
+ where {{ batch }}
+ {%- endset -%}
+
+ {%- do run_query(insert_batch_partitions) -%}
+ {%- endfor -%}
+
+ select 'SUCCESSFULLY CREATED TABLE {{ relation }}'
+
+{%- endmacro %}
+
+{% macro safe_create_table_as(temporary, relation, sql) -%}
+ {%- set query_result = adapter.run_query_with_partitions_limit_catching(create_table_as(temporary, relation, sql)) -%}
+ {%- do log('QUERY RESULT: ' ~ query_result) -%}
+ {%- if query_result == 'TOO_MANY_OPEN_PARTITIONS' -%}
+ {%- do create_table_as_with_partitions(temporary, relation, sql) -%}
+ {%- set query_result = '{{ relation }} with many partitions created' -%}
+ {%- endif -%}
+ {{ return(query_result) }}
+{%- endmacro %}
diff --git a/dbt/include/athena/macros/materializations/models/table/table.sql b/dbt/include/athena/macros/materializations/models/table/table.sql
index 7f9856e4..989bf63b 100644
--- a/dbt/include/athena/macros/materializations/models/table/table.sql
+++ b/dbt/include/athena/macros/materializations/models/table/table.sql
@@ -49,28 +49,22 @@
{%- endif -%}
-- create tmp table
- {% call statement('main') -%}
- {{ create_table_as(False, tmp_relation, sql) }}
- {%- endcall %}
+ {%- set query_result = safe_create_table_as(False, tmp_relation, sql) -%}
-- swap table
- {%- set swap_table = adapter.swap_table(tmp_relation,
- target_relation) -%}
+ {%- set swap_table = adapter.swap_table(tmp_relation, target_relation) -%}
-- delete glue tmp table, do not use drop_relation, as it will remove data of the target table
{%- do adapter.delete_from_glue_catalog(tmp_relation) -%}
- {% do adapter.expire_glue_table_versions(target_relation,
- versions_to_keep,
- True) %}
+ {% do adapter.expire_glue_table_versions(target_relation, versions_to_keep, True) %}
+
{%- else -%}
-- Here we are in the case of non-ha tables or ha tables but in case of full refresh.
{%- if old_relation is not none -%}
{{ drop_relation(old_relation) }}
{%- endif -%}
- {%- call statement('main') -%}
- {{ create_table_as(False, target_relation, sql) }}
- {%- endcall %}
+ {%- set query_result = safe_create_table_as(False, target_relation, sql) -%}
{%- endif -%}
{{ set_table_classification(target_relation) }}
@@ -78,14 +72,10 @@
{%- else -%}
{%- if old_relation is none -%}
- {%- call statement('main') -%}
- {{ create_table_as(False, target_relation, sql) }}
- {%- endcall %}
+ {%- set query_result = safe_create_table_as(False, target_relation, sql) -%}
{%- else -%}
{%- if old_relation.is_view -%}
- {%- call statement('main') -%}
- {{ create_table_as(False, tmp_relation, sql) }}
- {%- endcall -%}
+ {%- set query_result = safe_create_table_as(False, tmp_relation, sql) -%}
{%- do drop_relation(old_relation) -%}
{%- do rename_relation(tmp_relation, target_relation) -%}
{%- else -%}
@@ -103,9 +93,7 @@
{%- do drop_relation(old_relation_bkp) -%}
{%- endif -%}
- {%- call statement('main') -%}
- {{ create_table_as(False, tmp_relation, sql) }}
- {%- endcall -%}
+ {% set query_result = safe_create_table_as(False, tmp_relation, sql) %}
{{ rename_relation(old_relation, old_relation_bkp) }}
{{ rename_relation(tmp_relation, target_relation) }}
@@ -116,6 +104,10 @@
{%- endif -%}
+ {% call statement("main") %}
+ SELECT '{{ query_result }}';
+ {% endcall %}
+
{{ run_hooks(post_hooks) }}
{% if lf_tags_config is not none %}
diff --git a/dbt/include/athena/macros/utils/ddl_dml_data_type.sql b/dbt/include/athena/macros/utils/ddl_dml_data_type.sql
index ec726627..ddd0fcb0 100644
--- a/dbt/include/athena/macros/utils/ddl_dml_data_type.sql
+++ b/dbt/include/athena/macros/utils/ddl_dml_data_type.sql
@@ -18,8 +18,14 @@
{%- endif -%}
-- transform timestamp
- {%- if table_type == 'iceberg' and 'timestamp' in data_type -%}
- {% set data_type = 'timestamp' -%}
+ {%- if table_type == 'iceberg' -%}
+ {%- if 'timestamp' in data_type -%}
+ {% set data_type = 'timestamp' -%}
+ {%- endif -%}
+
+ {%- if 'binary' in data_type -%}
+ {% set data_type = 'binary' -%}
+ {%- endif -%}
{%- endif -%}
{{ return(data_type) }}
diff --git a/dev-requirements.txt b/dev-requirements.txt
index 3dfa7e77..41e3249f 100644
--- a/dev-requirements.txt
+++ b/dev-requirements.txt
@@ -1,14 +1,14 @@
autoflake~=1.7
black~=23.3
-dbt-tests-adapter~=1.5.4
-flake8~=5.0
+dbt-tests-adapter~=1.6.1
+flake8~=6.1
Flake8-pyproject~=1.2
isort~=5.11
-moto~=4.1.14
+moto~=4.2.0
pre-commit~=2.21
-pyparsing~=3.1.0
+pyparsing~=3.1.1
pytest~=7.4
pytest-cov~=4.1
pytest-dotenv~=0.5
pytest-xdist~=3.3
-pyupgrade~=3.3
+pyupgrade~=3.10
diff --git a/setup.py b/setup.py
index dae9a695..bfaacbfa 100644
--- a/setup.py
+++ b/setup.py
@@ -31,7 +31,7 @@ def _get_package_version() -> str:
return f'{parts["major"]}.{parts["minor"]}.{parts["patch"]}'
-dbt_version = "1.5"
+dbt_version = "1.6"
package_version = _get_package_version()
description = "The athena adapter plugin for dbt (data build tool)"
@@ -55,9 +55,9 @@ def _get_package_version() -> str:
# In order to control dbt-core version and package version
"boto3~=1.26",
"boto3-stubs[athena,glue,lakeformation,sts]~=1.26",
- "dbt-core~=1.5.0",
+ "dbt-core~=1.6.0",
"pyathena>=2.25,<4.0",
- "pydantic~=1.10",
+ "pydantic>=1.10,<3.0",
"tenacity~=8.2",
],
classifiers=[
@@ -66,10 +66,10 @@ def _get_package_version() -> str:
"Operating System :: Microsoft :: Windows",
"Operating System :: MacOS :: MacOS X",
"Operating System :: POSIX :: Linux",
- "Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
],
+ python_requires=">=3.8",
)
diff --git a/tests/conftest.py b/tests/conftest.py
index 827b862b..5f5f7f4c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,7 +7,7 @@
from dbt.events.eventmgr import LineFormat, NoFilter
from dbt.events.functions import EVENT_MANAGER, _get_stdout_config
-# Import the fuctional fixtures as a plugin
+# Import the functional fixtures as a plugin
# Note: fixtures with session scope need to be local
pytest_plugins = ["dbt.tests.fixtures.project"]
@@ -42,11 +42,14 @@ def dbt_debug_caplog() -> StringIO:
def _setup_custom_caplog(name: str, level: EventLevel):
capture_config = _get_stdout_config(
- line_format=LineFormat.PlainText, level=level, use_colors=False, debug=True, log_cache_events=True, quiet=False
+ line_format=LineFormat.PlainText,
+ level=level,
+ use_colors=False,
+ log_cache_events=True,
)
capture_config.name = name
capture_config.filter = NoFilter
- stringbuf = StringIO()
- capture_config.output_stream = stringbuf
+ string_buf = StringIO()
+ capture_config.output_stream = string_buf
EVENT_MANAGER.add_logger(capture_config)
- return stringbuf
+ return string_buf
diff --git a/tests/functional/adapter/fixture_split_parts.py b/tests/functional/adapter/fixture_split_parts.py
new file mode 100644
index 00000000..c9ff3d7c
--- /dev/null
+++ b/tests/functional/adapter/fixture_split_parts.py
@@ -0,0 +1,39 @@
+models__test_split_part_sql = """
+with data as (
+
+ select * from {{ ref('data_split_part') }}
+
+)
+
+select
+ {{ split_part('parts', 'split_on', 1) }} as actual,
+ result_1 as expected
+
+from data
+
+union all
+
+select
+ {{ split_part('parts', 'split_on', 2) }} as actual,
+ result_2 as expected
+
+from data
+
+union all
+
+select
+ {{ split_part('parts', 'split_on', 3) }} as actual,
+ result_3 as expected
+
+from data
+"""
+
+models__test_split_part_yml = """
+version: 2
+models:
+ - name: test_split_part
+ tests:
+ - assert_equal:
+ actual: actual
+ expected: expected
+"""
diff --git a/tests/functional/adapter/test_partitions.py b/tests/functional/adapter/test_partitions.py
new file mode 100644
index 00000000..68e639e6
--- /dev/null
+++ b/tests/functional/adapter/test_partitions.py
@@ -0,0 +1,127 @@
+import pytest
+
+from dbt.contracts.results import RunStatus
+from dbt.tests.util import run_dbt
+
+# this query generates 212 records
+test_partitions_model_sql = """
+select
+ random() as rnd,
+ cast(date_column as date) as date_column,
+ doy(date_column) as doy
+from (
+ values (
+ sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day)
+ )
+) as t1(date_array)
+cross join unnest(date_array) as t2(date_column)
+"""
+
+
+class TestHiveTablePartitions:
+ @pytest.fixture(scope="class")
+ def project_config_update(self):
+ return {"models": {"+table_type": "hive", "+materialized": "table", "+partitioned_by": ["date_column", "doy"]}}
+
+ @pytest.fixture(scope="class")
+ def models(self):
+ return {
+ "test_hive_partitions.sql": test_partitions_model_sql,
+ }
+
+ def test__check_incremental_run_with_partitions(self, project):
+ relation_name = "test_hive_partitions"
+ model_run_result_row_count_query = "select count(*) as records from {}.{}".format(
+ project.test_schema, relation_name
+ )
+
+ first_model_run = run_dbt(["run", "--select", relation_name])
+ first_model_run_result = first_model_run.results[0]
+
+ # check that the model run successfully
+ assert first_model_run_result.status == RunStatus.Success
+
+ records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]
+
+ assert records_count_first_run == 212
+
+
+class TestIcebergTablePartitions:
+ @pytest.fixture(scope="class")
+ def project_config_update(self):
+ return {
+ "models": {
+ "+table_type": "iceberg",
+ "+materialized": "table",
+ "+partitioned_by": ["DAY(date_column)", "doy"],
+ }
+ }
+
+ @pytest.fixture(scope="class")
+ def models(self):
+ return {
+ "test_iceberg_partitions.sql": test_partitions_model_sql,
+ }
+
+ def test__check_incremental_run_with_partitions(self, project):
+ relation_name = "test_iceberg_partitions"
+ model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}"
+
+ first_model_run = run_dbt(["run", "--select", relation_name])
+ first_model_run_result = first_model_run.results[0]
+
+ # check that the model run successfully
+ assert first_model_run_result.status == RunStatus.Success
+
+ records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]
+
+ assert records_count_first_run == 212
+
+
+class TestIcebergIncrementalPartitions:
+ @pytest.fixture(scope="class")
+ def project_config_update(self):
+ return {
+ "models": {
+ "+table_type": "iceberg",
+ "+materialized": "incremental",
+ "+incremental_strategy": "merge",
+ "+unique_key": "doy",
+ "+partitioned_by": ["DAY(date_column)", "doy"],
+ }
+ }
+
+ @pytest.fixture(scope="class")
+ def models(self):
+ return {
+ "test_iceberg_partitions_incremental.sql": test_partitions_model_sql,
+ }
+
+ def test__check_incremental_run_with_partitions(self, project):
+ """
+ Check that the incremental run works with iceberg and partitioned datasets
+ """
+
+ relation_name = "test_iceberg_partitions_incremental"
+ model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}"
+
+ first_model_run = run_dbt(["run", "--select", relation_name, "--full-refresh"])
+ first_model_run_result = first_model_run.results[0]
+
+ # check that the model run successfully
+ assert first_model_run_result.status == RunStatus.Success
+
+ records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]
+
+ assert records_count_first_run == 212
+
+ incremental_model_run = run_dbt(["run", "--select", relation_name])
+
+ incremental_model_run_result = incremental_model_run.results[0]
+
+ # check that the model run successfully after incremental run
+ assert incremental_model_run_result.status == RunStatus.Success
+
+ incremental_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]
+
+ assert incremental_records_count == 212
diff --git a/tests/functional/adapter/utils/test_utils.py b/tests/functional/adapter/utils/test_utils.py
index b998b6c7..83f743e9 100644
--- a/tests/functional/adapter/utils/test_utils.py
+++ b/tests/functional/adapter/utils/test_utils.py
@@ -3,6 +3,10 @@
models__test_datediff_sql,
seeds__data_datediff_csv,
)
+from tests.functional.adapter.fixture_split_parts import (
+ models__test_split_part_sql,
+ models__test_split_part_yml,
+)
from dbt.tests.adapter.utils.fixture_datediff import models__test_datediff_yml
from dbt.tests.adapter.utils.test_any_value import BaseAnyValue
@@ -100,7 +104,12 @@ class TestRight(BaseRight):
class TestSplitPart(BaseSplitPart):
- pass
+ @pytest.fixture(scope="class")
+ def models(self):
+ return {
+ "test_split_part.yml": models__test_split_part_yml,
+ "test_split_part.sql": self.interpolate_macro_namespace(models__test_split_part_sql, "split_part"),
+ }
class TestStringLiteral(BaseStringLiteral):
diff --git a/tests/unit/constants.py b/tests/unit/constants.py
index 0513ab69..8fe9aa82 100644
--- a/tests/unit/constants.py
+++ b/tests/unit/constants.py
@@ -1,6 +1,7 @@
CATALOG_ID = "12345678910"
DATA_CATALOG_NAME = "awsdatacatalog"
SHARED_DATA_CATALOG_NAME = "9876543210"
+FEDERATED_QUERY_CATALOG_NAME = "federated_query_data_source"
DATABASE_NAME = "test_dbt_athena"
BUCKET = "test-dbt-athena"
AWS_REGION = "eu-west-1"
diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py
index 30d39e4b..f14ae523 100644
--- a/tests/unit/test_adapter.py
+++ b/tests/unit/test_adapter.py
@@ -4,7 +4,10 @@
import agate
import boto3
+import botocore
import pytest
+
+# from botocore.client.BaseClient import _make_api_call
from moto import mock_athena, mock_glue, mock_s3, mock_sts
from moto.core import DEFAULT_ACCOUNT_ID
@@ -14,6 +17,7 @@
from dbt.adapters.athena.connections import AthenaCursor, AthenaParameterFormatter
from dbt.adapters.athena.exceptions import S3LocationException
from dbt.adapters.athena.relation import AthenaRelation, TableType
+from dbt.adapters.athena.utils import AthenaCatalogType
from dbt.clients import agate_helper
from dbt.contracts.connection import ConnectionState
from dbt.contracts.files import FileHash
@@ -28,6 +32,7 @@
BUCKET,
DATA_CATALOG_NAME,
DATABASE_NAME,
+ FEDERATED_QUERY_CATALOG_NAME,
S3_STAGING_DIR,
SHARED_DATA_CATALOG_NAME,
)
@@ -66,6 +71,7 @@ def setup_method(self, _):
("awsdatacatalog", "quux"),
("awsdatacatalog", "baz"),
(SHARED_DATA_CATALOG_NAME, "foo"),
+ (FEDERATED_QUERY_CATALOG_NAME, "foo"),
}
self.mock_manifest.nodes = {
"model.root.model1": CompiledNode(
@@ -212,6 +218,42 @@ def setup_method(self, _):
raw_code="select * from source_table",
language="",
),
+ "model.root.model5": CompiledNode(
+ name="model5",
+ database=FEDERATED_QUERY_CATALOG_NAME,
+ schema="foo",
+ resource_type=NodeType.Model,
+ unique_id="model.root.model5",
+ alias="bar",
+ fqn=["root", "model5"],
+ package_name="root",
+ refs=[],
+ sources=[],
+ depends_on=DependsOn(),
+ config=NodeConfig.from_dict(
+ {
+ "enabled": True,
+ "materialized": "table",
+ "persist_docs": {},
+ "post-hook": [],
+ "pre-hook": [],
+ "vars": {},
+ "meta": {"owner": "data-engineers"},
+ "quoting": {},
+ "column_types": {},
+ "tags": [],
+ }
+ ),
+ tags=[],
+ path="model5.sql",
+ original_file_path="model5.sql",
+ compiled=True,
+ extra_ctes_injected=False,
+ extra_ctes=[],
+ checksum=FileHash.from_contents(""),
+ raw_code="select * from source_table",
+ language="",
+ ),
}
@property
@@ -612,9 +654,84 @@ def test__get_one_catalog_shared_catalog(self, mock_aws_service):
for row in actual.rows.values():
assert row.values() in expected_rows
+ @mock_athena
+ def test__get_one_catalog_federated_query_catalog(self, mock_aws_service):
+ mock_aws_service.create_data_catalog(
+ catalog_name=FEDERATED_QUERY_CATALOG_NAME, catalog_type=AthenaCatalogType.LAMBDA
+ )
+ mock_information_schema = mock.MagicMock()
+ mock_information_schema.path.database = FEDERATED_QUERY_CATALOG_NAME
+
+ # Original botocore _make_api_call function
+ orig = botocore.client.BaseClient._make_api_call
+
+ # Mocking this as list_table_metadata and creating non glue tables is not supported by moto.
+ # Followed this guide: http://docs.getmoto.org/en/latest/docs/services/patching_other_services.html
+ def mock_athena_list_table_metadata(self, operation_name, kwarg):
+ if operation_name == "ListTableMetadata":
+ return {
+ "TableMetadataList": [
+ {
+ "Name": "bar",
+ "TableType": "EXTERNAL_TABLE",
+ "Columns": [
+ {
+ "Name": "id",
+ "Type": "string",
+ },
+ {
+ "Name": "country",
+ "Type": "string",
+ },
+ ],
+ "PartitionKeys": [
+ {
+ "Name": "dt",
+ "Type": "date",
+ },
+ ],
+ }
+ ],
+ }
+ # If we don't want to patch the API call
+ return orig(self, operation_name, kwarg)
+
+ self.adapter.acquire_connection("dummy")
+ with patch("botocore.client.BaseClient._make_api_call", new=mock_athena_list_table_metadata):
+ actual = self.adapter._get_one_catalog(
+ mock_information_schema,
+ {
+ "foo": {"bar"},
+ },
+ self.mock_manifest,
+ )
+
+ expected_column_names = (
+ "table_database",
+ "table_schema",
+ "table_name",
+ "table_type",
+ "table_comment",
+ "column_name",
+ "column_index",
+ "column_type",
+ "column_comment",
+ "table_owner",
+ )
+ expected_rows = [
+ (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "id", 0, "string", None, "data-engineers"),
+ (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "country", 1, "string", None, "data-engineers"),
+ (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "dt", 2, "date", None, "data-engineers"),
+ ]
+
+ assert actual.column_names == expected_column_names
+ assert len(actual.rows) == len(expected_rows)
+ for row in actual.rows.values():
+ assert row.values() in expected_rows
+
def test__get_catalog_schemas(self):
res = self.adapter._get_catalog_schemas(self.mock_manifest)
- assert len(res.keys()) == 2
+ assert len(res.keys()) == 3
information_schema_0 = list(res.keys())[0]
assert information_schema_0.name == "INFORMATION_SCHEMA"
@@ -632,6 +749,14 @@ def test__get_catalog_schemas(self):
assert set(relations.keys()) == {"foo"}
assert list(relations.values()) == [{"bar"}]
+ information_schema_1 = list(res.keys())[2]
+ assert information_schema_1.name == "INFORMATION_SCHEMA"
+ assert information_schema_1.schema is None
+ assert information_schema_1.database == FEDERATED_QUERY_CATALOG_NAME
+ relations = list(res.values())[1]
+ assert set(relations.keys()) == {"foo"}
+ assert list(relations.values()) == [{"bar"}]
+
@mock_athena
@mock_sts
def test__get_data_catalog(self, mock_aws_service):
@@ -696,7 +821,7 @@ def test_list_relations_without_caching_with_non_glue_data_catalog(
self, parent_list_relations_without_caching, mock_aws_service
):
data_catalog_name = "other_data_catalog"
- mock_aws_service.create_data_catalog(data_catalog_name, "HIVE")
+ mock_aws_service.create_data_catalog(data_catalog_name, AthenaCatalogType.HIVE)
schema_relation = self.adapter.Relation.create(
database=data_catalog_name,
schema=DATABASE_NAME,
diff --git a/tests/unit/utils.py b/tests/unit/utils.py
index 60ec02ed..097f6ce9 100644
--- a/tests/unit/utils.py
+++ b/tests/unit/utils.py
@@ -5,6 +5,7 @@
import agate
import boto3
+from dbt.adapters.athena.utils import AthenaCatalogType
from dbt.config.project import PartialProject
from .constants import AWS_REGION, BUCKET, CATALOG_ID, DATA_CATALOG_NAME, DATABASE_NAME
@@ -146,10 +147,18 @@ def _make_table_of(self, rows, column_types):
class MockAWSService:
def create_data_catalog(
- self, catalog_name: str = DATA_CATALOG_NAME, catalog_type: str = "GLUE", catalog_id: str = CATALOG_ID
+ self,
+ catalog_name: str = DATA_CATALOG_NAME,
+ catalog_type: AthenaCatalogType = AthenaCatalogType.GLUE,
+ catalog_id: str = CATALOG_ID,
):
athena = boto3.client("athena", region_name=AWS_REGION)
- athena.create_data_catalog(Name=catalog_name, Type=catalog_type, Parameters={"catalog-id": catalog_id})
+ parameters = {}
+ if catalog_type == AthenaCatalogType.GLUE:
+ parameters = {"catalog-id": catalog_id}
+ else:
+ parameters = {"catalog": catalog_name}
+ athena.create_data_catalog(Name=catalog_name, Type=catalog_type.value, Parameters=parameters)
def create_database(self, name: str = DATABASE_NAME, catalog_id: str = CATALOG_ID):
glue = boto3.client("glue", region_name=AWS_REGION)