diff --git a/CHANGELOG.md b/CHANGELOG.md index 3bc6def4..88a39a81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,2 +1,12 @@ +## v1.0.4 + +### Bugfixes +* Add support for partition fields of type timestamp +* Use correct escaper for INSERT queries +* Share same boto session between every calls + +### Features +* Get model owner from manifest + ## v1.0.3 * Fix issue on fetching partitions from glue, using pagination diff --git a/dbt/adapters/athena/connections.py b/dbt/adapters/athena/connections.py index dab4bf95..a2bb5374 100644 --- a/dbt/adapters/athena/connections.py +++ b/dbt/adapters/athena/connections.py @@ -27,6 +27,8 @@ from tenacity.stop import stop_after_attempt from tenacity.wait import wait_exponential +from dbt.adapters.athena.session import get_boto3_session + logger = AdapterLogger("Athena") @@ -51,7 +53,8 @@ def unique_field(self): return self.host def _connection_keys(self) -> Tuple[str, ...]: - return "s3_staging_dir", "work_group", "region_name", "database", "schema", "poll_interval", "aws_profile_name", "endpoing_url" + return "s3_staging_dir", "work_group", "region_name", "database", "schema", "poll_interval", \ + "aws_profile_name", "endpoing_url" class AthenaCursor(Cursor): @@ -140,13 +143,12 @@ def open(cls, connection: Connection) -> Connection: handle = AthenaConnection( s3_staging_dir=creds.s3_staging_dir, endpoint_url=creds.endpoint_url, - region_name=creds.region_name, schema_name=creds.schema, work_group=creds.work_group, cursor_class=AthenaCursor, formatter=AthenaParameterFormatter(), poll_interval=creds.poll_interval, - profile_name=creds.aws_profile_name, + session=get_boto3_session(connection), retry_config=RetryConfig( attempt=creds.num_retries, exceptions=( @@ -213,9 +215,7 @@ def format( raise ProgrammingError("Query is none or empty.") operation = operation.strip() - if operation.upper().startswith("SELECT") or operation.upper().startswith( - "WITH" - ): + if operation.upper().startswith(("SELECT", "WITH", "INSERT")): escaper = _escape_presto else: # Fixes ParseException that comes with newer version of PyAthena diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index c43d3f80..d79ff812 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -1,6 +1,5 @@ import agate import re -import boto3 from botocore.exceptions import ClientError from itertools import chain from threading import Lock @@ -20,6 +19,7 @@ boto3_client_lock = Lock() + class AthenaAdapter(SQLAdapter): ConnectionManager = AthenaConnectionManager Relation = AthenaRelation @@ -61,8 +61,8 @@ def clean_up_partitions( client = conn.handle with boto3_client_lock: - glue_client = boto3.client('glue', region_name=client.region_name) - s3_resource = boto3.resource('s3', region_name=client.region_name) + glue_client = client.session.client('glue', region_name=client.region_name) + s3_resource = client.session.resource('s3', region_name=client.region_name) paginator = glue_client.get_paginator("get_partitions") partition_params = { "DatabaseName": database_name, @@ -74,7 +74,8 @@ def clean_up_partitions( partitions = partition_pg.build_full_result().get('Partitions') s3_rg = re.compile('s3://([^/]*)/(.*)') for partition in partitions: - logger.debug("Deleting objects for partition '{}' at '{}'", partition["Values"], partition["StorageDescriptor"]["Location"]) + logger.debug("Deleting objects for partition '{}' at '{}'", partition["Values"], + partition["StorageDescriptor"]["Location"]) m = s3_rg.match(partition["StorageDescriptor"]["Location"]) if m is not None: bucket_name = m.group(1) @@ -90,7 +91,7 @@ def clean_up_table( conn = self.connections.get_thread_connection() client = conn.handle with boto3_client_lock: - glue_client = boto3.client('glue', region_name=client.region_name) + glue_client = client.session.client('glue', region_name=client.region_name) try: table = glue_client.get_table( DatabaseName=database_name, @@ -108,7 +109,7 @@ def clean_up_table( if m is not None: bucket_name = m.group(1) prefix = m.group(2) - s3_resource = boto3.resource('s3', region_name=client.region_name) + s3_resource = client.session.resource('s3', region_name=client.region_name) s3_bucket = s3_resource.Bucket(bucket_name) s3_bucket.objects.filter(Prefix=prefix).delete() @@ -118,13 +119,33 @@ def quote_seed_column( ) -> str: return super().quote_seed_column(column, False) + def _join_catalog_table_owners(self, table: agate.Table, manifest: Manifest) -> agate.Table: + owners = [] + # Get the owner for each model from the manifest + for node in manifest.nodes.values(): + if node.resource_type == "model": + owners.append({ + "table_database": node.database, + "table_schema": node.schema, + "table_name": node.alias, + "table_owner": node.config.meta.get("owner"), + }) + owners_table = agate.Table.from_object(owners) + + # Join owners with the results from catalog + join_keys = ["table_database", "table_schema", "table_name"] + return table.join( + right_table=owners_table, + left_key=join_keys, + right_key=join_keys, + ) + def _get_one_catalog( self, information_schema: InformationSchema, schemas: Dict[str, Optional[Set[str]]], manifest: Manifest, ) -> agate.Table: - kwargs = {"information_schema": information_schema, "schemas": schemas} table = self.execute_macro( GET_CATALOG_MACRO_NAME, @@ -134,9 +155,8 @@ def _get_one_catalog( manifest=manifest, ) - results = self._catalog_filter_table(table, manifest) - return results - + filtered_table = self._catalog_filter_table(table, manifest) + return self._join_catalog_table_owners(filtered_table, manifest) def _get_catalog_schemas(self, manifest: Manifest) -> AthenaSchemaSearchMap: info_schema_name_map = AthenaSchemaSearchMap() @@ -155,8 +175,8 @@ def _get_data_catalog(self, catalog_name): conn = self.connections.get_thread_connection() client = conn.handle with boto3_client_lock: - athena_client = boto3.client('athena', region_name=client.region_name) - + athena_client = client.session.client('athena', region_name=client.region_name) + response = athena_client.get_data_catalog(Name=catalog_name) return response['DataCatalog'] @@ -175,7 +195,7 @@ def list_relations_without_caching( conn = self.connections.get_thread_connection() client = conn.handle with boto3_client_lock: - glue_client = boto3.client('glue', region_name=client.region_name) + glue_client = client.session.client('glue', region_name=client.region_name) paginator = glue_client.get_paginator('get_tables') kwargs = { @@ -183,7 +203,7 @@ def list_relations_without_caching( } # If the catalog is `awsdatacatalog` we don't need to pass CatalogId as boto3 infers it from the account Id. if catalog_id: - kwargs['CatalogId'] = catalog_id + kwargs['CatalogId'] = catalog_id page_iterator = paginator.paginate(**kwargs) relations = [] diff --git a/dbt/adapters/athena/query_headers.py b/dbt/adapters/athena/query_headers.py index d5ed74db..a871550e 100644 --- a/dbt/adapters/athena/query_headers.py +++ b/dbt/adapters/athena/query_headers.py @@ -1,9 +1,10 @@ import dbt.adapters.base.query_headers + class _QueryComment(dbt.adapters.base.query_headers._QueryComment): """ - Athena DDL does not always respect /* ... */ block quotations. - This function is the same as _QueryComment.add except that + Athena DDL does not always respect /* ... */ block quotations. + This function is the same as _QueryComment.add except that a leading "-- " is prepended to the query_comment and any newlines in the query_comment are replaced with " ". This allows the default query_comment to be added to `create external table` statements. diff --git a/dbt/adapters/athena/relation.py b/dbt/adapters/athena/relation.py index d3dd7bea..c6262cd6 100644 --- a/dbt/adapters/athena/relation.py +++ b/dbt/adapters/athena/relation.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, Optional, Set +from typing import Dict, Optional, Set from dbt.adapters.base.relation import BaseRelation, InformationSchema, Policy @@ -16,6 +16,7 @@ class AthenaRelation(BaseRelation): quote_character: str = "" include_policy: Policy = AthenaIncludePolicy() + class AthenaSchemaSearchMap(Dict[InformationSchema, Dict[str, Set[Optional[str]]]]): """A utility class to keep track of what information_schema tables to search for what schemas and relations. The schema and relation values are all diff --git a/dbt/adapters/athena/session.py b/dbt/adapters/athena/session.py new file mode 100644 index 00000000..b4b55f34 --- /dev/null +++ b/dbt/adapters/athena/session.py @@ -0,0 +1,25 @@ +from typing import Optional + +import boto3.session +from dbt.contracts.connection import Connection + + +__BOTO3_SESSION__: Optional[boto3.session.Session] = None + + +def get_boto3_session(connection: Connection = None) -> boto3.session.Session: + def init_session(): + global __BOTO3_SESSION__ + __BOTO3_SESSION__ = boto3.session.Session( + region_name=connection.credentials.region_name, + profile_name=connection.credentials.aws_profile_name, + ) + + if not __BOTO3_SESSION__: + if connection is None: + raise RuntimeError( + 'A Connection object needs to be passed to initialize the boto3 session for the first time' + ) + init_session() + + return __BOTO3_SESSION__ diff --git a/dbt/include/athena/macros/adapters/metadata.sql b/dbt/include/athena/macros/adapters/metadata.sql index b4c67228..b76dd3ce 100644 --- a/dbt/include/athena/macros/adapters/metadata.sql +++ b/dbt/include/athena/macros/adapters/metadata.sql @@ -17,7 +17,6 @@ else table_type end as table_type, - null as table_owner, null as table_comment from {{ information_schema }}.tables @@ -54,8 +53,7 @@ columns.column_name, columns.column_index, columns.column_type, - columns.column_comment, - tables.table_owner + columns.column_comment from tables join columns diff --git a/dbt/include/athena/macros/materializations/models/incremental/helpers.sql b/dbt/include/athena/macros/materializations/models/incremental/helpers.sql index 3f6b1f59..5ec0cdf0 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/helpers.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/helpers.sql @@ -40,6 +40,8 @@ {%- set value = "'" + col + "'" -%} {%- elif column_type == 'date' -%} {%- set value = "'" + col|string + "'" -%} + {%- elif column_type == 'timestamp' -%} + {%- set value = "'" + col|string + "'" -%} {%- else -%} {%- do exceptions.raise_compiler_error('Need to add support for column type ' + column_type) -%} {%- endif -%} diff --git a/setup.py b/setup.py index 2729b804..ea62a46f 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ package_name = "dbt-athena-community" dbt_version = "1.0" -package_version = "1.0.3" +package_version = "1.0.4" description = """The athena adapter plugin for dbt (data build tool)""" if not package_version.startswith(dbt_version):