Skip to content

Commit

Permalink
allow cast on seeds
Browse files Browse the repository at this point in the history
  • Loading branch information
Jaume Sanjuan committed Oct 4, 2024
1 parent abf8c52 commit 17494cb
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 6 deletions.
5 changes: 3 additions & 2 deletions dbt/adapters/glue/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class GlueCredentials(Credentials):
datalake_formats: Optional[str] = None
enable_session_per_model: Optional[bool] = False
use_arrow: Optional[bool] = False

enable_spark_seed_casting: Optional[bool] = False

@property
def type(self):
Expand Down Expand Up @@ -93,5 +93,6 @@ def _connection_keys(self):
'glue_session_reuse',
'datalake_formats',
'enable_session_per_model',
'use_arrow'
'use_arrow',
'enable_spark_seed_casting',
]
77 changes: 74 additions & 3 deletions dbt/adapters/glue/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,52 @@
from dbt_common.exceptions import DbtDatabaseError, CompilationError
from dbt.adapters.base.impl import catch_as_completed
from dbt_common.utils import executor
from dbt_common.clients import agate_helper
from dbt.adapters.events.logging import AdapterLogger

logger = AdapterLogger("Glue")


class ColumnCsvMappingStrategy:
_schema_mappings = {
agate_helper.ISODateTime: 'string',
agate_helper.Number: 'double',
agate_helper.Integer: 'int',
agate.data_types.Boolean: 'boolean',
agate.data_types.Date: 'string',
agate.data_types.DateTime: 'string',
agate.data_types.Text: 'string',
}

_cast_mappings = {
agate_helper.ISODateTime: 'timestamp',
agate.data_types.Date: 'date',
agate.data_types.DateTime: 'timestamp',
}

def __init__(self, column_name, agate_type, specified_type):
self.column_name = column_name
self.agate_type = agate_type
self.specified_type = specified_type

def as_schema_value(self):
return ColumnCsvMappingStrategy._schema_mappings.get(self.agate_type)

def as_cast_value(self):
return (
self.specified_type if self.specified_type else ColumnCsvMappingStrategy._cast_mappings.get(self.agate_type)
)

@classmethod
def from_model(cls, model, agate_table):
return [
ColumnCsvMappingStrategy(
column.name, type(column.data_type), model.get("config", {}).get("column_types", {}).get(column.name)
)
for column in agate_table.columns
]


class GlueAdapter(SQLAdapter):
ConnectionManager = GlueConnectionManager
Relation = SparkRelation
Expand Down Expand Up @@ -535,7 +576,7 @@ def create_csv_table(self, model, agate_table):
mode = "False"

csv_chunks = self._split_csv_records_into_chunks(json.loads(f.getvalue()))
statements = self._map_csv_chunks_to_code(csv_chunks, session, model, mode)
statements = self._map_csv_chunks_to_code(csv_chunks, session, model, mode, ColumnCsvMappingStrategy.from_model(model, agate_table))
try:
cursor = session.cursor()
for statement in statements:
Expand All @@ -545,7 +586,14 @@ def create_csv_table(self, model, agate_table):
except Exception as e:
logger.error(e)

def _map_csv_chunks_to_code(self, csv_chunks: List[List[dict]], session: GlueConnection, model, mode):
def _map_csv_chunks_to_code(
self,
csv_chunks: List[List[dict]],
session: GlueConnection,
model,
mode,
column_mappings: List[ColumnCsvMappingStrategy],
):
statements = []
for i, csv_chunk in enumerate(csv_chunks):
is_first = i == 0
Expand All @@ -564,8 +612,31 @@ def _map_csv_chunks_to_code(self, csv_chunks: List[List[dict]], session: GlueCon
SqlWrapper2.execute("""select 1""")
'''
else:
code += f'''
if session.credentials.enable_spark_seed_casting:
csv_schema = ", ".join(
[f"{mapping.column_name}: {mapping.as_schema_value()}" for mapping in column_mappings]
)

cast_code = ".".join(
[
"df",
*[
f'withColumn("{mapping.column_name}", df.{mapping.column_name}.cast("{cast_value}"))'
for mapping in column_mappings
if (cast_value := mapping.as_cast_value())
],
]
)

code += f"""
df = spark.createDataFrame(csv, "{csv_schema}")
df = {cast_code}
"""
else:
code += """
df = spark.createDataFrame(csv)
"""
code += f'''
table_name = '{model["schema"]}.{model["name"]}'
if (spark.sql("show tables in {model["schema"]}").where("tableName == lower('{model["name"]}')").count() > 0):
df.write\
Expand Down
71 changes: 70 additions & 1 deletion tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import unittest
from unittest import mock
from unittest.mock import Mock
import pytest
from multiprocessing import get_context
import agate.data_types
from botocore.client import BaseClient
from moto import mock_aws

Expand All @@ -13,6 +15,8 @@
from dbt.adapters.glue import GlueAdapter
from dbt.adapters.glue.gluedbapi import GlueConnection
from dbt.adapters.glue.relation import SparkRelation
from dbt.adapters.glue.impl import ColumnCsvMappingStrategy
from dbt_common.clients import agate_helper
from tests.util import config_from_parts_or_dicts
from .util import MockAWSService

Expand Down Expand Up @@ -99,4 +103,69 @@ def test_create_csv_table_slices_big_datasets(self):
adapter.create_csv_table(model, test_table)

# test table is between 120000 and 180000 characters so it should be split three times (max chunk is 60000)
self.assertEqual(session_mock.cursor().execute.call_count, 3)
self.assertEqual(session_mock.cursor().execute.call_count, 3)

def test_create_csv_table_provides_schema_and_casts_when_spark_seed_cast_is_enabled(self):
config = self._get_config()
config.credentials.enable_spark_seed_casting = True
adapter = GlueAdapter(config, get_context("spawn"))
csv_chunks = [{'test_column': '1.2345'}]
model = {"name": "mock_model", "schema": "mock_schema", "config": {"column_types": {"test_column": "double"}}}
column_mappings = [ColumnCsvMappingStrategy('test_column', agate.data_types.Text, 'double')]
code = adapter._map_csv_chunks_to_code(csv_chunks, config, model, 'True', column_mappings)
self.assertIn('spark.createDataFrame(csv, "test_column: string")', code[0])
self.assertIn('df = df.withColumn("test_column", df.test_column.cast("double"))', code[0])

def test_create_csv_table_doesnt_provide_schema_when_spark_seed_cast_is_disabled(self):
config = self._get_config()
config.credentials.enable_spark_seed_casting = False
adapter = GlueAdapter(config, get_context("spawn"))
csv_chunks = [{'test_column': '1.2345'}]
model = {"name": "mock_model", "schema": "mock_schema"}
column_mappings = [ColumnCsvMappingStrategy('test_column', agate.data_types.Text, 'double')]
code = adapter._map_csv_chunks_to_code(csv_chunks, config, model, 'True', column_mappings)
self.assertIn('spark.createDataFrame(csv)', code[0])

class TestCsvMappingStrategy:
@pytest.mark.parametrize(
'agate_type,specified_type,expected_schema_type,expected_cast_type',
[
(agate_helper.ISODateTime, None, 'string', 'timestamp'),
(agate_helper.Number, None, 'double', None),
(agate_helper.Integer, None, 'int', None),
(agate.data_types.Boolean, None, 'boolean', None),
(agate.data_types.Date, None, 'string', 'date'),
(agate.data_types.DateTime, None, 'string', 'timestamp'),
(agate.data_types.Text, None, 'string', None),
(agate.data_types.Text, 'double', 'string', 'double'),
],
ids=[
'test isodatetime cast',
'test number cast',
'test integer cast',
'test boolean cast',
'test date cast',
'test datetime cast',
'test text cast',
'test specified cast',
]
)
def test_mapping_strategy_provides_proper_mappings(self, agate_type, specified_type, expected_schema_type, expected_cast_type):
column_mapping = ColumnCsvMappingStrategy('test_column', agate_type, specified_type)
assert column_mapping.as_schema_value() == expected_schema_type
assert column_mapping.as_cast_value() == expected_cast_type

def test_from_model_builds_column_mappings(self):
expected_column_names = ['col_int', 'col_str', 'col_date', 'col_specific']
expected_agate_types = [agate_helper.Integer,agate.data_types.Text, agate.data_types.Date, agate.data_types.Text]
expected_specified_types = [None, None, None, 'double']
agate_table = agate.Table(
[(111,'str_val','2024-01-01', '1.234')],
column_names=expected_column_names,
column_types=[data_type() for data_type in expected_agate_types]
)
model = {"name": "mock_model", "config": {"column_types": {"col_specific": "double"}}}
mappings = ColumnCsvMappingStrategy.from_model(model, agate_table)
assert expected_column_names == [mapping.column_name for mapping in mappings]
assert expected_agate_types == [mapping.agate_type for mapping in mappings]
assert expected_specified_types == [mapping.specified_type for mapping in mappings]

0 comments on commit 17494cb

Please sign in to comment.