diff --git a/dbt/adapters/glue/column.py b/dbt/adapters/glue/column.py new file mode 100644 index 00000000..222daea6 --- /dev/null +++ b/dbt/adapters/glue/column.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +from typing import ClassVar, Dict + +from dbt.adapters.base.column import Column + + +@dataclass +class GlueColumn(Column): + # Overwriting dafult string types to support glue + # TODO: Convert to supported glue types as needed + # Please ref: https://github.com/dbt-athena/dbt-athena/blob/main/dbt/adapters/athena/column.py + TYPE_LABELS: ClassVar[Dict[str, str]] = { + "STRING": "STRING", + "TEXT": "STRING", + "VARCHAR": "STRING" + } diff --git a/dbt/adapters/glue/connections.py b/dbt/adapters/glue/connections.py index 1399a509..57867ead 100644 --- a/dbt/adapters/glue/connections.py +++ b/dbt/adapters/glue/connections.py @@ -29,6 +29,15 @@ class GlueConnectionManager(SQLConnectionManager): TYPE = "glue" GLUE_CONNECTIONS_BY_KEY: Dict[str, GlueConnection] = {} + @classmethod + def data_type_code_to_name(cls, type_code: str) -> str: + """ + Get the string representation of the data type from the metadata. Dbt performs a + query to retrieve the types of the columns in the SQL query. Then these types are compared + to the types in the contract config, simplified because they need to match what is returned + by metadata (we are only interested in the broader type, without subtypes nor granularity). + """ + return type_code.split("(")[0].split("<")[0].upper() @classmethod def open(cls, connection): @@ -103,7 +112,7 @@ def get_result_from_cursor(cls, cursor: GlueCursor, limit: Optional[int]) -> aga data: List[Any] = [] column_names: List[str] = [] if cursor.description is not None: - column_names = [col[0] for col in cursor.description()] + column_names = [col[0] for col in cursor.description] if limit: rows = cursor.fetchmany(limit) else: diff --git a/dbt/adapters/glue/gluedbapi/cursor.py b/dbt/adapters/glue/gluedbapi/cursor.py index aff0c8bd..942cbb98 100644 --- a/dbt/adapters/glue/gluedbapi/cursor.py +++ b/dbt/adapters/glue/gluedbapi/cursor.py @@ -206,6 +206,7 @@ def __next__(self): raise StopIteration return item + @property def description(self): logger.debug("GlueCursor description called") if self.response: diff --git a/dbt/adapters/glue/impl.py b/dbt/adapters/glue/impl.py index f7fd6189..769a6ed0 100644 --- a/dbt/adapters/glue/impl.py +++ b/dbt/adapters/glue/impl.py @@ -11,9 +11,9 @@ from dbt.adapters.base import available from dbt.adapters.base.relation import BaseRelation -from dbt.adapters.base.column import Column from dbt.adapters.sql import SQLAdapter from dbt.adapters.glue import GlueConnectionManager +from dbt.adapters.glue.column import GlueColumn from dbt.adapters.glue.gluedbapi import GlueConnection from dbt.adapters.glue.relation import SparkRelation from dbt.adapters.glue.lakeformation import ( @@ -33,6 +33,7 @@ class GlueAdapter(SQLAdapter): ConnectionManager = GlueConnectionManager Relation = SparkRelation + Column = GlueColumn relation_type_map = {'EXTERNAL_TABLE': 'table', 'MANAGED_TABLE': 'table', @@ -234,7 +235,7 @@ def get_columns_in_relation(self, relation: BaseRelation): records = self.fetch_all_response(response) for record in records: - column = Column(column=record[0], dtype=record[1]) + column = self.Column(column=record[0], dtype=record[1]) if record[0][:1] != "#": if column not in columns: columns.append(column) diff --git a/dbt/include/glue/macros/adapters.sql b/dbt/include/glue/macros/adapters.sql index 99e20a48..eecf822e 100644 --- a/dbt/include/glue/macros/adapters.sql +++ b/dbt/include/glue/macros/adapters.sql @@ -66,6 +66,13 @@ {{ create_temporary_view(relation, sql) }} {%- else -%} create table {{ relation }} + {% set contract_config = config.get('contract') %} + {% if contract_config.enforced %} + {{ get_assert_columns_equivalent(sql) }} + {#-- This does not enforce contstraints and needs to be a TODO #} + {#-- We'll need to change up the query because with CREATE TABLE AS SELECT, #} + {#-- you do not specify the columns #} + {% endif %} {{ glue__file_format_clause() }} {{ partition_cols(label="partitioned by") }} {{ clustered_cols(label="clustered by") }} @@ -124,6 +131,10 @@ {% endmacro %} {% macro glue__create_view_as(relation, sql) -%} + {%- set contract_config = config.get('contract') -%} + {%- if contract_config.enforced -%} + {{ get_assert_columns_equivalent(sql) }} + {%- endif -%} DROP VIEW IF EXISTS {{ relation }} dbt_next_query create view {{ relation }}