Skip to content

Commit

Permalink
first iceberg prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Sep 28, 2023
1 parent 2c5043c commit 0e25102
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 37 deletions.
41 changes: 31 additions & 10 deletions dlt/destinations/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@
from dlt.common.utils import without_none
from dlt.common.data_types import TDataType
from dlt.common.schema import TColumnSchema, Schema
from dlt.common.schema.typing import TTableSchema, TColumnType
from dlt.common.schema.typing import TTableSchema, TColumnType, TWriteDisposition
from dlt.common.schema.utils import table_schema_has_type
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.destination.reference import LoadJob
from dlt.common.destination.reference import TLoadJobState
from dlt.common.destination.reference import LoadJob, FollowupJob
from dlt.common.destination.reference import TLoadJobState, NewLoadJob
from dlt.common.storages import FileStorage
from dlt.common.data_writers.escape import escape_bigquery_identifier

from dlt.destinations.sql_jobs import SqlStagingCopyJob

from dlt.destinations.typing import DBApi, DBTransaction
from dlt.destinations.exceptions import DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation, LoadJobTerminalException
from dlt.destinations.athena import capabilities
from dlt.destinations.sql_client import SqlClientBase, DBApiCursorImpl, raise_database_error, raise_open_connection_error
from dlt.destinations.typing import DBApiCursor
from dlt.destinations.job_client_impl import SqlJobClientBase, StorageSchemaInfo
from dlt.destinations.job_client_impl import SqlJobClientWithStaging
from dlt.destinations.athena.configuration import AthenaClientConfiguration
from dlt.destinations.type_mapping import TypeMapper
from dlt.destinations import path_utils
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(self) -> None:
DLTAthenaFormatter._INSTANCE = self


class DoNothingJob(LoadJob):
class DoNothingJob(LoadJob, FollowupJob):
"""The most lazy class of dlt"""

def __init__(self, file_path: str) -> None:
Expand All @@ -135,6 +135,7 @@ def exception(self) -> str:
# this part of code should be never reached
raise NotImplementedError()


class AthenaSQLClient(SqlClientBase[Connection]):

capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities()
Expand Down Expand Up @@ -274,9 +275,9 @@ def has_dataset(self) -> bool:
query = f"""SHOW DATABASES LIKE {self.fully_qualified_dataset_name()};"""
rows = self.execute_sql(query)
return len(rows) > 0



class AthenaClient(SqlJobClientBase):
class AthenaClient(SqlJobClientWithStaging):

capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities()

Expand Down Expand Up @@ -307,12 +308,22 @@ def _get_column_def_sql(self, c: TColumnSchema) -> str:

def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool) -> List[str]:

create_only_iceberg_tables = self.config.iceberg_bucket_url is not None and not self.in_staging_mode

bucket = self.config.staging_config.bucket_url
dataset = self.sql_client.dataset_name
if create_only_iceberg_tables:
bucket = self.config.iceberg_bucket_url

print(table_name)
print(bucket)

# TODO: we need to strip the staging layout from the table name, find a better way!
dataset = self.sql_client.dataset_name.replace("_staging", "")
sql: List[str] = []

# for the system tables we need to create empty iceberg tables to be able to run, DELETE and UPDATE queries
is_iceberg = self.schema.tables[table_name].get("write_disposition", None) == "skip"
# or if we are in iceberg mode, we create iceberg tables for all tables
is_iceberg = (self.schema.tables[table_name].get("write_disposition", None) == "skip") or create_only_iceberg_tables
columns = ", ".join([self._get_column_def_sql(c) for c in new_columns])

# this will fail if the table prefix is not properly defined
Expand Down Expand Up @@ -348,6 +359,16 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) ->
job = DoNothingJob(file_path)
return job

def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema], replace: bool) -> NewLoadJob:
"""update destination tables from staging tables"""
return SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": replace})

def get_stage_dispositions(self) -> List[TWriteDisposition]:
# in iceberg mode, we always use staging tables
if self.config.iceberg_bucket_url is not None:
return ["append", "replace", "merge"]
return []

@staticmethod
def is_dbapi_exception(ex: Exception) -> bool:
return isinstance(ex, Error)
1 change: 1 addition & 0 deletions dlt/destinations/athena/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
class AthenaClientConfiguration(DestinationClientDwhWithStagingConfiguration):
destination_name: Final[str] = "athena" # type: ignore[misc]
query_result_bucket: str = None
iceberg_bucket_url: Optional[str] = None
credentials: AwsCredentials = None
athena_work_group: Optional[str] = None
aws_data_catalog: Optional[str] = "awsdatacatalog"
Expand Down
4 changes: 2 additions & 2 deletions dlt/destinations/bigquery/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from dlt.destinations.bigquery import capabilities
from dlt.destinations.bigquery.configuration import BigQueryClientConfiguration
from dlt.destinations.bigquery.sql_client import BigQuerySqlClient, BQ_TERMINAL_REASONS
from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob
from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob, SqlJobParams
from dlt.destinations.job_impl import NewReferenceJob
from dlt.destinations.sql_client import SqlClientBase
from dlt.destinations.type_mapping import TypeMapper
Expand Down Expand Up @@ -138,7 +138,7 @@ def gen_key_table_clauses(cls, root_table_name: str, staging_root_table_name: st
class BigqueryStagingCopyJob(SqlStagingCopyJob):

@classmethod
def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]:
def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]:
sql: List[str] = []
for table in table_chain:
with sql_client.with_staging_dataset(staging=True):
Expand Down
33 changes: 24 additions & 9 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,26 +148,34 @@ def get_truncate_destination_table_dispositions(self) -> List[TWriteDisposition]
def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob:
return SqlMergeJob.from_table_chain(table_chain, self.sql_client)

def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob:
def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema], replace: bool) -> NewLoadJob:
"""update destination tables from staging tables"""
return SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client)
if not replace:
return None
return SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True})

def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob:
"""optimized replace strategy, defaults to _create_staging_copy_job for the basic client
for some destinations there are much faster destination updates at the cost of
dropping tables possible"""
return self._create_staging_copy_job(table_chain)
return self._create_staging_copy_job(table_chain, True)

def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]:
"""Creates a list of followup jobs for merge write disposition and staging replace strategies"""
jobs = super().create_table_chain_completed_followup_jobs(table_chain)
write_disposition = table_chain[0]["write_disposition"]
if write_disposition == "merge":
jobs.append(self._create_merge_job(table_chain))
if write_disposition == "append":
if job := self._create_staging_copy_job(table_chain, False):
jobs.append(job)
elif write_disposition == "merge":
if job := self._create_merge_job(table_chain):
jobs.append(job)
elif write_disposition == "replace" and self.config.replace_strategy == "insert-from-staging":
jobs.append(self._create_staging_copy_job(table_chain))
if job := self._create_staging_copy_job(table_chain, True):
jobs.append(job)
elif write_disposition == "replace" and self.config.replace_strategy == "staging-optimized":
jobs.append(self._create_optimized_replace_job(table_chain))
if job := self._create_optimized_replace_job(table_chain):
jobs.append(job)
return jobs

def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
Expand Down Expand Up @@ -431,10 +439,17 @@ def _commit_schema_update(self, schema: Schema, schema_str: str) -> None:


class SqlJobClientWithStaging(SqlJobClientBase, WithStagingDataset):

in_staging_mode: bool = False

@contextlib.contextmanager
def with_staging_dataset(self)-> Iterator["SqlJobClientBase"]:
with self.sql_client.with_staging_dataset(True):
yield self
try:
with self.sql_client.with_staging_dataset(True):
self.in_staging_mode = True
yield self
finally:
self.in_staging_mode = False

def get_stage_dispositions(self) -> List[TWriteDisposition]:
"""Returns a list of dispositions that require staging tables to be populated"""
Expand Down
4 changes: 2 additions & 2 deletions dlt/destinations/mssql/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dlt.common.schema.typing import TTableSchema, TColumnType
from dlt.common.utils import uniq_id

from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob
from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob, SqlJobParams

from dlt.destinations.insert_job_client import InsertValuesJobClient

Expand Down Expand Up @@ -83,7 +83,7 @@ def from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[i
class MsSqlStagingCopyJob(SqlStagingCopyJob):

@classmethod
def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]:
def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]:
sql: List[str] = []
for table in table_chain:
with sql_client.with_staging_dataset(staging=True):
Expand Down
4 changes: 2 additions & 2 deletions dlt/destinations/postgres/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dlt.common.schema import TColumnSchema, TColumnHint, Schema
from dlt.common.schema.typing import TTableSchema, TColumnType

from dlt.destinations.sql_jobs import SqlStagingCopyJob
from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams

from dlt.destinations.insert_job_client import InsertValuesJobClient

Expand Down Expand Up @@ -79,7 +79,7 @@ def from_db_type(self, db_type: str, precision: Optional[int] = None, scale: Opt
class PostgresStagingCopyJob(SqlStagingCopyJob):

@classmethod
def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]:
def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]:
sql: List[str] = []
for table in table_chain:
with sql_client.with_staging_dataset(staging=True):
Expand Down
7 changes: 3 additions & 4 deletions dlt/destinations/snowflake/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dlt.destinations.snowflake import capabilities
from dlt.destinations.snowflake.configuration import SnowflakeClientConfiguration
from dlt.destinations.snowflake.sql_client import SnowflakeSqlClient
from dlt.destinations.sql_jobs import SqlStagingCopyJob
from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams
from dlt.destinations.snowflake.sql_client import SnowflakeSqlClient
from dlt.destinations.job_impl import NewReferenceJob
from dlt.destinations.sql_client import SqlClientBase
Expand Down Expand Up @@ -157,13 +157,12 @@ def exception(self) -> str:
class SnowflakeStagingCopyJob(SqlStagingCopyJob):

@classmethod
def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]:
def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]:
sql: List[str] = []
for table in table_chain:
with sql_client.with_staging_dataset(staging=True):
staging_table_name = sql_client.make_qualified_table_name(table["name"])
table_name = sql_client.make_qualified_table_name(table["name"])
# drop destination table
sql.append(f"DROP TABLE IF EXISTS {table_name};")
# recreate destination table with data cloned from staging table
sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};")
Expand Down Expand Up @@ -206,7 +205,7 @@ def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str
return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)]

def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob:
return SnowflakeStagingCopyJob.from_table_chain(table_chain, self.sql_client)
return SnowflakeStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True})

def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool, separate_alters: bool = False) -> List[str]:
sql = super()._get_table_update_sql(table_name, new_columns, generate_alter)
Expand Down
24 changes: 16 additions & 8 deletions dlt/destinations/sql_jobs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, List, Sequence, Tuple, cast
from typing import Any, Callable, List, Sequence, Tuple, cast, TypedDict, Optional

import yaml
from dlt.common.runtime.logger import pretty_format_exception
Expand All @@ -11,24 +11,30 @@
from dlt.destinations.job_impl import NewLoadJobImpl
from dlt.destinations.sql_client import SqlClientBase

class SqlJobParams(TypedDict):
replace: Optional[bool]

DEFAULTS: SqlJobParams = {
"replace": False
}

class SqlBaseJob(NewLoadJobImpl):
"""Sql base job for jobs that rely on the whole tablechain"""
failed_text: str = ""

@classmethod
def from_table_chain(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> NewLoadJobImpl:
def from_table_chain(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> NewLoadJobImpl:
"""Generates a list of sql statements, that will be executed by the sql client when the job is executed in the loader.
The `table_chain` contains a list schemas of a tables with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list).
"""

params = cast(SqlJobParams, {**DEFAULTS, **(params or {})}) # type: ignore
top_table = table_chain[0]
file_info = ParsedLoadJobFileName(top_table["name"], uniq_id()[:10], 0, "sql")
try:
# Remove line breaks from multiline statements and write one SQL statement per line in output file
# to support clients that need to execute one statement at a time (i.e. snowflake)
sql = [' '.join(stmt.splitlines()) for stmt in cls.generate_sql(table_chain, sql_client)]
sql = [' '.join(stmt.splitlines()) for stmt in cls.generate_sql(table_chain, sql_client, params)]
job = cls(file_info.job_id(), "running")
job._save_text_file("\n".join(sql))
except Exception:
Expand All @@ -39,7 +45,7 @@ def from_table_chain(cls, table_chain: Sequence[TTableSchema], sql_client: SqlCl
return job

@classmethod
def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]:
def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]:
pass


Expand All @@ -48,14 +54,16 @@ class SqlStagingCopyJob(SqlBaseJob):
failed_text: str = "Tried to generate a staging copy sql job for the following tables:"

@classmethod
def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]:
def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]:
sql: List[str] = []
for table in table_chain:
with sql_client.with_staging_dataset(staging=True):
staging_table_name = sql_client.make_qualified_table_name(table["name"])
table_name = sql_client.make_qualified_table_name(table["name"])
columns = ", ".join(map(sql_client.capabilities.escape_identifier, get_columns_names_with_prop(table, "name")))
sql.append(sql_client._truncate_table_sql(table_name))
if params["replace"]:
sql.append(sql_client._truncate_table_sql(table_name))
print(f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name};")
sql.append(f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name};")
return sql

Expand All @@ -64,7 +72,7 @@ class SqlMergeJob(SqlBaseJob):
failed_text: str = "Tried to generate a merge sql job for the following tables:"

@classmethod
def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]:
def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]:
"""Generates a list of sql statements that merge the data in staging dataset with the data in destination dataset.
The `table_chain` contains a list schemas of a tables with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list).
Expand Down
52 changes: 52 additions & 0 deletions tests/load/test_iceberg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Temporary test file for iceberg
"""

import pytest
import os
import datetime # noqa: I251
from typing import Iterator, Any

import dlt
from dlt.common import pendulum
from dlt.common.utils import uniq_id
from tests.load.pipeline.utils import load_table_counts
from tests.cases import table_update_and_row, assert_all_data_types_row
from tests.pipeline.utils import assert_load_info

from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration

def test_iceberg() -> None:

os.environ['DESTINATION__FILESYSTEM__BUCKET_URL'] = "s3://dlt-ci-test-bucket"
os.environ['DESTINATION__ATHENA__ICEBERG_BUCKET_URL'] = "s3://dlt-ci-test-bucket/iceberg"

pipeline = dlt.pipeline(pipeline_name="aaathena", destination="athena", staging="filesystem", full_refresh=True)

@dlt.resource(name="items", write_disposition="append")
def items():
yield {
"id": 1,
"name": "item",
"sub_items": [{
"id": 101,
"name": "sub item 101"
},{
"id": 101,
"name": "sub item 102"
}]
}

print(pipeline.run(items))

# see if we have athena tables with items
table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values() ])
assert table_counts["items"] == 1
assert table_counts["items__sub_items"] == 2
assert table_counts["_dlt_loads"] == 1

pipeline.run(items)
table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values() ])
assert table_counts["items"] == 2
assert table_counts["items__sub_items"] == 4
assert table_counts["_dlt_loads"] == 2

0 comments on commit 0e25102

Please sign in to comment.