diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index c6354a37..8cdee2f4 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @jessedobbelaere @Jrmyy @mattiamatrix @nicor88 @svdimchenko @thenaturalist +* @jessedobbelaere @Jrmyy @mattiamatrix @nicor88 @svdimchenko diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0e2fef2a..ab75579a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + python-version: ['3.8', '3.9', '3.10', '3.11'] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/README.md b/README.md index ccc82e72..133290a6 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,19 @@

+ + + - +

## Features -* Supports dbt version `1.5.*` +* Supports dbt version `1.6.*` +* Supports from Python * Supports [seeds][seeds] * Correctly detects views and their columns * Supports [table materialization][table] diff --git a/dbt/adapters/athena/__version__.py b/dbt/adapters/athena/__version__.py index 0c46db4f..38ec8ede 100644 --- a/dbt/adapters/athena/__version__.py +++ b/dbt/adapters/athena/__version__.py @@ -1 +1 @@ -version = "1.5.1" +version = "1.6.0" diff --git a/dbt/adapters/athena/connections.py b/dbt/adapters/athena/connections.py index 67e2e71e..8645f8bc 100644 --- a/dbt/adapters/athena/connections.py +++ b/dbt/adapters/athena/connections.py @@ -1,4 +1,6 @@ import hashlib +import json +import re import time from concurrent.futures.thread import ThreadPoolExecutor from contextlib import contextmanager @@ -132,6 +134,7 @@ def execute( # type: ignore endpoint_url: Optional[str] = None, cache_size: int = 0, cache_expiration_time: int = 0, + catch_partitions_limit: bool = False, **kwargs, ): def inner() -> AthenaCursor: @@ -158,7 +161,12 @@ def inner() -> AthenaCursor: return self retry = tenacity.Retrying( - retry=retry_if_exception(lambda _: True), + # No need to retry if TOO_MANY_OPEN_PARTITIONS occurs. + # Otherwise, Athena throws ICEBERG_FILESYSTEM_ERROR after retry, + # because not all files are removed immediately after first try to create table + retry=retry_if_exception( + lambda e: False if catch_partitions_limit and "TOO_MANY_OPEN_PARTITIONS" in str(e) else True + ), stop=stop_after_attempt(self._retry_config.attempt), wait=wait_exponential( multiplier=self._retry_config.attempt, @@ -231,15 +239,37 @@ def open(cls, connection: Connection) -> Connection: @classmethod def get_response(cls, cursor: AthenaCursor) -> AthenaAdapterResponse: code = "OK" if cursor.state == AthenaQueryExecution.STATE_SUCCEEDED else "ERROR" + rowcount, data_scanned_in_bytes = cls.process_query_stats(cursor) return AthenaAdapterResponse( - _message=f"{code} {cursor.rowcount}", - rows_affected=cursor.rowcount, + _message=f"{code} {rowcount}", + rows_affected=rowcount, code=code, - data_scanned_in_bytes=cursor.data_scanned_in_bytes, + data_scanned_in_bytes=data_scanned_in_bytes, ) + @staticmethod + def process_query_stats(cursor: AthenaCursor) -> Tuple[int, int]: + """ + Helper function to parse query statistics from SELECT statements. + The function looks for all statements that contains rowcount or data_scanned_in_bytes, + then strip the SELECT statements, and pick the value between curly brackets. + """ + if all(map(cursor.query.__contains__, ["rowcount", "data_scanned_in_bytes"])): + try: + query_split = cursor.query.lower().split("select")[-1] + # query statistics are in the format {"rowcount":1, "data_scanned_in_bytes": 3} + # the following statement extract the content between { and } + query_stats = re.search("{(.*)}", query_split) + if query_stats: + stats = json.loads("{" + query_stats.group(1) + "}") + return stats.get("rowcount", -1), stats.get("data_scanned_in_bytes", 0) + except Exception as err: + logger.debug(f"There was an error parsing query stats {err}") + return -1, 0 + return cursor.rowcount, cursor.data_scanned_in_bytes + def cancel(self, connection: Connection) -> None: - connection.handle.cancel() + pass def add_begin_query(self) -> None: pass diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index fbc2491a..253e8d0c 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -1,6 +1,7 @@ import csv import os import posixpath as path +import re import tempfile from itertools import chain from textwrap import dedent @@ -19,6 +20,7 @@ TableTypeDef, TableVersionTypeDef, ) +from pyathena.error import OperationalError from dbt.adapters.athena import AthenaConnectionManager from dbt.adapters.athena.column import AthenaColumn @@ -42,7 +44,13 @@ get_table_type, ) from dbt.adapters.athena.s3 import S3DataNaming -from dbt.adapters.athena.utils import clean_sql_comment, get_catalog_id, get_chunks +from dbt.adapters.athena.utils import ( + AthenaCatalogType, + clean_sql_comment, + get_catalog_id, + get_catalog_type, + get_chunks, +) from dbt.adapters.base import ConstraintSupport, available from dbt.adapters.base.relation import BaseRelation, InformationSchema from dbt.adapters.sql import SQLAdapter @@ -413,6 +421,29 @@ def _get_one_table_for_catalog(self, table: TableTypeDef, database: str) -> List for idx, col in enumerate(table["StorageDescriptor"]["Columns"] + table.get("PartitionKeys", [])) ] + def _get_one_table_for_non_glue_catalog( + self, table: TableTypeDef, schema: str, database: str + ) -> List[Dict[str, Any]]: + table_catalog = { + "table_database": database, + "table_schema": schema, + "table_name": table["Name"], + "table_type": RELATION_TYPE_MAP[table.get("TableType", "EXTERNAL_TABLE")].value, + "table_comment": table.get("Parameters", {}).get("comment", ""), + } + return [ + { + **table_catalog, + **{ + "column_name": col["Name"], + "column_index": idx, + "column_type": col["Type"], + "column_comment": col.get("Comment", ""), + }, + } + for idx, col in enumerate(table["Columns"] + table.get("PartitionKeys", [])) + ] + def _get_one_catalog( self, information_schema: InformationSchema, @@ -420,29 +451,55 @@ def _get_one_catalog( manifest: Manifest, ) -> agate.Table: data_catalog = self._get_data_catalog(information_schema.path.database) - catalog_id = get_catalog_id(data_catalog) + data_catalog_type = get_catalog_type(data_catalog) + conn = self.connections.get_thread_connection() client = conn.handle - with boto3_client_lock: - glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config()) - - catalog = [] - paginator = glue_client.get_paginator("get_tables") - for schema, relations in schemas.items(): - kwargs = { - "DatabaseName": schema, - "MaxResults": 100, - } - # If the catalog is `awsdatacatalog` we don't need to pass CatalogId as boto3 infers it from the account Id. - if catalog_id: - kwargs["CatalogId"] = catalog_id + if data_catalog_type == AthenaCatalogType.GLUE: + with boto3_client_lock: + glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config()) + + catalog = [] + paginator = glue_client.get_paginator("get_tables") + for schema, relations in schemas.items(): + kwargs = { + "DatabaseName": schema, + "MaxResults": 100, + } + # If the catalog is `awsdatacatalog` we don't need to pass CatalogId as boto3 + # infers it from the account Id. + catalog_id = get_catalog_id(data_catalog) + if catalog_id: + kwargs["CatalogId"] = catalog_id + + for page in paginator.paginate(**kwargs): + for table in page["TableList"]: + if relations and table["Name"] in relations: + catalog.extend(self._get_one_table_for_catalog(table, information_schema.path.database)) + table = agate.Table.from_object(catalog) + else: + with boto3_client_lock: + athena_client = client.session.client( + "athena", region_name=client.region_name, config=get_boto3_config() + ) - for page in paginator.paginate(**kwargs): - for table in page["TableList"]: - if relations and table["Name"] in relations: - catalog.extend(self._get_one_table_for_catalog(table, information_schema.path.database)) + catalog = [] + paginator = athena_client.get_paginator("list_table_metadata") + for schema, relations in schemas.items(): + for page in paginator.paginate( + CatalogName=information_schema.path.database, + DatabaseName=schema, + MaxResults=50, # Limit supported by this operation + ): + for table in page["TableMetadataList"]: + if relations and table["Name"] in relations: + catalog.extend( + self._get_one_table_for_non_glue_catalog( + table, schema, information_schema.path.database + ) + ) + table = agate.Table.from_object(catalog) - table = agate.Table.from_object(catalog) filtered_table = self._catalog_filter_table(table, manifest) return self._join_catalog_table_owners(filtered_table, manifest) @@ -912,3 +969,26 @@ def _get_table_input(table: TableTypeDef) -> TableInputTypeDef: returned by get_table() method. """ return {k: v for k, v in table.items() if k in TableInputTypeDef.__annotations__} + + @available + def run_query_with_partitions_limit_catching(self, sql: str) -> str: + conn = self.connections.get_thread_connection() + cursor = conn.handle.cursor() + try: + cursor.execute(sql, catch_partitions_limit=True) + except OperationalError as e: + LOGGER.debug(f"CAUGHT EXCEPTION: {e}") + if "TOO_MANY_OPEN_PARTITIONS" in str(e): + return "TOO_MANY_OPEN_PARTITIONS" + raise e + return f'{{"rowcount":{cursor.rowcount},"data_scanned_in_bytes":{cursor.data_scanned_in_bytes}}}' + + @available + def format_partition_keys(self, partition_keys: List[str]) -> str: + return ", ".join([self.format_one_partition_key(k) for k in partition_keys]) + + @available + def format_one_partition_key(self, partition_key: str) -> str: + """Check if partition key uses Iceberg hidden partitioning""" + hidden = re.search(r"^(hour|day|month|year)\((.+)\)", partition_key.lower()) + return f"date_trunc('{hidden.group(1)}', {hidden.group(2)})" if hidden else partition_key.lower() diff --git a/dbt/adapters/athena/relation.py b/dbt/adapters/athena/relation.py index bb07393f..680f9e09 100644 --- a/dbt/adapters/athena/relation.py +++ b/dbt/adapters/athena/relation.py @@ -79,6 +79,7 @@ def add(self, relation: AthenaRelation) -> None: RELATION_TYPE_MAP = { "EXTERNAL_TABLE": TableType.TABLE, + "EXTERNAL": TableType.TABLE, # type returned by federated query tables "MANAGED_TABLE": TableType.TABLE, "VIRTUAL_VIEW": TableType.VIEW, "table": TableType.TABLE, diff --git a/dbt/adapters/athena/utils.py b/dbt/adapters/athena/utils.py index 778fb4c2..dcd74916 100644 --- a/dbt/adapters/athena/utils.py +++ b/dbt/adapters/athena/utils.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Generator, List, Optional, TypeVar from mypy_boto3_athena.type_defs import DataCatalogTypeDef @@ -9,7 +10,17 @@ def clean_sql_comment(comment: str) -> str: def get_catalog_id(catalog: Optional[DataCatalogTypeDef]) -> Optional[str]: - return catalog["Parameters"]["catalog-id"] if catalog else None + return catalog["Parameters"]["catalog-id"] if catalog and catalog["Type"] == AthenaCatalogType.GLUE.value else None + + +class AthenaCatalogType(Enum): + GLUE = "GLUE" + LAMBDA = "LAMBDA" + HIVE = "HIVE" + + +def get_catalog_type(catalog: Optional[DataCatalogTypeDef]) -> Optional[AthenaCatalogType]: + return AthenaCatalogType(catalog["Type"]) if catalog else None T = TypeVar("T") diff --git a/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql b/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql new file mode 100644 index 00000000..20773841 --- /dev/null +++ b/dbt/include/athena/macros/materializations/models/helpers/get_partition_batches.sql @@ -0,0 +1,52 @@ +{% macro get_partition_batches(sql) -%} + {%- set partitioned_by = config.get('partitioned_by') -%} + {%- set athena_partitions_limit = config.get('partitions_limit', 100) | int -%} + {%- set partitioned_keys = adapter.format_partition_keys(partitioned_by) -%} + {% do log('PARTITIONED KEYS: ' ~ partitioned_keys) %} + + {% call statement('get_partitions', fetch_result=True) %} + select distinct {{ partitioned_keys }} from ({{ sql }}) order by {{ partitioned_keys }}; + {% endcall %} + + {%- set table = load_result('get_partitions').table -%} + {%- set rows = table.rows -%} + {%- set partitions = {} -%} + {% do log('TOTAL PARTITIONS TO PROCESS: ' ~ rows | length) %} + {%- set partitions_batches = [] -%} + + {%- for row in rows -%} + {%- set single_partition = [] -%} + {%- for col in row -%} + + {%- set column_type = adapter.convert_type(table, loop.index0) -%} + {%- if column_type == 'integer' -%} + {%- set value = col | string -%} + {%- elif column_type == 'string' -%} + {%- set value = "'" + col + "'" -%} + {%- elif column_type == 'date' -%} + {%- set value = "DATE'" + col | string + "'" -%} + {%- elif column_type == 'timestamp' -%} + {%- set value = "TIMESTAMP'" + col | string + "'" -%} + {%- else -%} + {%- do exceptions.raise_compiler_error('Need to add support for column type ' + column_type) -%} + {%- endif -%} + {%- set partition_key = adapter.format_one_partition_key(partitioned_by[loop.index0]) -%} + {%- do single_partition.append(partition_key + '=' + value) -%} + {%- endfor -%} + + {%- set single_partition_expression = single_partition | join(' and ') -%} + + {%- set batch_number = (loop.index0 / athena_partitions_limit) | int -%} + {% if not batch_number in partitions %} + {% do partitions.update({batch_number: []}) %} + {% endif %} + + {%- do partitions[batch_number].append('(' + single_partition_expression + ')') -%} + {%- if partitions[batch_number] | length == athena_partitions_limit or loop.last -%} + {%- do partitions_batches.append(partitions[batch_number] | join(' or ')) -%} + {%- endif -%} + {%- endfor -%} + + {{ return(partitions_batches) }} + +{%- endmacro %} diff --git a/dbt/include/athena/macros/materializations/models/incremental/helpers.sql b/dbt/include/athena/macros/materializations/models/incremental/helpers.sql index a965e85c..76c2ed73 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/helpers.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/helpers.sql @@ -22,19 +22,42 @@ {% endmacro %} {% macro incremental_insert(on_schema_change, tmp_relation, target_relation, existing_relation, statement_name="main") %} - {% set dest_columns = process_schema_changes(on_schema_change, tmp_relation, existing_relation) %} - {% if not dest_columns %} + {%- set dest_columns = process_schema_changes(on_schema_change, tmp_relation, existing_relation) -%} + {%- if not dest_columns -%} {%- set dest_columns = adapter.get_columns_in_relation(target_relation) -%} - {% endif %} + {%- endif -%} {%- set dest_cols_csv = dest_columns | map(attribute='quoted') | join(', ') -%} - insert into {{ target_relation }} ({{ dest_cols_csv }}) - ( - select {{ dest_cols_csv }} - from {{ tmp_relation }} - ); + {%- set insert_full -%} + insert into {{ target_relation }} ({{ dest_cols_csv }}) + ( + select {{ dest_cols_csv }} + from {{ tmp_relation }} + ); + {%- endset -%} + + {%- set query_result = adapter.run_query_with_partitions_limit_catching(insert_full) -%} + {%- do log('QUERY RESULT: ' ~ query_result) -%} + {%- if query_result == 'TOO_MANY_OPEN_PARTITIONS' -%} + {% set partitions_batches = get_partition_batches(tmp_relation) %} + {% do log('BATCHES TO PROCESS: ' ~ partitions_batches | length) %} + {%- for batch in partitions_batches -%} + {%- do log('BATCH PROCESSING: ' ~ loop.index ~ ' OF ' ~ partitions_batches|length) -%} + {%- set insert_batch_partitions -%} + insert into {{ target_relation }} ({{ dest_cols_csv }}) + ( + select {{ dest_cols_csv }} + from {{ tmp_relation }} + where {{ batch }} + ); + {%- endset -%} + {%- do run_query(insert_batch_partitions) -%} + {%- endfor -%} + {%- endif -%} + SELECT '{{query_result}}' {%- endmacro %} + {% macro delete_overlapping_partitions(target_relation, tmp_relation, partitioned_by) %} {%- set partitioned_keys = partitioned_by | tojson | replace('\"', '') | replace('[', '') | replace(']', '') -%} {% call statement('get_partitions', fetch_result=True) %} diff --git a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql index 879aa335..b42c7cb9 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql @@ -7,7 +7,7 @@ {% set lf_tags_config = config.get('lf_tags_config') %} {% set lf_grants = config.get('lf_grants') %} - {% set partitioned_by = config.get('partitioned_by', default=none) %} + {% set partitioned_by = config.get('partitioned_by') %} {% set target_relation = this.incorporate(type='table') %} {% set existing_relation = load_relation(this) %} {% set tmp_relation = make_temp_relation(this) %} @@ -24,16 +24,18 @@ {% set to_drop = [] %} {% if existing_relation is none %} - {% set build_sql = create_table_as(False, target_relation, sql) -%} + {% set query_result = safe_create_table_as(False, target_relation, sql) -%} + {% set build_sql = "select '{{ query_result }}'" -%} {% elif existing_relation.is_view or should_full_refresh() %} {% do drop_relation(existing_relation) %} - {% set build_sql = create_table_as(False, target_relation, sql) -%} + {% set query_result = safe_create_table_as(False, target_relation, sql) -%} + {% set build_sql = "select '{{ query_result }}'" -%} {% elif partitioned_by is not none and strategy == 'insert_overwrite' %} {% set tmp_relation = make_temp_relation(target_relation) %} {% if tmp_relation is not none %} {% do drop_relation(tmp_relation) %} {% endif %} - {% do run_query(create_table_as(True, tmp_relation, sql)) %} + {% set query_result = safe_create_table_as(True, tmp_relation, sql) -%} {% do delete_overlapping_partitions(target_relation, tmp_relation, partitioned_by) %} {% set build_sql = incremental_insert(on_schema_change, tmp_relation, target_relation, existing_relation) %} {% do to_drop.append(tmp_relation) %} @@ -42,7 +44,7 @@ {% if tmp_relation is not none %} {% do drop_relation(tmp_relation) %} {% endif %} - {% do run_query(create_table_as(True, tmp_relation, sql)) %} + {% set query_result = safe_create_table_as(True, tmp_relation, sql) -%} {% set build_sql = incremental_insert(on_schema_change, tmp_relation, target_relation, existing_relation) %} {% do to_drop.append(tmp_relation) %} {% elif strategy == 'merge' and table_type == 'iceberg' %} @@ -67,7 +69,7 @@ {% if tmp_relation is not none %} {% do drop_relation(tmp_relation) %} {% endif %} - {% do run_query(create_table_as(True, tmp_relation, sql)) %} + {% set query_result = safe_create_table_as(True, tmp_relation, sql) -%} {% set build_sql = iceberg_merge(on_schema_change, tmp_relation, target_relation, unique_key, incremental_predicates, existing_relation, delete_condition) %} {% do to_drop.append(tmp_relation) %} {% endif %} diff --git a/dbt/include/athena/macros/materializations/models/incremental/merge.sql b/dbt/include/athena/macros/materializations/models/incremental/merge.sql index cd06cb03..07d8b886 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/merge.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/merge.sql @@ -73,33 +73,66 @@ {%- endfor -%} {%- set update_columns = get_merge_update_columns(merge_update_columns, merge_exclude_columns, dest_columns_wo_keys) -%} {%- set src_cols_csv = src_columns_quoted | join(', ') -%} - merge into {{ target_relation }} as target using {{ tmp_relation }} as src - on ( - {%- for key in unique_key_cols %} - target.{{ key }} = src.{{ key }} {{ "and " if not loop.last }} - {%- endfor %} - ) - {% if incremental_predicates is not none -%} - and ( - {%- for inc_predicate in incremental_predicates %} - {{ inc_predicate }} {{ "and " if not loop.last }} - {%- endfor %} - ) - {%- endif %} - {% if delete_condition is not none -%} - when matched and ({{ delete_condition }}) - then delete - {%- endif %} - when matched - then update set - {%- for col in update_columns %} - {%- if merge_update_columns_rules and col.name in merge_update_columns_rules %} - {{ get_update_statement(col, merge_update_columns_rules[col.name], loop.last) }} - {%- else -%} - {{ get_update_statement(col, merge_update_columns_default_rule, loop.last) }} - {%- endif -%} - {%- endfor %} - when not matched - then insert ({{ dest_cols_csv }}) - values ({{ src_cols_csv }}); + + {%- set src_part -%} + merge into {{ target_relation }} as target using {{ tmp_relation }} as src + {%- endset -%} + + {%- set merge_part -%} + on ( + {%- for key in unique_key_cols -%} + target.{{ key }} = src.{{ key }} + {{ " and " if not loop.last }} + {%- endfor -%} + {% if incremental_predicates is not none -%} + and ( + {%- for inc_predicate in incremental_predicates %} + {{ inc_predicate }} {{ "and " if not loop.last }} + {%- endfor %} + ) + {%- endif %} + ) + {% if delete_condition is not none -%} + when matched and ({{ delete_condition }}) + then delete + {%- endif %} + when matched + then update set + {%- for col in update_columns %} + {%- if merge_update_columns_rules and col.name in merge_update_columns_rules %} + {{ get_update_statement(col, merge_update_columns_rules[col.name], loop.last) }} + {%- else -%} + {{ get_update_statement(col, merge_update_columns_default_rule, loop.last) }} + {%- endif -%} + {%- endfor %} + when not matched + then insert ({{ dest_cols_csv }}) + values ({{ src_cols_csv }}) + {%- endset -%} + + {%- set merge_full -%} + {{ src_part }} + {{ merge_part }} + {%- endset -%} + + {%- set query_result = adapter.run_query_with_partitions_limit_catching(merge_full) -%} + {%- do log('QUERY RESULT: ' ~ query_result) -%} + {%- if query_result == 'TOO_MANY_OPEN_PARTITIONS' -%} + {% set partitions_batches = get_partition_batches(tmp_relation) %} + {% do log('BATCHES TO PROCESS: ' ~ partitions_batches | length) %} + {%- for batch in partitions_batches -%} + {%- do log('BATCH PROCESSING: ' ~ loop.index ~ ' OF ' ~ partitions_batches | length) -%} + {%- set src_batch_part -%} + merge into {{ target_relation }} as target + using (select * from {{ tmp_relation }} where {{ batch }}) as src + {%- endset -%} + {%- set merge_batch -%} + {{ src_batch_part }} + {{ merge_part }} + {%- endset -%} + {%- do run_query(merge_batch) -%} + {%- endfor -%} + {%- endif -%} + + SELECT '{{query_result}}' {%- endmacro %} diff --git a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql index af730a30..878a0378 100644 --- a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql +++ b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql @@ -87,3 +87,44 @@ as {{ sql }} {% endmacro %} + +{% macro create_table_as_with_partitions(temporary, relation, sql) -%} + + {% set partitions_batches = get_partition_batches(sql) %} + {% do log('BATCHES TO PROCESS: ' ~ partitions_batches | length) %} + + {%- do log('CREATE EMPTY TABLE: ' ~ relation) -%} + {%- set create_empty_table_query -%} + {{ create_table_as(temporary, relation, sql) }} + limit 0 + {%- endset -%} + {%- do run_query(create_empty_table_query) -%} + {%- set dest_columns = adapter.get_columns_in_relation(relation) -%} + {%- set dest_cols_csv = dest_columns | map(attribute='quoted') | join(', ') -%} + + {%- for batch in partitions_batches -%} + {%- do log('BATCH PROCESSING: ' ~ loop.index ~ ' OF ' ~ partitions_batches | length) -%} + + {%- set insert_batch_partitions -%} + insert into {{ relation }} ({{ dest_cols_csv }}) + select {{ dest_cols_csv }} + from ({{ sql }}) + where {{ batch }} + {%- endset -%} + + {%- do run_query(insert_batch_partitions) -%} + {%- endfor -%} + + select 'SUCCESSFULLY CREATED TABLE {{ relation }}' + +{%- endmacro %} + +{% macro safe_create_table_as(temporary, relation, sql) -%} + {%- set query_result = adapter.run_query_with_partitions_limit_catching(create_table_as(temporary, relation, sql)) -%} + {%- do log('QUERY RESULT: ' ~ query_result) -%} + {%- if query_result == 'TOO_MANY_OPEN_PARTITIONS' -%} + {%- do create_table_as_with_partitions(temporary, relation, sql) -%} + {%- set query_result = '{{ relation }} with many partitions created' -%} + {%- endif -%} + {{ return(query_result) }} +{%- endmacro %} diff --git a/dbt/include/athena/macros/materializations/models/table/table.sql b/dbt/include/athena/macros/materializations/models/table/table.sql index 7f9856e4..989bf63b 100644 --- a/dbt/include/athena/macros/materializations/models/table/table.sql +++ b/dbt/include/athena/macros/materializations/models/table/table.sql @@ -49,28 +49,22 @@ {%- endif -%} -- create tmp table - {% call statement('main') -%} - {{ create_table_as(False, tmp_relation, sql) }} - {%- endcall %} + {%- set query_result = safe_create_table_as(False, tmp_relation, sql) -%} -- swap table - {%- set swap_table = adapter.swap_table(tmp_relation, - target_relation) -%} + {%- set swap_table = adapter.swap_table(tmp_relation, target_relation) -%} -- delete glue tmp table, do not use drop_relation, as it will remove data of the target table {%- do adapter.delete_from_glue_catalog(tmp_relation) -%} - {% do adapter.expire_glue_table_versions(target_relation, - versions_to_keep, - True) %} + {% do adapter.expire_glue_table_versions(target_relation, versions_to_keep, True) %} + {%- else -%} -- Here we are in the case of non-ha tables or ha tables but in case of full refresh. {%- if old_relation is not none -%} {{ drop_relation(old_relation) }} {%- endif -%} - {%- call statement('main') -%} - {{ create_table_as(False, target_relation, sql) }} - {%- endcall %} + {%- set query_result = safe_create_table_as(False, target_relation, sql) -%} {%- endif -%} {{ set_table_classification(target_relation) }} @@ -78,14 +72,10 @@ {%- else -%} {%- if old_relation is none -%} - {%- call statement('main') -%} - {{ create_table_as(False, target_relation, sql) }} - {%- endcall %} + {%- set query_result = safe_create_table_as(False, target_relation, sql) -%} {%- else -%} {%- if old_relation.is_view -%} - {%- call statement('main') -%} - {{ create_table_as(False, tmp_relation, sql) }} - {%- endcall -%} + {%- set query_result = safe_create_table_as(False, tmp_relation, sql) -%} {%- do drop_relation(old_relation) -%} {%- do rename_relation(tmp_relation, target_relation) -%} {%- else -%} @@ -103,9 +93,7 @@ {%- do drop_relation(old_relation_bkp) -%} {%- endif -%} - {%- call statement('main') -%} - {{ create_table_as(False, tmp_relation, sql) }} - {%- endcall -%} + {% set query_result = safe_create_table_as(False, tmp_relation, sql) %} {{ rename_relation(old_relation, old_relation_bkp) }} {{ rename_relation(tmp_relation, target_relation) }} @@ -116,6 +104,10 @@ {%- endif -%} + {% call statement("main") %} + SELECT '{{ query_result }}'; + {% endcall %} + {{ run_hooks(post_hooks) }} {% if lf_tags_config is not none %} diff --git a/dbt/include/athena/macros/utils/ddl_dml_data_type.sql b/dbt/include/athena/macros/utils/ddl_dml_data_type.sql index ec726627..ddd0fcb0 100644 --- a/dbt/include/athena/macros/utils/ddl_dml_data_type.sql +++ b/dbt/include/athena/macros/utils/ddl_dml_data_type.sql @@ -18,8 +18,14 @@ {%- endif -%} -- transform timestamp - {%- if table_type == 'iceberg' and 'timestamp' in data_type -%} - {% set data_type = 'timestamp' -%} + {%- if table_type == 'iceberg' -%} + {%- if 'timestamp' in data_type -%} + {% set data_type = 'timestamp' -%} + {%- endif -%} + + {%- if 'binary' in data_type -%} + {% set data_type = 'binary' -%} + {%- endif -%} {%- endif -%} {{ return(data_type) }} diff --git a/dev-requirements.txt b/dev-requirements.txt index 3dfa7e77..41e3249f 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,14 +1,14 @@ autoflake~=1.7 black~=23.3 -dbt-tests-adapter~=1.5.4 -flake8~=5.0 +dbt-tests-adapter~=1.6.1 +flake8~=6.1 Flake8-pyproject~=1.2 isort~=5.11 -moto~=4.1.14 +moto~=4.2.0 pre-commit~=2.21 -pyparsing~=3.1.0 +pyparsing~=3.1.1 pytest~=7.4 pytest-cov~=4.1 pytest-dotenv~=0.5 pytest-xdist~=3.3 -pyupgrade~=3.3 +pyupgrade~=3.10 diff --git a/setup.py b/setup.py index dae9a695..bfaacbfa 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ def _get_package_version() -> str: return f'{parts["major"]}.{parts["minor"]}.{parts["patch"]}' -dbt_version = "1.5" +dbt_version = "1.6" package_version = _get_package_version() description = "The athena adapter plugin for dbt (data build tool)" @@ -55,9 +55,9 @@ def _get_package_version() -> str: # In order to control dbt-core version and package version "boto3~=1.26", "boto3-stubs[athena,glue,lakeformation,sts]~=1.26", - "dbt-core~=1.5.0", + "dbt-core~=1.6.0", "pyathena>=2.25,<4.0", - "pydantic~=1.10", + "pydantic>=1.10,<3.0", "tenacity~=8.2", ], classifiers=[ @@ -66,10 +66,10 @@ def _get_package_version() -> str: "Operating System :: Microsoft :: Windows", "Operating System :: MacOS :: MacOS X", "Operating System :: POSIX :: Linux", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", ], + python_requires=">=3.8", ) diff --git a/tests/conftest.py b/tests/conftest.py index 827b862b..5f5f7f4c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ from dbt.events.eventmgr import LineFormat, NoFilter from dbt.events.functions import EVENT_MANAGER, _get_stdout_config -# Import the fuctional fixtures as a plugin +# Import the functional fixtures as a plugin # Note: fixtures with session scope need to be local pytest_plugins = ["dbt.tests.fixtures.project"] @@ -42,11 +42,14 @@ def dbt_debug_caplog() -> StringIO: def _setup_custom_caplog(name: str, level: EventLevel): capture_config = _get_stdout_config( - line_format=LineFormat.PlainText, level=level, use_colors=False, debug=True, log_cache_events=True, quiet=False + line_format=LineFormat.PlainText, + level=level, + use_colors=False, + log_cache_events=True, ) capture_config.name = name capture_config.filter = NoFilter - stringbuf = StringIO() - capture_config.output_stream = stringbuf + string_buf = StringIO() + capture_config.output_stream = string_buf EVENT_MANAGER.add_logger(capture_config) - return stringbuf + return string_buf diff --git a/tests/functional/adapter/fixture_split_parts.py b/tests/functional/adapter/fixture_split_parts.py new file mode 100644 index 00000000..c9ff3d7c --- /dev/null +++ b/tests/functional/adapter/fixture_split_parts.py @@ -0,0 +1,39 @@ +models__test_split_part_sql = """ +with data as ( + + select * from {{ ref('data_split_part') }} + +) + +select + {{ split_part('parts', 'split_on', 1) }} as actual, + result_1 as expected + +from data + +union all + +select + {{ split_part('parts', 'split_on', 2) }} as actual, + result_2 as expected + +from data + +union all + +select + {{ split_part('parts', 'split_on', 3) }} as actual, + result_3 as expected + +from data +""" + +models__test_split_part_yml = """ +version: 2 +models: + - name: test_split_part + tests: + - assert_equal: + actual: actual + expected: expected +""" diff --git a/tests/functional/adapter/test_partitions.py b/tests/functional/adapter/test_partitions.py new file mode 100644 index 00000000..68e639e6 --- /dev/null +++ b/tests/functional/adapter/test_partitions.py @@ -0,0 +1,127 @@ +import pytest + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +# this query generates 212 records +test_partitions_model_sql = """ +select + random() as rnd, + cast(date_column as date) as date_column, + doy(date_column) as doy +from ( + values ( + sequence(from_iso8601_date('2023-01-01'), from_iso8601_date('2023-07-31'), interval '1' day) + ) +) as t1(date_array) +cross join unnest(date_array) as t2(date_column) +""" + + +class TestHiveTablePartitions: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"models": {"+table_type": "hive", "+materialized": "table", "+partitioned_by": ["date_column", "doy"]}} + + @pytest.fixture(scope="class") + def models(self): + return { + "test_hive_partitions.sql": test_partitions_model_sql, + } + + def test__check_incremental_run_with_partitions(self, project): + relation_name = "test_hive_partitions" + model_run_result_row_count_query = "select count(*) as records from {}.{}".format( + project.test_schema, relation_name + ) + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 212 + + +class TestIcebergTablePartitions: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + "+materialized": "table", + "+partitioned_by": ["DAY(date_column)", "doy"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_iceberg_partitions.sql": test_partitions_model_sql, + } + + def test__check_incremental_run_with_partitions(self, project): + relation_name = "test_iceberg_partitions" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 212 + + +class TestIcebergIncrementalPartitions: + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "models": { + "+table_type": "iceberg", + "+materialized": "incremental", + "+incremental_strategy": "merge", + "+unique_key": "doy", + "+partitioned_by": ["DAY(date_column)", "doy"], + } + } + + @pytest.fixture(scope="class") + def models(self): + return { + "test_iceberg_partitions_incremental.sql": test_partitions_model_sql, + } + + def test__check_incremental_run_with_partitions(self, project): + """ + Check that the incremental run works with iceberg and partitioned datasets + """ + + relation_name = "test_iceberg_partitions_incremental" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + + first_model_run = run_dbt(["run", "--select", relation_name, "--full-refresh"]) + first_model_run_result = first_model_run.results[0] + + # check that the model run successfully + assert first_model_run_result.status == RunStatus.Success + + records_count_first_run = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert records_count_first_run == 212 + + incremental_model_run = run_dbt(["run", "--select", relation_name]) + + incremental_model_run_result = incremental_model_run.results[0] + + # check that the model run successfully after incremental run + assert incremental_model_run_result.status == RunStatus.Success + + incremental_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + + assert incremental_records_count == 212 diff --git a/tests/functional/adapter/utils/test_utils.py b/tests/functional/adapter/utils/test_utils.py index b998b6c7..83f743e9 100644 --- a/tests/functional/adapter/utils/test_utils.py +++ b/tests/functional/adapter/utils/test_utils.py @@ -3,6 +3,10 @@ models__test_datediff_sql, seeds__data_datediff_csv, ) +from tests.functional.adapter.fixture_split_parts import ( + models__test_split_part_sql, + models__test_split_part_yml, +) from dbt.tests.adapter.utils.fixture_datediff import models__test_datediff_yml from dbt.tests.adapter.utils.test_any_value import BaseAnyValue @@ -100,7 +104,12 @@ class TestRight(BaseRight): class TestSplitPart(BaseSplitPart): - pass + @pytest.fixture(scope="class") + def models(self): + return { + "test_split_part.yml": models__test_split_part_yml, + "test_split_part.sql": self.interpolate_macro_namespace(models__test_split_part_sql, "split_part"), + } class TestStringLiteral(BaseStringLiteral): diff --git a/tests/unit/constants.py b/tests/unit/constants.py index 0513ab69..8fe9aa82 100644 --- a/tests/unit/constants.py +++ b/tests/unit/constants.py @@ -1,6 +1,7 @@ CATALOG_ID = "12345678910" DATA_CATALOG_NAME = "awsdatacatalog" SHARED_DATA_CATALOG_NAME = "9876543210" +FEDERATED_QUERY_CATALOG_NAME = "federated_query_data_source" DATABASE_NAME = "test_dbt_athena" BUCKET = "test-dbt-athena" AWS_REGION = "eu-west-1" diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 30d39e4b..f14ae523 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -4,7 +4,10 @@ import agate import boto3 +import botocore import pytest + +# from botocore.client.BaseClient import _make_api_call from moto import mock_athena, mock_glue, mock_s3, mock_sts from moto.core import DEFAULT_ACCOUNT_ID @@ -14,6 +17,7 @@ from dbt.adapters.athena.connections import AthenaCursor, AthenaParameterFormatter from dbt.adapters.athena.exceptions import S3LocationException from dbt.adapters.athena.relation import AthenaRelation, TableType +from dbt.adapters.athena.utils import AthenaCatalogType from dbt.clients import agate_helper from dbt.contracts.connection import ConnectionState from dbt.contracts.files import FileHash @@ -28,6 +32,7 @@ BUCKET, DATA_CATALOG_NAME, DATABASE_NAME, + FEDERATED_QUERY_CATALOG_NAME, S3_STAGING_DIR, SHARED_DATA_CATALOG_NAME, ) @@ -66,6 +71,7 @@ def setup_method(self, _): ("awsdatacatalog", "quux"), ("awsdatacatalog", "baz"), (SHARED_DATA_CATALOG_NAME, "foo"), + (FEDERATED_QUERY_CATALOG_NAME, "foo"), } self.mock_manifest.nodes = { "model.root.model1": CompiledNode( @@ -212,6 +218,42 @@ def setup_method(self, _): raw_code="select * from source_table", language="", ), + "model.root.model5": CompiledNode( + name="model5", + database=FEDERATED_QUERY_CATALOG_NAME, + schema="foo", + resource_type=NodeType.Model, + unique_id="model.root.model5", + alias="bar", + fqn=["root", "model5"], + package_name="root", + refs=[], + sources=[], + depends_on=DependsOn(), + config=NodeConfig.from_dict( + { + "enabled": True, + "materialized": "table", + "persist_docs": {}, + "post-hook": [], + "pre-hook": [], + "vars": {}, + "meta": {"owner": "data-engineers"}, + "quoting": {}, + "column_types": {}, + "tags": [], + } + ), + tags=[], + path="model5.sql", + original_file_path="model5.sql", + compiled=True, + extra_ctes_injected=False, + extra_ctes=[], + checksum=FileHash.from_contents(""), + raw_code="select * from source_table", + language="", + ), } @property @@ -612,9 +654,84 @@ def test__get_one_catalog_shared_catalog(self, mock_aws_service): for row in actual.rows.values(): assert row.values() in expected_rows + @mock_athena + def test__get_one_catalog_federated_query_catalog(self, mock_aws_service): + mock_aws_service.create_data_catalog( + catalog_name=FEDERATED_QUERY_CATALOG_NAME, catalog_type=AthenaCatalogType.LAMBDA + ) + mock_information_schema = mock.MagicMock() + mock_information_schema.path.database = FEDERATED_QUERY_CATALOG_NAME + + # Original botocore _make_api_call function + orig = botocore.client.BaseClient._make_api_call + + # Mocking this as list_table_metadata and creating non glue tables is not supported by moto. + # Followed this guide: http://docs.getmoto.org/en/latest/docs/services/patching_other_services.html + def mock_athena_list_table_metadata(self, operation_name, kwarg): + if operation_name == "ListTableMetadata": + return { + "TableMetadataList": [ + { + "Name": "bar", + "TableType": "EXTERNAL_TABLE", + "Columns": [ + { + "Name": "id", + "Type": "string", + }, + { + "Name": "country", + "Type": "string", + }, + ], + "PartitionKeys": [ + { + "Name": "dt", + "Type": "date", + }, + ], + } + ], + } + # If we don't want to patch the API call + return orig(self, operation_name, kwarg) + + self.adapter.acquire_connection("dummy") + with patch("botocore.client.BaseClient._make_api_call", new=mock_athena_list_table_metadata): + actual = self.adapter._get_one_catalog( + mock_information_schema, + { + "foo": {"bar"}, + }, + self.mock_manifest, + ) + + expected_column_names = ( + "table_database", + "table_schema", + "table_name", + "table_type", + "table_comment", + "column_name", + "column_index", + "column_type", + "column_comment", + "table_owner", + ) + expected_rows = [ + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "id", 0, "string", None, "data-engineers"), + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "country", 1, "string", None, "data-engineers"), + (FEDERATED_QUERY_CATALOG_NAME, "foo", "bar", "table", None, "dt", 2, "date", None, "data-engineers"), + ] + + assert actual.column_names == expected_column_names + assert len(actual.rows) == len(expected_rows) + for row in actual.rows.values(): + assert row.values() in expected_rows + def test__get_catalog_schemas(self): res = self.adapter._get_catalog_schemas(self.mock_manifest) - assert len(res.keys()) == 2 + assert len(res.keys()) == 3 information_schema_0 = list(res.keys())[0] assert information_schema_0.name == "INFORMATION_SCHEMA" @@ -632,6 +749,14 @@ def test__get_catalog_schemas(self): assert set(relations.keys()) == {"foo"} assert list(relations.values()) == [{"bar"}] + information_schema_1 = list(res.keys())[2] + assert information_schema_1.name == "INFORMATION_SCHEMA" + assert information_schema_1.schema is None + assert information_schema_1.database == FEDERATED_QUERY_CATALOG_NAME + relations = list(res.values())[1] + assert set(relations.keys()) == {"foo"} + assert list(relations.values()) == [{"bar"}] + @mock_athena @mock_sts def test__get_data_catalog(self, mock_aws_service): @@ -696,7 +821,7 @@ def test_list_relations_without_caching_with_non_glue_data_catalog( self, parent_list_relations_without_caching, mock_aws_service ): data_catalog_name = "other_data_catalog" - mock_aws_service.create_data_catalog(data_catalog_name, "HIVE") + mock_aws_service.create_data_catalog(data_catalog_name, AthenaCatalogType.HIVE) schema_relation = self.adapter.Relation.create( database=data_catalog_name, schema=DATABASE_NAME, diff --git a/tests/unit/utils.py b/tests/unit/utils.py index 60ec02ed..097f6ce9 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -5,6 +5,7 @@ import agate import boto3 +from dbt.adapters.athena.utils import AthenaCatalogType from dbt.config.project import PartialProject from .constants import AWS_REGION, BUCKET, CATALOG_ID, DATA_CATALOG_NAME, DATABASE_NAME @@ -146,10 +147,18 @@ def _make_table_of(self, rows, column_types): class MockAWSService: def create_data_catalog( - self, catalog_name: str = DATA_CATALOG_NAME, catalog_type: str = "GLUE", catalog_id: str = CATALOG_ID + self, + catalog_name: str = DATA_CATALOG_NAME, + catalog_type: AthenaCatalogType = AthenaCatalogType.GLUE, + catalog_id: str = CATALOG_ID, ): athena = boto3.client("athena", region_name=AWS_REGION) - athena.create_data_catalog(Name=catalog_name, Type=catalog_type, Parameters={"catalog-id": catalog_id}) + parameters = {} + if catalog_type == AthenaCatalogType.GLUE: + parameters = {"catalog-id": catalog_id} + else: + parameters = {"catalog": catalog_name} + athena.create_data_catalog(Name=catalog_name, Type=catalog_type.value, Parameters=parameters) def create_database(self, name: str = DATABASE_NAME, catalog_id: str = CATALOG_ID): glue = boto3.client("glue", region_name=AWS_REGION)