Skip to content

Commit

Permalink
Support partitioning hints for athena iceberg
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed May 23, 2024
1 parent d0fdfb4 commit ef25591
Show file tree
Hide file tree
Showing 7 changed files with 301 additions and 9 deletions.
2 changes: 2 additions & 0 deletions dlt/destinations/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from dlt.destinations.impl.bigquery import bigquery_adapter
from dlt.destinations.impl.synapse import synapse_adapter
from dlt.destinations.impl.clickhouse import clickhouse_adapter
from dlt.destinations.impl.athena import athena_adapter

__all__ = [
"weaviate_adapter",
"qdrant_adapter",
"bigquery_adapter",
"synapse_adapter",
"clickhouse_adapter",
"athena_adapter",
]
30 changes: 21 additions & 9 deletions dlt/destinations/impl/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from dlt.destinations.impl.athena.configuration import AthenaClientConfiguration
from dlt.destinations.type_mapping import TypeMapper
from dlt.destinations import path_utils
from dlt.destinations.impl.athena.athena_adapter import PARTITION_HINT


class AthenaTypeMapper(TypeMapper):
Expand Down Expand Up @@ -401,9 +402,12 @@ def _from_db_type(
return self.type_mapper.from_db_type(hive_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
return (
f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}"
)
return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}"

def _iceberg_partition_clause(self, partition_hints: Optional[List[str]]) -> str:
if not partition_hints:
return ""
return f"PARTITIONED BY ({', '.join(partition_hints)})"

def _get_table_update_sql(
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
Expand Down Expand Up @@ -431,20 +435,28 @@ def _get_table_update_sql(
sql.append(f"""ALTER TABLE {qualified_table_name} ADD COLUMNS ({columns});""")
else:
if is_iceberg:
sql.append(f"""CREATE TABLE {qualified_table_name}
partition_clause = self._iceberg_partition_clause(table.get(PARTITION_HINT))
sql.append(
f"""CREATE TABLE {qualified_table_name}
({columns})
{partition_clause}
LOCATION '{location.rstrip('/')}'
TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');""")
TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');"""
)
elif table_format == "jsonl":
sql.append(f"""CREATE EXTERNAL TABLE {qualified_table_name}
sql.append(
f"""CREATE EXTERNAL TABLE {qualified_table_name}
({columns})
ROW FORMAT SERDE 'org.openx.data.jsonserde.JsonSerDe'
LOCATION '{location}';""")
LOCATION '{location}';"""
)
else:
sql.append(f"""CREATE EXTERNAL TABLE {qualified_table_name}
sql.append(
f"""CREATE EXTERNAL TABLE {qualified_table_name}
({columns})
STORED AS PARQUET
LOCATION '{location}';""")
LOCATION '{location}';"""
)
return sql

def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
Expand Down
90 changes: 90 additions & 0 deletions dlt/destinations/impl/athena/athena_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Any, Optional, Dict, Protocol, Sequence, Union, Final

from dateutil import parser

from dlt.common.pendulum import timezone
from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns, TColumnSchema
from dlt.destinations.utils import ensure_resource
from dlt.extract import DltResource
from dlt.extract.items import TTableHintTemplate


PARTITION_HINT: Final[str] = "x-athena-partition"


class athena_partition:
"""Helper class to generate iceberg partition transform strings.
E.g. `athena_partition.bucket(16, "id")` will return `bucket(16, "id")`.
"""

@staticmethod
def year(column_name: str) -> str:
"""Partition by year part of a date or timestamp column."""
return f"year({column_name})"

@staticmethod
def month(column_name: str) -> str:
"""Partition by month part of a date or timestamp column."""
return f"month({column_name})"

@staticmethod
def day(column_name: str) -> str:
"""Partition by day part of a date or timestamp column."""
return f"day({column_name})"

@staticmethod
def hour(column_name: str) -> str:
"""Partition by hour part of a date or timestamp column."""
return f"hour({column_name})"

@staticmethod
def bucket(n: int, column_name: str) -> str:
"""Partition by hashed value to n buckets."""
return f"bucket({n}, {column_name})"

@staticmethod
def truncate(length: int, column_name: str) -> str:
"""Partition by value truncated to length."""
return f"truncate({length}, {column_name})"


def athena_adapter(
data: Any,
partition: Union[str, Sequence[str]] = None,
) -> DltResource:
"""
Prepares data for loading into Athena
Args:
data: The data to be transformed.
This can be raw data or an instance of DltResource.
If raw data is provided, the function will wrap it into a `DltResource` object.
partition: Column name(s) partition transform string(s) to partition table by
Returns:
A `DltResource` object that is ready to be loaded into BigQuery.
Raises:
ValueError: If any hint is invalid or none are specified.
Examples:
>>> data = [{"name": "Marcel", "department": "Engineering", "date_hired": "2024-01-30"}]
>>> athena_adapter(data, partition=["department", athena_partition.year("date_hired"), athena_partition.bucket(8, "name")])
[DltResource with hints applied]
"""
resource = ensure_resource(data)
additional_table_hints: Dict[str, TTableHintTemplate[Any]] = {}

if partition:
if isinstance(partition, str):
partition = [partition]

# Note: PARTITIONED BY clause identifiers are not allowed to be quoted. They are added as-is.
additional_table_hints[PARTITION_HINT] = list(partition)

if additional_table_hints:
resource.apply_hints(additional_table_hints=additional_table_hints)
else:
raise ValueError("A value for `partition` must be specified.")
return resource
49 changes: 49 additions & 0 deletions docs/website/docs/dlt-ecosystem/destinations/athena.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,5 +161,54 @@ aws_data_catalog="awsdatacatalog"
You can choose the following file formats:
* [parquet](../file-formats/parquet.md) is used by default


## Athena adapter

You can use the `athena_adapter` to add partitioning to Athena tables. This is currently only supported for Iceberg tables.

Here is how you use it:

```py
from datetime import date

import dlt
from dlt.destinations.impl.athena.athena_adapter import athena_partition, athena_adapter

data_items = [
(1, "A", date(2021, 1, 1)),
(2, "A", date(2021, 1, 2)),
(3, "A", date(2021, 1, 3)),
(4, "A", date(2021, 2, 1)),
(5, "A", date(2021, 2, 2)),
(6, "B", date(2021, 1, 1)),
(7, "B", date(2021, 1, 2)),
(8, "B", date(2021, 1, 3)),
(9, "B", date(2021, 2, 1)),
(10, "B", date(2021, 3, 2)),
]

@dlt.resource(table_format="iceberg")
def partitioned_data():
yield [{"id": i, "category": c, "created_at": d} for i, c, d in data_items]


# Add partitioning hints to the table
athena_adapter(
partitioned_table,
partition=[
# Partition per category and month
"category",
athena_partition.month("created_at"),
],
)


pipeline = dlt.pipeline("athena_example")
pipeline.run(partitioned_data)
```

See the AWS Athena documentation for more information on [available partitioning options](https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-creating-tables.html#querying-iceberg-creating-tables-query-editor).


<!--@@@DLT_TUBA athena-->

66 changes: 66 additions & 0 deletions tests/load/athena_iceberg/test_athena_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pytest

import dlt
from dlt.destinations import filesystem
from dlt.destinations.impl.athena.athena_adapter import athena_adapter, athena_partition
from tests.load.utils import destinations_configs, DestinationTestConfiguration


def test_iceberg_partition_hints():
"""Create a table with athena partition hints and check that the SQL is generated correctly."""

@dlt.resource(table_format="iceberg")
def partitioned_table():
yield {
"product_id": 1,
"name": "product 1",
"created_at": "2021-01-01T00:00:00Z",
"category": "category 1",
"price": 100.0,
"quantity": 10,
}

@dlt.resource(table_format="iceberg")
def not_partitioned_table():
yield {"a": 1, "b": 2}

athena_adapter(
partitioned_table,
partition=[
"category",
athena_partition.month("created_at"),
athena_partition.bucket(10, "product_id"),
athena_partition.truncate(2, "name"),
],
)

pipeline = dlt.pipeline(
"athena_test",
destination="athena",
staging=filesystem("s3://not-a-real-bucket"),
full_refresh=True,
)

pipeline.extract([partitioned_table, not_partitioned_table])
pipeline.normalize()

with pipeline._sql_job_client(pipeline.default_schema) as client:
sql_partitioned = client._get_table_update_sql(
"partitioned_table",
list(pipeline.default_schema.tables["partitioned_table"]["columns"].values()),
False,
)[0]
sql_not_partitioned = client._get_table_update_sql(
"not_partitioned_table",
list(pipeline.default_schema.tables["not_partitioned_table"]["columns"].values()),
False,
)[0]

# Partition clause is generated with original order
expected_clause = (
"PARTITIONED BY (category, month(created_at), bucket(10, product_id), truncate(2, name))"
)
assert expected_clause in sql_partitioned

# No partition clause otherwise
assert "PARTITIONED BY" not in sql_not_partitioned
67 changes: 67 additions & 0 deletions tests/load/pipeline/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from tests.pipeline.utils import assert_load_info, load_table_counts
from tests.pipeline.utils import load_table_counts
from dlt.destinations.exceptions import CantExtractTablePrefix
from dlt.destinations.impl.athena.athena_adapter import athena_partition, athena_adapter
from dlt.destinations.fs_client import FSClientBase

from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration
from tests.load.utils import (
Expand Down Expand Up @@ -231,3 +233,68 @@ def test_athena_file_layouts(destination_config: DestinationTestConfiguration, l
pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]
)
assert table_counts == {"items1": 3, "items2": 7}


@pytest.mark.parametrize(
"destination_config",
destinations_configs(default_sql_configs=True, subset=["athena"], force_iceberg=True),
ids=lambda x: x.name,
)
def test_athena_partitioned_iceberg_table(destination_config: DestinationTestConfiguration):
"""Load an iceberg table with partition hints and verifiy partitions are created correctly."""
pipeline = destination_config.setup_pipeline("athena_" + uniq_id(), full_refresh=True)

data_items = [
(1, "A", datetime.date.fromisoformat("2021-01-01")),
(2, "A", datetime.date.fromisoformat("2021-01-02")),
(3, "A", datetime.date.fromisoformat("2021-01-03")),
(4, "A", datetime.date.fromisoformat("2021-02-01")),
(5, "A", datetime.date.fromisoformat("2021-02-02")),
(6, "B", datetime.date.fromisoformat("2021-01-01")),
(7, "B", datetime.date.fromisoformat("2021-01-02")),
(8, "B", datetime.date.fromisoformat("2021-01-03")),
(9, "B", datetime.date.fromisoformat("2021-02-01")),
(10, "B", datetime.date.fromisoformat("2021-03-02")),
]

@dlt.resource(table_format="iceberg")
def partitioned_table():
yield [{"id": i, "category": c, "created_at": d} for i, c, d in data_items]

athena_adapter(
partitioned_table,
partition=[
"category",
athena_partition.month("created_at"),
],
)

info = pipeline.run(partitioned_table)
assert_load_info(info)

# Get partitions from metadata
with pipeline.sql_client() as sql_client:
tbl_name = sql_client.make_qualified_table_name("partitioned_table$partitions")
rows = sql_client.execute_sql(f"SELECT partition FROM {tbl_name}")
partition_keys = {r[0] for r in rows}

data_rows = sql_client.execute_sql(
f"SELECT id, category, created_at FROM {sql_client.make_qualified_table_name('partitioned_table')}"
)
# data_rows = [(i, c, d.toisoformat()) for i, c, d in data_rows]

# All data is in table
assert len(data_rows) == len(data_items)
assert set(data_rows) == set(data_items)

# Compare with expected partitions
# Months are number of months since epoch
expected_partitions = {
"{category=A, created_at_month=612}",
"{category=A, created_at_month=613}",
"{category=B, created_at_month=612}",
"{category=B, created_at_month=613}",
"{category=B, created_at_month=614}",
}

assert partition_keys == expected_partitions
6 changes: 6 additions & 0 deletions tests/load/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def destinations_configs(
file_format: Union[TLoaderFileFormat, Sequence[TLoaderFileFormat]] = None,
supports_merge: Optional[bool] = None,
supports_dbt: Optional[bool] = None,
force_iceberg: Optional[bool] = None,
) -> List[DestinationTestConfiguration]:
# sanity check
for item in subset:
Expand Down Expand Up @@ -495,6 +496,11 @@ def destinations_configs(
conf for conf in destination_configs if conf.name not in EXCLUDED_DESTINATION_CONFIGURATIONS
]

if force_iceberg is not None:
destination_configs = [
conf for conf in destination_configs if conf.force_iceberg is force_iceberg
]

return destination_configs


Expand Down

0 comments on commit ef25591

Please sign in to comment.