Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] patched 1.0.1 #3

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
37 changes: 23 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,19 @@ stored login info. You can configure the AWS profile name to use via `aws_profil

A dbt profile can be configured to run against AWS Athena using the following configuration:

| Option | Description | Required? | Example |
|---------------- |-------------------------------------------------------------------------------- |----------- |-------------------- |
| s3_staging_dir | S3 location to store Athena query results and metadata | Required | `s3://bucket/dbt/` |
| region_name | AWS region of your Athena instance | Required | `eu-west-1` |
| schema | Specify the schema (Athena database) to build models into (lowercase **only**) | Required | `dbt` |
| database | Specify the database (Data catalog) to build models into (lowercase **only**) | Required | `awsdatacatalog` |
| poll_interval | Interval in seconds to use for polling the status of query results in Athena | Optional | `5` |
| aws_profile_name| Profile to use from your AWS shared credentials file. | Optional | `my-profile` |
| work_group| Identifier of Athena workgroup | Optional | `my-custom-workgroup` |
| num_retries| Number of times to retry a failing query | Optional | `3` | `5`
| Option | Description | Required? | Example |
|---------------- |-------------------------------------------------------------------------------- |----------- |---------------------- |
| s3_staging_dir | S3 location to store Athena query results and metadata | Required | `s3://bucket/dbt/` |
| region_name | AWS region of your Athena instance | Required | `eu-west-1` |
| schema | Specify the schema (Athena database) to build models into (lowercase **only**) | Required | `dbt` |
| database | Specify the database (Data catalog) to build models into (lowercase **only**) | Required | `awsdatacatalog` |
| poll_interval | Interval in seconds to use for polling the status of query results in Athena | Optional | `5` |
| aws_profile_name| Profile to use from your AWS shared credentials file. | Optional | `my-profile` |
| work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` |
| num_retries | Number of times to retry a failing query | Optional | `3` |
| s3_data_dir | Prefix for storing tables, if different from the connection's `s3_staging_dir` | Optional | `s3://bucket2/dbt/` |
| s3_data_naming | How to generate table paths in `s3_data_dir`: `uuid/schema_table` | Optional | `uuid` |


**Example profiles.yml entry:**
```yaml
Expand Down Expand Up @@ -78,9 +81,7 @@ _Additional information_
#### Table Configuration

* `external_location` (`default=none`)
* The location where Athena saves your table in Amazon S3
* If `none` then it will default to `{s3_staging_dir}/tables`
* If you are using a static value, when your table/partition is recreated underlying data will be cleaned up and overwritten by new data
* If set, the full S3 path in which the table will be saved.
* `partitioned_by` (`default=none`)
* An array list of columns by which the table will be partitioned
* Limited to creation of 100 partitions (_currently_)
Expand All @@ -93,7 +94,15 @@ _Additional information_
* Supports `ORC`, `PARQUET`, `AVRO`, `JSON`, or `TEXTFILE`
* `field_delimiter` (`default=none`)
* Custom field delimiter, for when format is set to `TEXTFILE`


The location in which a table is saved is determined by:

1. If `external_location` is defined, that value is used.
2. If `s3_data_dir` is defined, the path is determined by that and `s3_data_naming`:
+ `s3_data_naming=uuid`: `{s3_data_dir}/{uuid4()}/`
+ `s3_data_naming=schema_table`: `{s3_data_dir}/{schema}/{table}/`
3. Otherwise, the default location for a CTAS query is used, which will depend on how your workgroup is configured.

More information: [CREATE TABLE AS][create-table-as]

[run_started_at]: https://docs.getdbt.com/reference/dbt-jinja-functions/run_started_at
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/athena/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "1.0.1"
version = "1.0.1+nvlt3"
2 changes: 2 additions & 0 deletions dbt/adapters/athena/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class AthenaCredentials(Credentials):
poll_interval: float = 1.0
_ALIASES = {"catalog": "database"}
num_retries: Optional[int] = 5
s3_data_dir: Optional[str] = None
s3_data_naming: Optional[str] = "uuid"

@property
def type(self) -> str:
Expand Down
158 changes: 143 additions & 15 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from uuid import uuid4
import agate
import re
import boto3
import boto3.session
from botocore.exceptions import ClientError
from typing import Optional
from typing import Optional, List

from dbt.adapters.base import available
from dbt.adapters.base import available, Column
from dbt.adapters.sql import SQLAdapter
from dbt.adapters.athena import AthenaConnectionManager
from dbt.adapters.athena.relation import AthenaRelation
from dbt.events import AdapterLogger
from dbt.contracts.relation import RelationType
logger = AdapterLogger("Athena")

class AthenaAdapter(SQLAdapter):
Expand Down Expand Up @@ -38,30 +39,94 @@ def convert_datetime_type(
return "timestamp"

@available
def s3_uuid_table_location(self):
def s3_table_prefix(self) -> str:
"""
Returns the root location for storing tables in S3.

This is `s3_data_dir`, if set, and `s3_staging_dir/tables/` if not.

We generate a value here even if `s3_data_dir` is not set,
since creating a seed table requires a non-default location.
"""
conn = self.connections.get_thread_connection()
client = conn.handle
creds = conn.credentials
if creds.s3_data_dir is not None:
return creds.s3_data_dir
else:
return f"{creds.s3_staging_dir}tables/"

@available
def s3_uuid_table_location(self) -> str:
"""
Returns a random location for storing a table, using a UUID as
the final directory part
"""
return f"{self.s3_table_prefix()}{str(uuid4())}/"


@available
def temp_table_suffix(self, initial="__dbt_tmp", length=8):
return f"{initial}_{str(uuid4())[:length]}"


@available
def s3_schema_table_location(self, schema_name: str, table_name: str) -> str:
"""
Returns a fixed location for storing a table determined by the
(athena) schema and table name
"""
return f"{self.s3_table_prefix()}{schema_name}/{table_name}/"

@available
def s3_table_location(self, schema_name: str, table_name: str) -> str:
"""
Returns either a UUID or database/table prefix for storing a table,
depending on the value of s3_table
"""
conn = self.connections.get_thread_connection()
creds = conn.credentials
if creds.s3_data_naming == "schema_table":
return self.s3_schema_table_location(schema_name, table_name)
elif creds.s3_data_naming == "uuid":
return self.s3_uuid_table_location()
else:
raise ValueError(f"Unknown value for s3_data_naming: {creds.s3_data_naming}")

@available
def has_s3_data_dir(self) -> bool:
"""
Returns true if the user has specified `s3_data_dir`, and
we should set `external_location
"""
conn = self.connections.get_thread_connection()
creds = conn.credentials
return creds.s3_data_dir is not None

return f"{client.s3_staging_dir}tables/{str(uuid4())}/"

@available
def clean_up_partitions(
self, database_name: str, table_name: str, where_condition: str
):
# Look up Glue partitions & clean up
conn = self.connections.get_thread_connection()
client = conn.handle
creds = conn.credentials
session = boto3.session.Session(region_name=creds.region_name, profile_name=creds.aws_profile_name)

glue_client = boto3.client('glue', region_name=client.region_name)
s3_resource = boto3.resource('s3', region_name=client.region_name)
partitions = glue_client.get_partitions(
glue_client = session.client('glue')
s3_resource = session.resource('s3')
paginator = glue_client.get_paginator("get_partitions")
partition_pages = paginator.paginate(
# CatalogId='123456789012', # Need to make this configurable if it is different from default AWS Account ID
DatabaseName=database_name,
TableName=table_name,
Expression=where_condition
Expression=where_condition,
ExcludeColumnSchema=True,
)
partitions = []
for page in partition_pages:
partitions.extend(page["Partitions"])
p = re.compile('s3://([^/]*)/(.*)')
for partition in partitions["Partitions"]:
for partition in partitions:
logger.debug("Deleting objects for partition '{}' at '{}'", partition["Values"], partition["StorageDescriptor"]["Location"])
m = p.match(partition["StorageDescriptor"]["Location"])
if m is not None:
Expand All @@ -76,8 +141,10 @@ def clean_up_table(
):
# Look up Glue partitions & clean up
conn = self.connections.get_thread_connection()
client = conn.handle
glue_client = boto3.client('glue', region_name=client.region_name)
creds = conn.credentials
session = boto3.session.Session(region_name=creds.region_name, profile_name=creds.aws_profile_name)

glue_client = session.client('glue')
try:
table = glue_client.get_table(
DatabaseName=database_name,
Expand All @@ -95,7 +162,7 @@ def clean_up_table(
if m is not None:
bucket_name = m.group(1)
prefix = m.group(2)
s3_resource = boto3.resource('s3', region_name=client.region_name)
s3_resource = session.resource('s3')
s3_bucket = s3_resource.Bucket(bucket_name)
s3_bucket.objects.filter(Prefix=prefix).delete()

Expand All @@ -104,3 +171,64 @@ def quote_seed_column(
self, column: str, quote_config: Optional[bool]
) -> str:
return super().quote_seed_column(column, False)

def get_columns_in_relation(self, relation: AthenaRelation) -> List[Column]:
conn = self.connections.get_thread_connection()
creds = conn.credentials
session = boto3.session.Session(region_name=creds.region_name, profile_name=creds.aws_profile_name)
glue_client = session.client('glue')

table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.identifier)
return [Column(c["Name"], c["Type"]) for c in table["Table"]["StorageDescriptor"]["Columns"] + table["Table"]["PartitionKeys"]]

def list_schemas(self, database: str) -> List[str]:
conn = self.connections.get_thread_connection()
creds = conn.credentials
session = boto3.session.Session(region_name=creds.region_name, profile_name=creds.aws_profile_name)
glue_client = session.client('glue')
paginator = glue_client.get_paginator("get_databases")

result = []
logger.debug("CALL glue.get_databases()")
for page in paginator.paginate():
for db in page["DatabaseList"]:
result.append(db["Name"])
return result

def list_relations_without_caching(self, schema_relation: AthenaRelation) -> List[AthenaRelation]:
conn = self.connections.get_thread_connection()
creds = conn.credentials
session = boto3.session.Session(region_name=creds.region_name, profile_name=creds.aws_profile_name)
glue_client = session.client('glue')
paginator = glue_client.get_paginator("get_tables")

result = []
logger.debug("CALL glue.get_tables('{}')", schema_relation.schema)
for page in paginator.paginate(DatabaseName=schema_relation.schema):
for table in page["TableList"]:
if table["TableType"] == "EXTERNAL_TABLE":
table_type = RelationType.Table
elif table["TableType"] == "VIRTUAL_VIEW":
table_type = RelationType.View
else:
raise ValueError(f"Unknown TableType for {table['Name']}: {table['TableType']}")
rel = AthenaRelation.create(schema=table["DatabaseName"], identifier=table["Name"], database=schema_relation.database, type=table_type)
result.append(rel)

return result

@available
def delete_table(self, relation: AthenaRelation):
conn = self.connections.get_thread_connection()
creds = conn.credentials
session = boto3.session.Session(region_name=creds.region_name, profile_name=creds.aws_profile_name)
glue_client = session.client('glue')

logger.debug("CALL glue.delete_table({}, {})", relation.schema, relation.identifier)
try:
glue_client.delete_table(DatabaseName=relation.schema, Name=relation.identifier)
except ClientError as e:
if e.response['Error']['Code'] == 'EntityNotFoundException':
logger.debug("Table '{}' does not exists - Ignoring", relation)
else:
raise
21 changes: 1 addition & 20 deletions dbt/include/athena/macros/adapters/columns.sql
Original file line number Diff line number Diff line change
@@ -1,22 +1,3 @@
{% macro athena__get_columns_in_relation(relation) -%}
{% call statement('get_columns_in_relation', fetch_result=True) %}

select
column_name,
data_type,
null as character_maximum_length,
null as numeric_precision,
null as numeric_scale

from {{ relation.information_schema('columns') }}
where LOWER(table_name) = LOWER('{{ relation.identifier }}')
{% if relation.schema %}
and LOWER(table_schema) = LOWER('{{ relation.schema }}')
{% endif %}
order by ordinal_position

{% endcall %}

{% set table = load_result('get_columns_in_relation').table %}
{% do return(sql_convert_columns_in_relation(table)) %}
{{ return(adapter.get_columns_in_relation(relation)) }}
{% endmacro %}
36 changes: 2 additions & 34 deletions dbt/include/athena/macros/adapters/metadata.sql
Original file line number Diff line number Diff line change
Expand Up @@ -79,42 +79,10 @@


{% macro athena__list_schemas(database) -%}
{% call statement('list_schemas', fetch_result=True) %}
select
distinct schema_name

from {{ information_schema_name(database) }}.schemata
{% endcall %}
{{ return(load_result('list_schemas').table) }}
{{ return(adapter.list_schemas()) }}
{% endmacro %}


{% macro athena__list_relations_without_caching(schema_relation) %}
{% call statement('list_relations_without_caching', fetch_result=True) -%}
WITH views AS (
select
table_catalog as database,
table_name as name,
table_schema as schema
from {{ schema_relation.information_schema() }}.views
where LOWER(table_schema) = LOWER('{{ schema_relation.schema }}')
), tables AS (
select
table_catalog as database,
table_name as name,
table_schema as schema

from {{ schema_relation.information_schema() }}.tables
where LOWER(table_schema) = LOWER('{{ schema_relation.schema }}')

-- Views appear in both `tables` and `views`, so excluding them from tables
EXCEPT

select * from views
)
select views.*, 'view' AS table_type FROM views
UNION ALL
select tables.*, 'table' AS table_type FROM tables
{% endcall %}
{% do return(load_result('list_relations_without_caching').table) %}
{{ return(adapter.list_relations_without_caching(schema_relation)) }}
{% endmacro %}
8 changes: 2 additions & 6 deletions dbt/include/athena/macros/adapters/relation.sql
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
{% macro athena__drop_relation(relation) -%}
{% if config.get('incremental_strategy') == 'insert_overwrite' %}
{%- do adapter.clean_up_table(relation.schema, relation.table) -%}
{% endif %}
{% call statement('drop_relation', auto_begin=False) -%}
drop {{ relation.type }} if exists {{ relation }}
{%- endcall %}
{%- do adapter.clean_up_table(relation.schema, relation.table) -%}
{%- do adapter.delete_table(relation) -%}
{% endmacro %}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
{% set partitioned_by = config.get('partitioned_by', default=none) %}
{% set target_relation = this.incorporate(type='table') %}
{% set existing_relation = load_relation(this) %}
{% set tmp_relation = make_temp_relation(this) %}
{% set temp_table_suffix = adapter.temp_table_suffix() %}
{% set tmp_relation = make_temp_relation(this, suffix=temp_table_suffix) %}

{{ run_hooks(pre_hooks, inside_transaction=False) }}

Expand All @@ -28,7 +29,7 @@
{% do adapter.drop_relation(existing_relation) %}
{% set build_sql = create_table_as(False, target_relation, sql) %}
{% elif partitioned_by is not none and strategy == 'insert_overwrite' %}
{% set tmp_relation = make_temp_relation(target_relation) %}
{% set tmp_relation = make_temp_relation(target_relation, suffix=temp_table_suffix) %}
{% if tmp_relation is not none %}
{% do adapter.drop_relation(tmp_relation) %}
{% endif %}
Expand All @@ -37,7 +38,7 @@
{% set build_sql = incremental_insert(tmp_relation, target_relation) %}
{% do to_drop.append(tmp_relation) %}
{% else %}
{% set tmp_relation = make_temp_relation(target_relation) %}
{% set tmp_relation = make_temp_relation(target_relation, suffix=temp_table_suffix) %}
{% if tmp_relation is not none %}
{% do adapter.drop_relation(tmp_relation) %}
{% endif %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
with (
{%- if external_location is not none and not temporary %}
external_location='{{ external_location }}',
{%- elif adapter.has_s3_data_dir() -%}
external_location='{{ adapter.s3_table_location(relation.schema, relation.identifier) }}',
{%- endif %}
{%- if partitioned_by is not none %}
partitioned_by=ARRAY{{ partitioned_by | tojson | replace('\"', '\'') }},
Expand Down
Loading