From 89d18308a72bdb636d5e9b33ff2c4680fe6c8605 Mon Sep 17 00:00:00 2001
From: Gordon Ball <gordon.ball@northvolt.com>
Date: Fri, 17 Dec 2021 12:16:58 +0100
Subject: [PATCH] Support s3_data_dir and s3_data_naming

---
 README.md                                     | 37 +++++++-----
 dbt/adapters/athena/connections.py            |  2 +
 dbt/adapters/athena/impl.py                   | 59 ++++++++++++++++++-
 .../models/table/create_table_as.sql          |  2 +
 .../macros/materializations/seeds/helpers.sql |  2 +-
 5 files changed, 84 insertions(+), 18 deletions(-)

diff --git a/README.md b/README.md
index 42f32459..b043dd01 100644
--- a/README.md
+++ b/README.md
@@ -41,16 +41,19 @@ stored login info. You can configure the AWS profile name to use via `aws_profil
 
 A dbt profile can be configured to run against AWS Athena using the following configuration:
 
-| Option          | Description                                                                     | Required?  | Example             |
-|---------------- |-------------------------------------------------------------------------------- |----------- |-------------------- |
-| s3_staging_dir  | S3 location to store Athena query results and metadata                          | Required   | `s3://bucket/dbt/`  |
-| region_name     | AWS region of your Athena instance                                              | Required   | `eu-west-1`         |
-| schema          | Specify the schema (Athena database) to build models into (lowercase **only**)  | Required   | `dbt`               |
-| database        | Specify the database (Data catalog) to build models into (lowercase **only**)   | Required   | `awsdatacatalog`    |
-| poll_interval   | Interval in seconds to use for polling the status of query results in Athena    | Optional   | `5`                 |
-| aws_profile_name| Profile to use from your AWS shared credentials file.                           | Optional   | `my-profile`        |
-| work_group| Identifier of Athena workgroup   | Optional   | `my-custom-workgroup`        |
-| num_retries| Number of times to retry a failing query | Optional  | `3`  | `5`
+| Option          | Description                                                                     | Required?  | Example               |
+|---------------- |-------------------------------------------------------------------------------- |----------- |---------------------- |
+| s3_staging_dir  | S3 location to store Athena query results and metadata                          | Required   | `s3://bucket/dbt/`    |
+| region_name     | AWS region of your Athena instance                                              | Required   | `eu-west-1`           |
+| schema          | Specify the schema (Athena database) to build models into (lowercase **only**)  | Required   | `dbt`                 |
+| database        | Specify the database (Data catalog) to build models into (lowercase **only**)   | Required   | `awsdatacatalog`      |
+| poll_interval   | Interval in seconds to use for polling the status of query results in Athena    | Optional   | `5`                   |
+| aws_profile_name| Profile to use from your AWS shared credentials file.                           | Optional   | `my-profile`          |
+| work_group      | Identifier of Athena workgroup                                                  | Optional   | `my-custom-workgroup` |
+| num_retries     | Number of times to retry a failing query                                        | Optional   | `3`                   |
+| s3_data_dir     | Prefix for storing tables, if different from the connection's `s3_staging_dir`  | Optional   | `s3://bucket2/dbt/`   |
+| s3_data_naming  | How to generate table paths in `s3_data_dir`: `uuid/schema_table`               | Optional   | `uuid`                |
+
 
 **Example profiles.yml entry:**
 ```yaml
@@ -78,9 +81,7 @@ _Additional information_
 #### Table Configuration
 
 * `external_location` (`default=none`)
-  * The location where Athena saves your table in Amazon S3
-  * If `none` then it will default to `{s3_staging_dir}/tables`
-  * If you are using a static value, when your table/partition is recreated underlying data will be cleaned up and overwritten by new data
+  * If set, the full S3 path in which the table will be saved.
 * `partitioned_by` (`default=none`)
   * An array list of columns by which the table will be partitioned
   * Limited to creation of 100 partitions (_currently_)
@@ -93,7 +94,15 @@ _Additional information_
   * Supports `ORC`, `PARQUET`, `AVRO`, `JSON`, or `TEXTFILE`
 * `field_delimiter` (`default=none`)
   * Custom field delimiter, for when format is set to `TEXTFILE`
-  
+
+The location in which a table is saved is determined by:
+
+1. If `external_location` is defined, that value is used.
+2. If `s3_data_dir` is defined, the path is determined by that and `s3_data_naming`:
+   + `s3_data_naming=uuid`: `{s3_data_dir}/{uuid4()}/`
+   + `s3_data_naming=schema_table`: `{s3_data_dir}/{schema}/{table}/`
+3. Otherwise, the default location for a CTAS query is used, which will depend on how your workgroup is configured.
+
 More information: [CREATE TABLE AS][create-table-as]
 
 [run_started_at]: https://docs.getdbt.com/reference/dbt-jinja-functions/run_started_at
diff --git a/dbt/adapters/athena/connections.py b/dbt/adapters/athena/connections.py
index cc30862c..49811212 100644
--- a/dbt/adapters/athena/connections.py
+++ b/dbt/adapters/athena/connections.py
@@ -40,6 +40,8 @@ class AthenaCredentials(Credentials):
     poll_interval: float = 1.0
     _ALIASES = {"catalog": "database"}
     num_retries: Optional[int] = 5
+    s3_data_dir: Optional[str] = None
+    s3_data_naming: Optional[str] = "uuid"
 
     @property
     def type(self) -> str:
diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py
index f815d70a..9530c49e 100644
--- a/dbt/adapters/athena/impl.py
+++ b/dbt/adapters/athena/impl.py
@@ -38,11 +38,64 @@ def convert_datetime_type(
         return "timestamp"
 
     @available
-    def s3_uuid_table_location(self):
+    def s3_table_prefix(self) -> str:
+        """
+        Returns the root location for storing tables in S3.
+
+        This is `s3_data_dir`, if set, and `s3_staging_dir/tables/` if not.
+
+        We generate a value here even if `s3_data_dir` is not set,
+        since creating a seed table requires a non-default location.
+        """
         conn = self.connections.get_thread_connection()
-        client = conn.handle
+        creds = conn.credentials
+        if creds.s3_data_dir is not None:
+            return creds.s3_data_dir
+        else:
+            return f"{creds.s3_staging_dir}tables/"
+
+    @available
+    def s3_uuid_table_location(self) -> str:
+        """
+        Returns a random location for storing a table, using a UUID as
+        the final directory part
+        """
+        return f"{self.s3_table_prefix()}{str(uuid4())}/"
+
+
+    @available
+    def s3_schema_table_location(self, schema_name: str, table_name: str) -> str:
+        """
+        Returns a fixed location for storing a table determined by the
+        (athena) schema and table name
+        """
+        return f"{self.s3_table_prefix()}{schema_name}/{table_name}/"
+
+    @available
+    def s3_table_location(self, schema_name: str, table_name: str) -> str:
+        """
+        Returns either a UUID or database/table prefix for storing a table,
+        depending on the value of s3_table
+        """
+        conn = self.connections.get_thread_connection()
+        creds = conn.credentials
+        if creds.s3_data_naming == "schema_table":
+            return self.s3_schema_table_location(schema_name, table_name)
+        elif creds.s3_data_naming == "uuid":
+            return self.s3_uuid_table_location()
+        else:
+            raise ValueError(f"Unknown value for s3_data_naming: {creds.s3_data_naming}")
+
+    @available
+    def has_s3_data_dir(self) -> bool:
+        """
+        Returns true if the user has specified `s3_data_dir`, and
+        we should set `external_location
+        """
+        conn = self.connections.get_thread_connection()
+        creds = conn.credentials
+        return creds.s3_data_dir is not None
 
-        return f"{client.s3_staging_dir}tables/{str(uuid4())}/"
 
     @available
     def clean_up_partitions(
diff --git a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql
index 504ba148..af398c9a 100644
--- a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql
+++ b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql
@@ -12,6 +12,8 @@
     with (
       {%- if external_location is not none and not temporary %}
         external_location='{{ external_location }}',
+      {%- elif adapter.has_s3_data_dir() -%}
+        external_location='{{ adapter.s3_table_location(relation.schema, relation.identifier) }}',
       {%- endif %}
       {%- if partitioned_by is not none %}
         partitioned_by=ARRAY{{ partitioned_by | tojson | replace('\"', '\'') }},
diff --git a/dbt/include/athena/macros/materializations/seeds/helpers.sql b/dbt/include/athena/macros/materializations/seeds/helpers.sql
index bbc0e0e0..057d5d0d 100644
--- a/dbt/include/athena/macros/materializations/seeds/helpers.sql
+++ b/dbt/include/athena/macros/materializations/seeds/helpers.sql
@@ -21,7 +21,7 @@
         {%- endfor -%}
     )
     stored as parquet
-    location '{{ adapter.s3_uuid_table_location() }}'
+    location '{{ adapter.s3_table_location(model["schema"], model["alias"]) }}'
     tblproperties ('classification'='parquet')
   {% endset %}