From 0e251028e54f51137ee4ecb39d5d1c08b94e8a10 Mon Sep 17 00:00:00 2001 From: Dave Date: Thu, 28 Sep 2023 16:39:30 +0200 Subject: [PATCH] first iceberg prototype --- dlt/destinations/athena/athena.py | 41 ++++++++++++++----- dlt/destinations/athena/configuration.py | 1 + dlt/destinations/bigquery/bigquery.py | 4 +- dlt/destinations/job_client_impl.py | 33 +++++++++++---- dlt/destinations/mssql/mssql.py | 4 +- dlt/destinations/postgres/postgres.py | 4 +- dlt/destinations/snowflake/snowflake.py | 7 ++-- dlt/destinations/sql_jobs.py | 24 +++++++---- tests/load/test_iceberg.py | 52 ++++++++++++++++++++++++ 9 files changed, 133 insertions(+), 37 deletions(-) create mode 100644 tests/load/test_iceberg.py diff --git a/dlt/destinations/athena/athena.py b/dlt/destinations/athena/athena.py index ed8364aa3a..69175faabb 100644 --- a/dlt/destinations/athena/athena.py +++ b/dlt/destinations/athena/athena.py @@ -16,21 +16,21 @@ from dlt.common.utils import without_none from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, Schema -from dlt.common.schema.typing import TTableSchema, TColumnType +from dlt.common.schema.typing import TTableSchema, TColumnType, TWriteDisposition from dlt.common.schema.utils import table_schema_has_type from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import LoadJob -from dlt.common.destination.reference import TLoadJobState +from dlt.common.destination.reference import LoadJob, FollowupJob +from dlt.common.destination.reference import TLoadJobState, NewLoadJob from dlt.common.storages import FileStorage from dlt.common.data_writers.escape import escape_bigquery_identifier - +from dlt.destinations.sql_jobs import SqlStagingCopyJob from dlt.destinations.typing import DBApi, DBTransaction from dlt.destinations.exceptions import DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation, LoadJobTerminalException from dlt.destinations.athena import capabilities from dlt.destinations.sql_client import SqlClientBase, DBApiCursorImpl, raise_database_error, raise_open_connection_error from dlt.destinations.typing import DBApiCursor -from dlt.destinations.job_client_impl import SqlJobClientBase, StorageSchemaInfo +from dlt.destinations.job_client_impl import SqlJobClientWithStaging from dlt.destinations.athena.configuration import AthenaClientConfiguration from dlt.destinations.type_mapping import TypeMapper from dlt.destinations import path_utils @@ -121,7 +121,7 @@ def __init__(self) -> None: DLTAthenaFormatter._INSTANCE = self -class DoNothingJob(LoadJob): +class DoNothingJob(LoadJob, FollowupJob): """The most lazy class of dlt""" def __init__(self, file_path: str) -> None: @@ -135,6 +135,7 @@ def exception(self) -> str: # this part of code should be never reached raise NotImplementedError() + class AthenaSQLClient(SqlClientBase[Connection]): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -274,9 +275,9 @@ def has_dataset(self) -> bool: query = f"""SHOW DATABASES LIKE {self.fully_qualified_dataset_name()};""" rows = self.execute_sql(query) return len(rows) > 0 + - -class AthenaClient(SqlJobClientBase): +class AthenaClient(SqlJobClientWithStaging): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -307,12 +308,22 @@ def _get_column_def_sql(self, c: TColumnSchema) -> str: def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool) -> List[str]: + create_only_iceberg_tables = self.config.iceberg_bucket_url is not None and not self.in_staging_mode + bucket = self.config.staging_config.bucket_url - dataset = self.sql_client.dataset_name + if create_only_iceberg_tables: + bucket = self.config.iceberg_bucket_url + + print(table_name) + print(bucket) + + # TODO: we need to strip the staging layout from the table name, find a better way! + dataset = self.sql_client.dataset_name.replace("_staging", "") sql: List[str] = [] # for the system tables we need to create empty iceberg tables to be able to run, DELETE and UPDATE queries - is_iceberg = self.schema.tables[table_name].get("write_disposition", None) == "skip" + # or if we are in iceberg mode, we create iceberg tables for all tables + is_iceberg = (self.schema.tables[table_name].get("write_disposition", None) == "skip") or create_only_iceberg_tables columns = ", ".join([self._get_column_def_sql(c) for c in new_columns]) # this will fail if the table prefix is not properly defined @@ -348,6 +359,16 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> job = DoNothingJob(file_path) return job + def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema], replace: bool) -> NewLoadJob: + """update destination tables from staging tables""" + return SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": replace}) + + def get_stage_dispositions(self) -> List[TWriteDisposition]: + # in iceberg mode, we always use staging tables + if self.config.iceberg_bucket_url is not None: + return ["append", "replace", "merge"] + return [] + @staticmethod def is_dbapi_exception(ex: Exception) -> bool: return isinstance(ex, Error) diff --git a/dlt/destinations/athena/configuration.py b/dlt/destinations/athena/configuration.py index f6e6fa3b51..3c175cba9b 100644 --- a/dlt/destinations/athena/configuration.py +++ b/dlt/destinations/athena/configuration.py @@ -9,6 +9,7 @@ class AthenaClientConfiguration(DestinationClientDwhWithStagingConfiguration): destination_name: Final[str] = "athena" # type: ignore[misc] query_result_bucket: str = None + iceberg_bucket_url: Optional[str] = None credentials: AwsCredentials = None athena_work_group: Optional[str] = None aws_data_catalog: Optional[str] = "awsdatacatalog" diff --git a/dlt/destinations/bigquery/bigquery.py b/dlt/destinations/bigquery/bigquery.py index 473fee2113..387e450184 100644 --- a/dlt/destinations/bigquery/bigquery.py +++ b/dlt/destinations/bigquery/bigquery.py @@ -19,7 +19,7 @@ from dlt.destinations.bigquery import capabilities from dlt.destinations.bigquery.configuration import BigQueryClientConfiguration from dlt.destinations.bigquery.sql_client import BigQuerySqlClient, BQ_TERMINAL_REASONS -from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob +from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob, SqlJobParams from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper @@ -138,7 +138,7 @@ def gen_key_table_clauses(cls, root_table_name: str, staging_root_table_name: st class BigqueryStagingCopyJob(SqlStagingCopyJob): @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index c082eefb93..90cdf1ffcd 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -148,26 +148,34 @@ def get_truncate_destination_table_dispositions(self) -> List[TWriteDisposition] def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: return SqlMergeJob.from_table_chain(table_chain, self.sql_client) - def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: + def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema], replace: bool) -> NewLoadJob: """update destination tables from staging tables""" - return SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client) + if not replace: + return None + return SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True}) def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: """optimized replace strategy, defaults to _create_staging_copy_job for the basic client for some destinations there are much faster destination updates at the cost of dropping tables possible""" - return self._create_staging_copy_job(table_chain) + return self._create_staging_copy_job(table_chain, True) def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: """Creates a list of followup jobs for merge write disposition and staging replace strategies""" jobs = super().create_table_chain_completed_followup_jobs(table_chain) write_disposition = table_chain[0]["write_disposition"] - if write_disposition == "merge": - jobs.append(self._create_merge_job(table_chain)) + if write_disposition == "append": + if job := self._create_staging_copy_job(table_chain, False): + jobs.append(job) + elif write_disposition == "merge": + if job := self._create_merge_job(table_chain): + jobs.append(job) elif write_disposition == "replace" and self.config.replace_strategy == "insert-from-staging": - jobs.append(self._create_staging_copy_job(table_chain)) + if job := self._create_staging_copy_job(table_chain, True): + jobs.append(job) elif write_disposition == "replace" and self.config.replace_strategy == "staging-optimized": - jobs.append(self._create_optimized_replace_job(table_chain)) + if job := self._create_optimized_replace_job(table_chain): + jobs.append(job) return jobs def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: @@ -431,10 +439,17 @@ def _commit_schema_update(self, schema: Schema, schema_str: str) -> None: class SqlJobClientWithStaging(SqlJobClientBase, WithStagingDataset): + + in_staging_mode: bool = False + @contextlib.contextmanager def with_staging_dataset(self)-> Iterator["SqlJobClientBase"]: - with self.sql_client.with_staging_dataset(True): - yield self + try: + with self.sql_client.with_staging_dataset(True): + self.in_staging_mode = True + yield self + finally: + self.in_staging_mode = False def get_stage_dispositions(self) -> List[TWriteDisposition]: """Returns a list of dispositions that require staging tables to be populated""" diff --git a/dlt/destinations/mssql/mssql.py b/dlt/destinations/mssql/mssql.py index 5ed3b706b8..67b51f885c 100644 --- a/dlt/destinations/mssql/mssql.py +++ b/dlt/destinations/mssql/mssql.py @@ -8,7 +8,7 @@ from dlt.common.schema.typing import TTableSchema, TColumnType from dlt.common.utils import uniq_id -from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob +from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob, SqlJobParams from dlt.destinations.insert_job_client import InsertValuesJobClient @@ -83,7 +83,7 @@ def from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[i class MsSqlStagingCopyJob(SqlStagingCopyJob): @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): diff --git a/dlt/destinations/postgres/postgres.py b/dlt/destinations/postgres/postgres.py index ead5ab6639..b6c716754f 100644 --- a/dlt/destinations/postgres/postgres.py +++ b/dlt/destinations/postgres/postgres.py @@ -7,7 +7,7 @@ from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.schema.typing import TTableSchema, TColumnType -from dlt.destinations.sql_jobs import SqlStagingCopyJob +from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams from dlt.destinations.insert_job_client import InsertValuesJobClient @@ -79,7 +79,7 @@ def from_db_type(self, db_type: str, precision: Optional[int] = None, scale: Opt class PostgresStagingCopyJob(SqlStagingCopyJob): @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): diff --git a/dlt/destinations/snowflake/snowflake.py b/dlt/destinations/snowflake/snowflake.py index b9046cde75..69432bc696 100644 --- a/dlt/destinations/snowflake/snowflake.py +++ b/dlt/destinations/snowflake/snowflake.py @@ -17,7 +17,7 @@ from dlt.destinations.snowflake import capabilities from dlt.destinations.snowflake.configuration import SnowflakeClientConfiguration from dlt.destinations.snowflake.sql_client import SnowflakeSqlClient -from dlt.destinations.sql_jobs import SqlStagingCopyJob +from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams from dlt.destinations.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase @@ -157,13 +157,12 @@ def exception(self) -> str: class SnowflakeStagingCopyJob(SqlStagingCopyJob): @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): staging_table_name = sql_client.make_qualified_table_name(table["name"]) table_name = sql_client.make_qualified_table_name(table["name"]) - # drop destination table sql.append(f"DROP TABLE IF EXISTS {table_name};") # recreate destination table with data cloned from staging table sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};") @@ -206,7 +205,7 @@ def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - return SnowflakeStagingCopyJob.from_table_chain(table_chain, self.sql_client) + return SnowflakeStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True}) def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool, separate_alters: bool = False) -> List[str]: sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index c1137ee9ad..784c0e3a05 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Sequence, Tuple, cast +from typing import Any, Callable, List, Sequence, Tuple, cast, TypedDict, Optional import yaml from dlt.common.runtime.logger import pretty_format_exception @@ -11,24 +11,30 @@ from dlt.destinations.job_impl import NewLoadJobImpl from dlt.destinations.sql_client import SqlClientBase +class SqlJobParams(TypedDict): + replace: Optional[bool] + +DEFAULTS: SqlJobParams = { + "replace": False +} class SqlBaseJob(NewLoadJobImpl): """Sql base job for jobs that rely on the whole tablechain""" failed_text: str = "" @classmethod - def from_table_chain(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> NewLoadJobImpl: + def from_table_chain(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> NewLoadJobImpl: """Generates a list of sql statements, that will be executed by the sql client when the job is executed in the loader. The `table_chain` contains a list schemas of a tables with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). """ - + params = cast(SqlJobParams, {**DEFAULTS, **(params or {})}) # type: ignore top_table = table_chain[0] file_info = ParsedLoadJobFileName(top_table["name"], uniq_id()[:10], 0, "sql") try: # Remove line breaks from multiline statements and write one SQL statement per line in output file # to support clients that need to execute one statement at a time (i.e. snowflake) - sql = [' '.join(stmt.splitlines()) for stmt in cls.generate_sql(table_chain, sql_client)] + sql = [' '.join(stmt.splitlines()) for stmt in cls.generate_sql(table_chain, sql_client, params)] job = cls(file_info.job_id(), "running") job._save_text_file("\n".join(sql)) except Exception: @@ -39,7 +45,7 @@ def from_table_chain(cls, table_chain: Sequence[TTableSchema], sql_client: SqlCl return job @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]: pass @@ -48,14 +54,16 @@ class SqlStagingCopyJob(SqlBaseJob): failed_text: str = "Tried to generate a staging copy sql job for the following tables:" @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): staging_table_name = sql_client.make_qualified_table_name(table["name"]) table_name = sql_client.make_qualified_table_name(table["name"]) columns = ", ".join(map(sql_client.capabilities.escape_identifier, get_columns_names_with_prop(table, "name"))) - sql.append(sql_client._truncate_table_sql(table_name)) + if params["replace"]: + sql.append(sql_client._truncate_table_sql(table_name)) + print(f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name};") sql.append(f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name};") return sql @@ -64,7 +72,7 @@ class SqlMergeJob(SqlBaseJob): failed_text: str = "Tried to generate a merge sql job for the following tables:" @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]: """Generates a list of sql statements that merge the data in staging dataset with the data in destination dataset. The `table_chain` contains a list schemas of a tables with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). diff --git a/tests/load/test_iceberg.py b/tests/load/test_iceberg.py new file mode 100644 index 0000000000..35f4996ecd --- /dev/null +++ b/tests/load/test_iceberg.py @@ -0,0 +1,52 @@ +""" +Temporary test file for iceberg +""" + +import pytest +import os +import datetime # noqa: I251 +from typing import Iterator, Any + +import dlt +from dlt.common import pendulum +from dlt.common.utils import uniq_id +from tests.load.pipeline.utils import load_table_counts +from tests.cases import table_update_and_row, assert_all_data_types_row +from tests.pipeline.utils import assert_load_info + +from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration + +def test_iceberg() -> None: + + os.environ['DESTINATION__FILESYSTEM__BUCKET_URL'] = "s3://dlt-ci-test-bucket" + os.environ['DESTINATION__ATHENA__ICEBERG_BUCKET_URL'] = "s3://dlt-ci-test-bucket/iceberg" + + pipeline = dlt.pipeline(pipeline_name="aaathena", destination="athena", staging="filesystem", full_refresh=True) + + @dlt.resource(name="items", write_disposition="append") + def items(): + yield { + "id": 1, + "name": "item", + "sub_items": [{ + "id": 101, + "name": "sub item 101" + },{ + "id": 101, + "name": "sub item 102" + }] + } + + print(pipeline.run(items)) + + # see if we have athena tables with items + table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values() ]) + assert table_counts["items"] == 1 + assert table_counts["items__sub_items"] == 2 + assert table_counts["_dlt_loads"] == 1 + + pipeline.run(items) + table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values() ]) + assert table_counts["items"] == 2 + assert table_counts["items__sub_items"] == 4 + assert table_counts["_dlt_loads"] == 2 \ No newline at end of file