diff --git a/.github/workflows/test_destination_athena.yml b/.github/workflows/test_destination_athena.yml index 16c9caff53..704e66522b 100644 --- a/.github/workflows/test_destination_athena.yml +++ b/.github/workflows/test_destination_athena.yml @@ -21,6 +21,7 @@ env: RUNTIME__DLTHUB_TELEMETRY_SEGMENT_WRITE_KEY: TLJiyRkGVZGCi2TtjClamXpFcxAA1rSB ACTIVE_DESTINATIONS: "[\"athena\"]" ALL_FILESYSTEM_DRIVERS: "[\"memory\"]" + EXCLUDED_DESTINATION_CONFIGURATIONS: "[\"athena-parquet-staging-iceberg\"]" jobs: get_docs_changes: diff --git a/.github/workflows/test_destination_athena_iceberg.yml b/.github/workflows/test_destination_athena_iceberg.yml new file mode 100644 index 0000000000..6892a96bf1 --- /dev/null +++ b/.github/workflows/test_destination_athena_iceberg.yml @@ -0,0 +1,94 @@ + +name: test athena iceberg + +on: + pull_request: + branches: + - master + - devel + workflow_dispatch: + +env: + DESTINATION__FILESYSTEM__CREDENTIALS__AWS_ACCESS_KEY_ID: AKIAT4QMVMC4J46G55G4 + DESTINATION__FILESYSTEM__CREDENTIALS__AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + DESTINATION__ATHENA__CREDENTIALS__AWS_ACCESS_KEY_ID: AKIAT4QMVMC4J46G55G4 + DESTINATION__ATHENA__CREDENTIALS__AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + DESTINATION__ATHENA__CREDENTIALS__REGION_NAME: eu-central-1 + DESTINATION__ATHENA__QUERY_RESULT_BUCKET: s3://dlt-athena-output + + RUNTIME__SENTRY_DSN: https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 + RUNTIME__LOG_LEVEL: ERROR + RUNTIME__DLTHUB_TELEMETRY_SEGMENT_WRITE_KEY: TLJiyRkGVZGCi2TtjClamXpFcxAA1rSB + ACTIVE_DESTINATIONS: "[\"athena\"]" + ALL_FILESYSTEM_DRIVERS: "[\"memory\"]" + EXCLUDED_DESTINATION_CONFIGURATIONS: "[\"athena-no-staging\"]" + +jobs: + get_docs_changes: + uses: ./.github/workflows/get_docs_changes.yml + # Tests that require credentials do not run in forks + if: ${{ !github.event.pull_request.head.repo.fork }} + + run_loader: + name: test destination athena iceberg + needs: get_docs_changes + if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest"] + # os: ["ubuntu-latest", "macos-latest", "windows-latest"] + defaults: + run: + shell: bash + runs-on: ${{ matrix.os }} + + steps: + + - name: Check out + uses: actions/checkout@master + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: "3.10.x" + + - name: Install Poetry + uses: snok/install-poetry@v1.3.2 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v3 + with: + # path: ${{ steps.pip-cache.outputs.dir }} + path: .venv + key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-athena + + - name: Install dependencies + # if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' + run: poetry install --no-interaction -E athena + + - run: | + poetry run pytest tests/load + if: runner.os != 'Windows' + name: Run tests Linux/MAC + - run: | + poetry run pytest tests/load + if: runner.os == 'Windows' + name: Run tests Windows + shell: cmd + + matrix_job_required_check: + name: Redshift, PostgreSQL and DuckDB tests + needs: run_loader + runs-on: ubuntu-latest + if: always() + steps: + - name: Check matrix job results + if: contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') + run: | + echo "One or more matrix job tests failed or were cancelled. You may need to re-run them." && exit 1 diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 0bf4088b27..13172b41e9 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -4,17 +4,20 @@ from typing import ClassVar, Final, Optional, NamedTuple, Literal, Sequence, Iterable, Type, Protocol, Union, TYPE_CHECKING, cast, List, ContextManager, Dict, Any from contextlib import contextmanager import datetime # noqa: 251 +from copy import deepcopy from dlt.common import logger from dlt.common.exceptions import IdentifierTooLongException, InvalidDestinationReference, UnknownDestinationModule from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import TWriteDisposition from dlt.common.schema.exceptions import InvalidDatasetName +from dlt.common.schema.utils import get_write_disposition, get_table_format from dlt.common.configuration import configspec from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.configuration.accessors import config from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.schema.utils import is_complete_column +from dlt.common.schema.exceptions import UnknownTableException from dlt.common.storages import FileStorage from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.utils import get_module_name @@ -244,9 +247,8 @@ def restore_file_load(self, file_path: str) -> LoadJob: """Finds and restores already started loading job identified by `file_path` if destination supports it.""" pass - def get_truncate_destination_table_dispositions(self) -> List[TWriteDisposition]: - # in the base job, all replace strategies are treated the same, see filesystem for example - return ["replace"] + def should_truncate_table_before_load(self, table: TTableSchema) -> bool: + return table["write_disposition"] == "replace" def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: """Creates a list of followup jobs that should be executed after a table chain is completed""" @@ -287,6 +289,21 @@ def _verify_schema(self) -> None: if not is_complete_column(column): logger.warning(f"A column {column_name} in table {table_name} in schema {self.schema.name} is incomplete. It was not bound to the data during normalizations stage and its data type is unknown. Did you add this column manually in code ie. as a merge key?") + def get_load_table(self, table_name: str, prepare_for_staging: bool = False) -> TTableSchema: + if table_name not in self.schema.tables: + return None + try: + # make a copy of the schema so modifications do not affect the original document + table = deepcopy(self.schema.tables[table_name]) + # add write disposition if not specified - in child tables + if "write_disposition" not in table: + table["write_disposition"] = get_write_disposition(self.schema.tables, table_name) + if "table_format" not in table: + table["table_format"] = get_table_format(self.schema.tables, table_name) + return table + except KeyError: + raise UnknownTableException(table_name) + class WithStateSync(ABC): @@ -309,15 +326,23 @@ class WithStagingDataset(ABC): """Adds capability to use staging dataset and request it from the loader""" @abstractmethod - def get_stage_dispositions(self) -> List[TWriteDisposition]: - """Returns a list of write dispositions that require staging dataset""" - return [] + def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: + return False @abstractmethod def with_staging_dataset(self)-> ContextManager["JobClientBase"]: """Executes job client methods on staging dataset""" return self # type: ignore +class SupportsStagingDestination(): + """Adds capability to support a staging destination for the load""" + + def should_load_data_to_staging_dataset_on_staging_destination(self, table: TTableSchema) -> bool: + return False + + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + # the default is to truncate the tables on the staging destination... + return True TDestinationReferenceArg = Union["DestinationReference", ModuleType, None, str] diff --git a/dlt/common/schema/exceptions.py b/dlt/common/schema/exceptions.py index 2245a77b61..5f638a111d 100644 --- a/dlt/common/schema/exceptions.py +++ b/dlt/common/schema/exceptions.py @@ -69,3 +69,8 @@ def __init__(self, schema_name: str, init_engine: int, from_engine: int, to_engi self.from_engine = from_engine self.to_engine = to_engine super().__init__(f"No engine upgrade path in schema {schema_name} from {init_engine} to {to_engine}, stopped at {from_engine}") + +class UnknownTableException(SchemaException): + def __init__(self, table_name: str) -> None: + self.table_name = table_name + super().__init__(f"Trying to access unknown table {table_name}.") \ No newline at end of file diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index ae24691e2d..2cc057560c 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -24,6 +24,7 @@ TColumnHint = Literal["not_null", "partition", "cluster", "primary_key", "foreign_key", "sort", "unique", "root_key", "merge_key"] """Known hints of a column used to declare hint regexes.""" TWriteDisposition = Literal["skip", "append", "replace", "merge"] +TTableFormat = Literal["iceberg"] TTypeDetections = Literal["timestamp", "iso_timestamp", "large_integer", "hexbytes_to_text", "wei_to_double"] TTypeDetectionFunc = Callable[[Type[Any], Any], Optional[TDataType]] TColumnNames = Union[str, Sequence[str]] @@ -86,6 +87,7 @@ class TTableSchema(TypedDict, total=False): filters: Optional[TRowFilters] columns: TTableSchemaColumns resource: Optional[str] + table_format: Optional[TTableFormat] class TPartialTableSchema(TTableSchema): diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index a378a34db0..32bd4ade1c 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -15,10 +15,10 @@ from dlt.common.validation import TCustomValidator, validate_dict, validate_dict_ignoring_xkeys from dlt.common.schema import detections from dlt.common.schema.typing import (COLUMN_HINTS, SCHEMA_ENGINE_VERSION, LOADS_TABLE_NAME, SIMPLE_REGEX_PREFIX, VERSION_TABLE_NAME, TColumnName, TPartialTableSchema, TSchemaTables, TSchemaUpdate, - TSimpleRegex, TStoredSchema, TTableSchema, TTableSchemaColumns, TColumnSchemaBase, TColumnSchema, TColumnProp, + TSimpleRegex, TStoredSchema, TTableSchema, TTableSchemaColumns, TColumnSchemaBase, TColumnSchema, TColumnProp, TTableFormat, TColumnHint, TTypeDetectionFunc, TTypeDetections, TWriteDisposition) from dlt.common.schema.exceptions import (CannotCoerceColumnException, ParentTableNotFoundException, SchemaEngineNoUpgradePathException, SchemaException, - TablePropertiesConflictException, InvalidSchemaName) + TablePropertiesConflictException, InvalidSchemaName, UnknownTableException) from dlt.common.normalizers.utils import import_normalizers from dlt.common.schema.typing import TAnySchemaColumns @@ -493,18 +493,29 @@ def merge_schema_updates(schema_updates: Sequence[TSchemaUpdate]) -> TSchemaTabl return aggregated_update -def get_write_disposition(tables: TSchemaTables, table_name: str) -> TWriteDisposition: - """Returns write disposition of a table if present. If not, looks up into parent table""" - table = tables[table_name] - w_d = table.get("write_disposition") - if w_d: - return w_d +def get_inherited_table_hint(tables: TSchemaTables, table_name: str, table_hint_name: str, allow_none: bool = False) -> Any: + table = tables.get(table_name, {}) + hint = table.get(table_hint_name) + if hint: + return hint parent = table.get("parent") if parent: - return get_write_disposition(tables, parent) + return get_inherited_table_hint(tables, parent, table_hint_name, allow_none) + + if allow_none: + return None + + raise ValueError(f"No table hint '{table_hint_name} found in the chain of tables for '{table_name}'.") + + +def get_write_disposition(tables: TSchemaTables, table_name: str) -> TWriteDisposition: + """Returns table hint of a table if present. If not, looks up into parent table""" + return cast(TWriteDisposition, get_inherited_table_hint(tables, table_name, "write_disposition", allow_none=False)) + - raise ValueError(f"No write disposition found in the chain of tables for '{table_name}'.") +def get_table_format(tables: TSchemaTables, table_name: str) -> TTableFormat: + return cast(TTableFormat, get_inherited_table_hint(tables, table_name, "table_format", allow_none=True)) def table_schema_has_type(table: TTableSchema, _typ: TDataType) -> bool: @@ -637,7 +648,8 @@ def new_table( write_disposition: TWriteDisposition = None, columns: Sequence[TColumnSchema] = None, validate_schema: bool = False, - resource: str = None + resource: str = None, + table_format: TTableFormat = None ) -> TTableSchema: table: TTableSchema = { @@ -652,6 +664,8 @@ def new_table( # set write disposition only for root tables table["write_disposition"] = write_disposition or DEFAULT_WRITE_DISPOSITION table["resource"] = resource or table_name + if table_format: + table["table_format"] = table_format if validate_schema: validate_dict_ignoring_xkeys( spec=TColumnSchema, diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index 95170ac46c..2f52365787 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -237,8 +237,11 @@ def list_failed_jobs(self, load_id: str) -> Sequence[str]: return self.storage.list_folder_files(self._get_job_folder_path(load_id, LoadStorage.FAILED_JOBS_FOLDER)) def list_jobs_for_table(self, load_id: str, table_name: str) -> Sequence[LoadJobInfo]: + return [job for job in self.list_all_jobs(load_id) if job.job_file_info.table_name == table_name] + + def list_all_jobs(self, load_id: str) -> Sequence[LoadJobInfo]: info = self.get_load_package_info(load_id) - return [job for job in flatten_list_or_items(iter(info.jobs.values())) if job.job_file_info.table_name == table_name] # type: ignore + return [job for job in flatten_list_or_items(iter(info.jobs.values()))] # type: ignore def list_completed_failed_jobs(self, load_id: str) -> Sequence[str]: return self.storage.list_folder_files(self._get_job_folder_completed_path(load_id, LoadStorage.FAILED_JOBS_FOLDER)) diff --git a/dlt/destinations/athena/__init__.py b/dlt/destinations/athena/__init__.py index 531744f6e6..d19a0ffdb7 100644 --- a/dlt/destinations/athena/__init__.py +++ b/dlt/destinations/athena/__init__.py @@ -36,6 +36,7 @@ def capabilities() -> DestinationCapabilitiesContext: caps.alter_add_multi_column = True caps.schema_supports_numeric_precision = False caps.timestamp_precision = 3 + caps.supports_truncate_command = False return caps diff --git a/dlt/destinations/athena/athena.py b/dlt/destinations/athena/athena.py index ed8364aa3a..44d020c127 100644 --- a/dlt/destinations/athena/athena.py +++ b/dlt/destinations/athena/athena.py @@ -15,22 +15,22 @@ from dlt.common import logger from dlt.common.utils import without_none from dlt.common.data_types import TDataType -from dlt.common.schema import TColumnSchema, Schema -from dlt.common.schema.typing import TTableSchema, TColumnType -from dlt.common.schema.utils import table_schema_has_type +from dlt.common.schema import TColumnSchema, Schema, TSchemaTables, TTableSchema +from dlt.common.schema.typing import TTableSchema, TColumnType, TWriteDisposition, TTableFormat +from dlt.common.schema.utils import table_schema_has_type, get_table_format from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import LoadJob -from dlt.common.destination.reference import TLoadJobState +from dlt.common.destination.reference import LoadJob, FollowupJob +from dlt.common.destination.reference import TLoadJobState, NewLoadJob, SupportsStagingDestination from dlt.common.storages import FileStorage from dlt.common.data_writers.escape import escape_bigquery_identifier - +from dlt.destinations.sql_jobs import SqlStagingCopyJob from dlt.destinations.typing import DBApi, DBTransaction from dlt.destinations.exceptions import DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation, LoadJobTerminalException from dlt.destinations.athena import capabilities from dlt.destinations.sql_client import SqlClientBase, DBApiCursorImpl, raise_database_error, raise_open_connection_error from dlt.destinations.typing import DBApiCursor -from dlt.destinations.job_client_impl import SqlJobClientBase, StorageSchemaInfo +from dlt.destinations.job_client_impl import SqlJobClientWithStaging from dlt.destinations.athena.configuration import AthenaClientConfiguration from dlt.destinations.type_mapping import TypeMapper from dlt.destinations import path_utils @@ -69,13 +69,16 @@ class AthenaTypeMapper(TypeMapper): "int": "bigint", } - def to_db_integer_type(self, precision: Optional[int]) -> str: + def __init__(self, capabilities: DestinationCapabilitiesContext): + super().__init__(capabilities) + + def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: if precision is None: return "bigint" if precision <= 8: - return "tinyint" + return "int" if table_format == "iceberg" else "tinyint" elif precision <= 16: - return "smallint" + return "int" if table_format == "iceberg" else "smallint" elif precision <= 32: return "int" return "bigint" @@ -135,6 +138,11 @@ def exception(self) -> str: # this part of code should be never reached raise NotImplementedError() +class DoNothingFollowupJob(DoNothingJob, FollowupJob): + """The second most lazy class of dlt""" + pass + + class AthenaSQLClient(SqlClientBase[Connection]): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -276,7 +284,7 @@ def has_dataset(self) -> bool: return len(rows) > 0 -class AthenaClient(SqlJobClientBase): +class AthenaClient(SqlJobClientWithStaging, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -296,28 +304,33 @@ def __init__(self, schema: Schema, config: AthenaClientConfiguration) -> None: self.type_mapper = AthenaTypeMapper(self.capabilities) def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: - # never truncate tables in athena - super().initialize_storage([]) + # only truncate tables in iceberg mode + truncate_tables = [] + super().initialize_storage(truncate_tables) def _from_db_type(self, hive_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: return self.type_mapper.from_db_type(hive_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema) -> str: - return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c)}" + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}" def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool) -> List[str]: bucket = self.config.staging_config.bucket_url dataset = self.sql_client.dataset_name + sql: List[str] = [] # for the system tables we need to create empty iceberg tables to be able to run, DELETE and UPDATE queries - is_iceberg = self.schema.tables[table_name].get("write_disposition", None) == "skip" - columns = ", ".join([self._get_column_def_sql(c) for c in new_columns]) + # or if we are in iceberg mode, we create iceberg tables for all tables + table = self.get_load_table(table_name, self.in_staging_mode) + is_iceberg = self._is_iceberg_table(table) or table.get("write_disposition", None) == "skip" + columns = ", ".join([self._get_column_def_sql(c, table.get("table_format")) for c in new_columns]) # this will fail if the table prefix is not properly defined table_prefix = self.table_prefix_layout.format(table_name=table_name) location = f"{bucket}/{dataset}/{table_prefix}" + # use qualified table names qualified_table_name = self.sql_client.make_qualified_ddl_table_name(table_name) if is_iceberg and not generate_alter: @@ -345,9 +358,52 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> ) job = super().start_file_load(table, file_path, load_id) if not job: - job = DoNothingJob(file_path) + job = DoNothingFollowupJob(file_path) if self._is_iceberg_table(self.get_load_table(table["name"])) else DoNothingJob(file_path) return job + def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + if self._is_iceberg_table(self.get_load_table(table_chain[0]["name"])): + return [SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": False})] + return super()._create_append_followup_jobs(table_chain) + + def _create_replace_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + if self._is_iceberg_table(self.get_load_table(table_chain[0]["name"])): + return [SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True})] + return super()._create_replace_followup_jobs(table_chain) + + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + # fall back to append jobs for merge + return self._create_append_followup_jobs(table_chain) + + def _is_iceberg_table(self, table: TTableSchema) -> bool: + table_format = table.get("table_format") + return table_format == "iceberg" + + def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: + # all iceberg tables need staging + if self._is_iceberg_table(self.get_load_table(table["name"])): + return True + return super().should_load_data_to_staging_dataset(table) + + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + # on athena we only truncate replace tables that are not iceberg + table = self.get_load_table(table["name"]) + if table["write_disposition"] == "replace" and not self._is_iceberg_table(self.get_load_table(table["name"])): + return True + return False + + def should_load_data_to_staging_dataset_on_staging_destination(self, table: TTableSchema) -> bool: + """iceberg table data goes into staging on staging destination""" + return self._is_iceberg_table(self.get_load_table(table["name"])) + + def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: + table = super().get_load_table(table_name, staging) + if self.config.force_iceberg: + table["table_format"] ="iceberg" + if staging and table.get("table_format", None) == "iceberg": + table.pop("table_format") + return table + @staticmethod def is_dbapi_exception(ex: Exception) -> bool: return isinstance(ex, Error) diff --git a/dlt/destinations/athena/configuration.py b/dlt/destinations/athena/configuration.py index f6e6fa3b51..5dd1341c34 100644 --- a/dlt/destinations/athena/configuration.py +++ b/dlt/destinations/athena/configuration.py @@ -12,6 +12,8 @@ class AthenaClientConfiguration(DestinationClientDwhWithStagingConfiguration): credentials: AwsCredentials = None athena_work_group: Optional[str] = None aws_data_catalog: Optional[str] = "awsdatacatalog" + supports_truncate_command: bool = False + force_iceberg: Optional[bool] = False __config_gen_annotations__: ClassVar[List[str]] = ["athena_work_group"] diff --git a/dlt/destinations/bigquery/bigquery.py b/dlt/destinations/bigquery/bigquery.py index 4dffd06c24..9cc7591f57 100644 --- a/dlt/destinations/bigquery/bigquery.py +++ b/dlt/destinations/bigquery/bigquery.py @@ -7,19 +7,20 @@ from dlt.common import json, logger from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import FollowupJob, NewLoadJob, TLoadJobState, LoadJob +from dlt.common.destination.reference import FollowupJob, NewLoadJob, TLoadJobState, LoadJob, SupportsStagingDestination from dlt.common.data_types import TDataType from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns -from dlt.common.schema.typing import TTableSchema, TColumnType +from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat +from dlt.common.schema.exceptions import UnknownTableException from dlt.destinations.job_client_impl import SqlJobClientWithStaging -from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate, DestinationTransientException, LoadJobNotExistsException, LoadJobTerminalException, LoadJobUnknownTableException +from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate, DestinationTransientException, LoadJobNotExistsException, LoadJobTerminalException from dlt.destinations.bigquery import capabilities from dlt.destinations.bigquery.configuration import BigQueryClientConfiguration from dlt.destinations.bigquery.sql_client import BigQuerySqlClient, BQ_TERMINAL_REASONS -from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob +from dlt.destinations.sql_jobs import SqlMergeJob, SqlStagingCopyJob, SqlJobParams from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper @@ -134,7 +135,7 @@ def gen_key_table_clauses(cls, root_table_name: str, staging_root_table_name: st class BigqueryStagingCopyJob(SqlStagingCopyJob): @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): @@ -146,7 +147,7 @@ def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClient sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};") return sql -class BigQueryClient(SqlJobClientWithStaging): +class BigQueryClient(SqlJobClientWithStaging, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -163,11 +164,13 @@ def __init__(self, schema: Schema, config: BigQueryClientConfiguration) -> None: self.sql_client: BigQuerySqlClient = sql_client # type: ignore self.type_mapper = BigQueryTypeMapper(self.capabilities) - def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - return BigQueryMergeJob.from_table_chain(table_chain, self.sql_client) + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + return [BigQueryMergeJob.from_table_chain(table_chain, self.sql_client)] - def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - return BigqueryStagingCopyJob.from_table_chain(table_chain, self.sql_client) + def _create_replace_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + if self.config.replace_strategy == "staging-optimized": + return [BigqueryStagingCopyJob.from_table_chain(table_chain, self.sql_client)] + return super()._create_replace_followup_jobs(table_chain) def restore_file_load(self, file_path: str) -> LoadJob: """Returns a completed SqlLoadJob or restored BigQueryLoadJob @@ -214,7 +217,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> reason = BigQuerySqlClient._get_reason_from_errors(gace) if reason == "notFound": # google.api_core.exceptions.NotFound: 404 - table not found - raise LoadJobUnknownTableException(table["name"], file_path) + raise UnknownTableException(table["name"]) elif reason == "duplicate": # google.api_core.exceptions.Conflict: 409 PUT - already exists return self.restore_file_load(file_path) @@ -243,9 +246,9 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc return sql - def _get_column_def_sql(self, c: TColumnSchema) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: name = self.capabilities.escape_identifier(c["name"]) - return f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" + return f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}" def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: schema_table: TTableSchemaColumns = {} diff --git a/dlt/destinations/duckdb/duck.py b/dlt/destinations/duckdb/duck.py index c40abd56a0..fe4ebac37e 100644 --- a/dlt/destinations/duckdb/duck.py +++ b/dlt/destinations/duckdb/duck.py @@ -5,7 +5,7 @@ from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.destination.reference import LoadJob, FollowupJob, TLoadJobState -from dlt.common.schema.typing import TTableSchema, TColumnType +from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import maybe_context @@ -65,7 +65,7 @@ class DuckDbTypeMapper(TypeMapper): "HUGEINT": "bigint", } - def to_db_integer_type(self, precision: Optional[int]) -> str: + def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: if precision is None: return "BIGINT" # Precision is number of bits @@ -141,7 +141,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> job = DuckDbCopyJob(table["name"], file_path, self.sql_client) return job - def _get_column_def_sql(self, c: TColumnSchema) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: hints_str = " ".join(self.active_hints.get(h, "") for h in self.active_hints.keys() if c.get(h, False) is True) column_name = self.capabilities.escape_identifier(c["name"]) return f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" diff --git a/dlt/destinations/exceptions.py b/dlt/destinations/exceptions.py index f0fe32f950..5c20f081f1 100644 --- a/dlt/destinations/exceptions.py +++ b/dlt/destinations/exceptions.py @@ -63,12 +63,6 @@ def __init__(self, file_path: str, message: str) -> None: super().__init__(f"Job with id/file name {file_path} encountered unrecoverable problem: {message}") -class LoadJobUnknownTableException(DestinationTerminalException): - def __init__(self, table_name: str, file_name: str) -> None: - self.table_name = table_name - super().__init__(f"Client does not know table {table_name} for load file {file_name}") - - class LoadJobInvalidStateTransitionException(DestinationTerminalException): def __init__(self, from_state: TLoadJobState, to_state: TLoadJobState) -> None: self.from_state = from_state diff --git a/dlt/destinations/filesystem/filesystem.py b/dlt/destinations/filesystem/filesystem.py index 1d9caf036d..49ad36dd16 100644 --- a/dlt/destinations/filesystem/filesystem.py +++ b/dlt/destinations/filesystem/filesystem.py @@ -1,14 +1,15 @@ import posixpath import os from types import TracebackType -from typing import ClassVar, List, Type, Iterable, Set +from typing import ClassVar, List, Type, Iterable, Set, Iterator from fsspec import AbstractFileSystem +from contextlib import contextmanager from dlt.common import logger from dlt.common.schema import Schema, TSchemaTables, TTableSchema from dlt.common.storages import FileStorage, LoadStorage, fsspec_from_config from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import NewLoadJob, TLoadJobState, LoadJob, JobClientBase, FollowupJob +from dlt.common.destination.reference import NewLoadJob, TLoadJobState, LoadJob, JobClientBase, FollowupJob, WithStagingDataset from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.filesystem import capabilities @@ -68,7 +69,7 @@ def create_followup_jobs(self, next_state: str) -> List[NewLoadJob]: return jobs -class FilesystemClient(JobClientBase): +class FilesystemClient(JobClientBase, WithStagingDataset): """filesystem client storing jobs in memory""" capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -82,16 +83,27 @@ def __init__(self, schema: Schema, config: FilesystemDestinationClientConfigurat # verify files layout. we need {table_name} and only allow {schema_name} before it, otherwise tables # cannot be replaced and we cannot initialize folders consistently self.table_prefix_layout = path_utils.get_table_prefix_layout(config.layout) - - @property - def dataset_path(self) -> str: - ds_path = posixpath.join(self.fs_path, self.config.normalize_dataset_name(self.schema)) - return ds_path + self._dataset_path = self.config.normalize_dataset_name(self.schema) def drop_storage(self) -> None: if self.is_storage_initialized(): self.fs_client.rm(self.dataset_path, recursive=True) + @property + def dataset_path(self) -> str: + return posixpath.join(self.fs_path, self._dataset_path) + + + @contextmanager + def with_staging_dataset(self) -> Iterator["FilesystemClient"]: + current_dataset_path = self._dataset_path + try: + self._dataset_path = self.schema.naming.normalize_table_identifier(current_dataset_path + "_staging") + yield self + finally: + # restore previous dataset name + self._dataset_path = current_dataset_path + def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: # clean up existing files for tables selected for truncating if truncate_tables and self.fs_client.isdir(self.dataset_path): @@ -169,3 +181,6 @@ def __enter__(self) -> "FilesystemClient": def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: pass + + def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: + return False diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index f2fc8da8d1..7dabf278c2 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -12,7 +12,7 @@ from dlt.common import json, pendulum, logger from dlt.common.data_types import TDataType -from dlt.common.schema.typing import COLUMN_HINTS, TColumnType, TColumnSchemaBase, TTableSchema, TWriteDisposition +from dlt.common.schema.typing import COLUMN_HINTS, TColumnType, TColumnSchemaBase, TTableSchema, TWriteDisposition, TTableFormat from dlt.common.storages import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables from dlt.common.destination.reference import StateInfo, StorageSchemaInfo,WithStateSync, DestinationClientConfiguration, DestinationClientDwhConfiguration, DestinationClientDwhWithStagingConfiguration, NewLoadJob, WithStagingDataset, TLoadJobState, LoadJob, JobClientBase, FollowupJob, CredentialsConfiguration @@ -140,34 +140,31 @@ def maybe_ddl_transaction(self) -> Iterator[None]: else: yield - def get_truncate_destination_table_dispositions(self) -> List[TWriteDisposition]: - if self.config.replace_strategy == "truncate-and-insert": - return ["replace"] - return [] + def should_truncate_table_before_load(self, table: TTableSchema) -> bool: + return table["write_disposition"] == "replace" and self.config.replace_strategy == "truncate-and-insert" - def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - return SqlMergeJob.from_table_chain(table_chain, self.sql_client) + def _create_append_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + return [] - def _create_staging_copy_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - """update destination tables from staging tables""" - return SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client) + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + return [SqlMergeJob.from_table_chain(table_chain, self.sql_client)] - def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - """optimized replace strategy, defaults to _create_staging_copy_job for the basic client - for some destinations there are much faster destination updates at the cost of - dropping tables possible""" - return self._create_staging_copy_job(table_chain) + def _create_replace_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + jobs: List[NewLoadJob] = [] + if self.config.replace_strategy in ["insert-from-staging", "staging-optimized"]: + jobs.append(SqlStagingCopyJob.from_table_chain(table_chain, self.sql_client, {"replace": True})) + return jobs def create_table_chain_completed_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: """Creates a list of followup jobs for merge write disposition and staging replace strategies""" jobs = super().create_table_chain_completed_followup_jobs(table_chain) write_disposition = table_chain[0]["write_disposition"] - if write_disposition == "merge": - jobs.append(self._create_merge_job(table_chain)) - elif write_disposition == "replace" and self.config.replace_strategy == "insert-from-staging": - jobs.append(self._create_staging_copy_job(table_chain)) - elif write_disposition == "replace" and self.config.replace_strategy == "staging-optimized": - jobs.append(self._create_optimized_replace_job(table_chain)) + if write_disposition == "append": + jobs.extend(self._create_append_followup_jobs(table_chain)) + elif write_disposition == "merge": + jobs.extend(self._create_merge_followup_jobs(table_chain)) + elif write_disposition == "replace": + jobs.extend(self._create_replace_followup_jobs(table_chain)) return jobs def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: @@ -314,30 +311,32 @@ def _build_schema_update_sql(self, only_tables: Iterable[str]) -> Tuple[List[str sql += ";" sql_updates.append(sql) # create a schema update for particular table - partial_table = copy(self.schema.get_table(table_name)) + partial_table = copy(self.get_load_table(table_name)) # keep only new columns partial_table["columns"] = {c["name"]: c for c in new_columns} schema_update[table_name] = partial_table return sql_updates, schema_update - def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str]: + def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None) -> List[str]: """Make one or more ADD COLUMN sql clauses to be joined in ALTER TABLE statement(s)""" - return [f"ADD COLUMN {self._get_column_def_sql(c)}" for c in new_columns] + return [f"ADD COLUMN {self._get_column_def_sql(c, table_format)}" for c in new_columns] def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool) -> List[str]: # build sql canonical_name = self.sql_client.make_qualified_table_name(table_name) + table = self.get_load_table(table_name) + table_format = table.get("table_format") if table else None sql_result: List[str] = [] if not generate_alter: # build CREATE sql = f"CREATE TABLE {canonical_name} (\n" - sql += ",\n".join([self._get_column_def_sql(c) for c in new_columns]) + sql += ",\n".join([self._get_column_def_sql(c, table_format) for c in new_columns]) sql += ")" sql_result.append(sql) else: sql_base = f"ALTER TABLE {canonical_name}\n" - add_column_statements = self._make_add_column_sql(new_columns) + add_column_statements = self._make_add_column_sql(new_columns, table_format) if self.capabilities.alter_add_multi_column: column_sql = ",\n" sql_result.append(sql_base + column_sql.join(add_column_statements)) @@ -360,7 +359,7 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc return sql_result @abstractmethod - def _get_column_def_sql(self, c: TColumnSchema) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: pass @staticmethod @@ -431,15 +430,22 @@ def _commit_schema_update(self, schema: Schema, schema_str: str) -> None: class SqlJobClientWithStaging(SqlJobClientBase, WithStagingDataset): + + in_staging_mode: bool = False + @contextlib.contextmanager def with_staging_dataset(self)-> Iterator["SqlJobClientBase"]: - with self.sql_client.with_staging_dataset(True): - yield self + try: + with self.sql_client.with_staging_dataset(True): + self.in_staging_mode = True + yield self + finally: + self.in_staging_mode = False + + def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: + if table["write_disposition"] == "merge": + return True + elif table["write_disposition"] == "replace" and (self.config.replace_strategy in ["insert-from-staging", "staging-optimized"]): + return True + return False - def get_stage_dispositions(self) -> List[TWriteDisposition]: - """Returns a list of dispositions that require staging tables to be populated""" - dispositions: List[TWriteDisposition] = ["merge"] - # if we have anything but the truncate-and-insert replace strategy, we need staging tables - if self.config.replace_strategy in ["insert-from-staging", "staging-optimized"]: - dispositions.append("replace") - return dispositions diff --git a/dlt/destinations/mssql/mssql.py b/dlt/destinations/mssql/mssql.py index 5ed3b706b8..cd999441ff 100644 --- a/dlt/destinations/mssql/mssql.py +++ b/dlt/destinations/mssql/mssql.py @@ -5,10 +5,10 @@ from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.schema.typing import TTableSchema, TColumnType +from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.common.utils import uniq_id -from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob +from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlMergeJob, SqlJobParams from dlt.destinations.insert_job_client import InsertValuesJobClient @@ -62,7 +62,7 @@ class MsSqlTypeMapper(TypeMapper): "int": "bigint", } - def to_db_integer_type(self, precision: Optional[int]) -> str: + def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: if precision is None: return "bigint" if precision <= 8: @@ -83,7 +83,7 @@ def from_db_type(self, db_type: str, precision: Optional[int], scale: Optional[i class MsSqlStagingCopyJob(SqlStagingCopyJob): @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): @@ -133,14 +133,14 @@ def __init__(self, schema: Schema, config: MsSqlClientConfiguration) -> None: self.active_hints = HINT_TO_MSSQL_ATTR if self.config.create_indexes else {} self.type_mapper = MsSqlTypeMapper(self.capabilities) - def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - return MsSqlMergeJob.from_table_chain(table_chain, self.sql_client) + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + return [MsSqlMergeJob.from_table_chain(table_chain, self.sql_client)] - def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str]: + def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None) -> List[str]: # Override because mssql requires multiple columns in a single ADD COLUMN clause - return ["ADD \n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] + return ["ADD \n" + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns)] - def _get_column_def_sql(self, c: TColumnSchema) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: sc_type = c["data_type"] if sc_type == "text" and c.get("unique"): # MSSQL does not allow index on large TEXT columns @@ -152,8 +152,10 @@ def _get_column_def_sql(self, c: TColumnSchema) -> str: column_name = self.capabilities.escape_identifier(c["name"]) return f"{column_name} {db_type} {hints_str} {self._gen_not_null(c['nullable'])}" - def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - return MsSqlStagingCopyJob.from_table_chain(table_chain, self.sql_client) + def _create_replace_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + if self.config.replace_strategy == "staging-optimized": + return [MsSqlStagingCopyJob.from_table_chain(table_chain, self.sql_client)] + return super()._create_replace_followup_jobs(table_chain) def _from_db_type(self, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: return self.type_mapper.from_db_type(pq_t, precision, scale) diff --git a/dlt/destinations/postgres/postgres.py b/dlt/destinations/postgres/postgres.py index ead5ab6639..2812d1d4c4 100644 --- a/dlt/destinations/postgres/postgres.py +++ b/dlt/destinations/postgres/postgres.py @@ -5,9 +5,9 @@ from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.schema.typing import TTableSchema, TColumnType +from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat -from dlt.destinations.sql_jobs import SqlStagingCopyJob +from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams from dlt.destinations.insert_job_client import InsertValuesJobClient @@ -59,7 +59,7 @@ class PostgresTypeMapper(TypeMapper): "integer": "bigint", } - def to_db_integer_type(self, precision: Optional[int]) -> str: + def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: if precision is None: return "bigint" # Precision is number of bits @@ -79,7 +79,7 @@ def from_db_type(self, db_type: str, precision: Optional[int] = None, scale: Opt class PostgresStagingCopyJob(SqlStagingCopyJob): @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): @@ -109,13 +109,15 @@ def __init__(self, schema: Schema, config: PostgresClientConfiguration) -> None: self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {} self.type_mapper = PostgresTypeMapper(self.capabilities) - def _get_column_def_sql(self, c: TColumnSchema) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: hints_str = " ".join(self.active_hints.get(h, "") for h in self.active_hints.keys() if c.get(h, False) is True) column_name = self.capabilities.escape_identifier(c["name"]) return f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" - def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - return PostgresStagingCopyJob.from_table_chain(table_chain, self.sql_client) + def _create_replace_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + if self.config.replace_strategy == "staging-optimized": + return [PostgresStagingCopyJob.from_table_chain(table_chain, self.sql_client)] + return super()._create_replace_followup_jobs(table_chain) def _from_db_type(self, pq_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: return self.type_mapper.from_db_type(pq_t, precision, scale) diff --git a/dlt/destinations/redshift/redshift.py b/dlt/destinations/redshift/redshift.py index 944bd24581..888f27ae7c 100644 --- a/dlt/destinations/redshift/redshift.py +++ b/dlt/destinations/redshift/redshift.py @@ -14,10 +14,10 @@ from typing import ClassVar, Dict, List, Optional, Sequence, Any from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import NewLoadJob, CredentialsConfiguration +from dlt.common.destination.reference import NewLoadJob, CredentialsConfiguration, SupportsStagingDestination from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.schema.typing import TTableSchema, TColumnType +from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults from dlt.destinations.insert_job_client import InsertValuesJobClient @@ -76,7 +76,7 @@ class RedshiftTypeMapper(TypeMapper): "integer": "bigint", } - def to_db_integer_type(self, precision: Optional[int]) -> str: + def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: if precision is None: return "bigint" if precision <= 16: @@ -187,7 +187,7 @@ def gen_key_table_clauses(cls, root_table_name: str, staging_root_table_name: st return SqlMergeJob.gen_key_table_clauses(root_table_name, staging_root_table_name, key_clauses, for_delete) -class RedshiftClient(InsertValuesJobClient): +class RedshiftClient(InsertValuesJobClient, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -201,10 +201,10 @@ def __init__(self, schema: Schema, config: RedshiftClientConfiguration) -> None: self.config: RedshiftClientConfiguration = config self.type_mapper = RedshiftTypeMapper(self.capabilities) - def _create_merge_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - return RedshiftMergeJob.from_table_chain(table_chain, self.sql_client) + def _create_merge_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)] - def _get_column_def_sql(self, c: TColumnSchema) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: hints_str = " ".join(HINT_TO_REDSHIFT_ATTR.get(h, "") for h in HINT_TO_REDSHIFT_ATTR.keys() if c.get(h, False) is True) column_name = self.capabilities.escape_identifier(c["name"]) return f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" diff --git a/dlt/destinations/snowflake/snowflake.py b/dlt/destinations/snowflake/snowflake.py index b9046cde75..f433ec7e7d 100644 --- a/dlt/destinations/snowflake/snowflake.py +++ b/dlt/destinations/snowflake/snowflake.py @@ -1,13 +1,13 @@ -from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, Any +from typing import ClassVar, Optional, Sequence, Tuple, List, Any from urllib.parse import urlparse, urlunparse from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import FollowupJob, NewLoadJob, TLoadJobState, LoadJob, CredentialsConfiguration -from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults, AzureCredentials, AzureCredentialsWithoutDefaults +from dlt.common.destination.reference import FollowupJob, NewLoadJob, TLoadJobState, LoadJob, CredentialsConfiguration, SupportsStagingDestination +from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults, AzureCredentialsWithoutDefaults from dlt.common.data_types import TDataType from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns -from dlt.common.schema.typing import TTableSchema, TColumnType +from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.destinations.job_client_impl import SqlJobClientWithStaging @@ -17,7 +17,7 @@ from dlt.destinations.snowflake import capabilities from dlt.destinations.snowflake.configuration import SnowflakeClientConfiguration from dlt.destinations.snowflake.sql_client import SnowflakeSqlClient -from dlt.destinations.sql_jobs import SqlStagingCopyJob +from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams from dlt.destinations.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.job_impl import NewReferenceJob from dlt.destinations.sql_client import SqlClientBase @@ -157,20 +157,19 @@ def exception(self) -> str: class SnowflakeStagingCopyJob(SqlStagingCopyJob): @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): staging_table_name = sql_client.make_qualified_table_name(table["name"]) table_name = sql_client.make_qualified_table_name(table["name"]) - # drop destination table sql.append(f"DROP TABLE IF EXISTS {table_name};") # recreate destination table with data cloned from staging table sql.append(f"CREATE TABLE {table_name} CLONE {staging_table_name};") return sql -class SnowflakeClient(SqlJobClientWithStaging): +class SnowflakeClient(SqlJobClientWithStaging, SupportsStagingDestination): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() def __init__(self, schema: Schema, config: SnowflakeClientConfiguration) -> None: @@ -201,12 +200,14 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> def restore_file_load(self, file_path: str) -> LoadJob: return EmptyLoadJob.from_file_path(file_path, "completed") - def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema]) -> List[str]: + def _make_add_column_sql(self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None) -> List[str]: # Override because snowflake requires multiple columns in a single ADD COLUMN clause - return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] + return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns)] - def _create_optimized_replace_job(self, table_chain: Sequence[TTableSchema]) -> NewLoadJob: - return SnowflakeStagingCopyJob.from_table_chain(table_chain, self.sql_client) + def _create_replace_followup_jobs(self, table_chain: Sequence[TTableSchema]) -> List[NewLoadJob]: + if self.config.replace_strategy == "staging-optimized": + return [SnowflakeStagingCopyJob.from_table_chain(table_chain, self.sql_client)] + return super()._create_replace_followup_jobs(table_chain) def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool, separate_alters: bool = False) -> List[str]: sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) @@ -221,7 +222,7 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc def _from_db_type(self, bq_t: str, precision: Optional[int], scale: Optional[int]) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: name = self.capabilities.escape_identifier(c["name"]) return f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index c1137ee9ad..4e8393ed74 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Sequence, Tuple, cast +from typing import Any, Callable, List, Sequence, Tuple, cast, TypedDict, Optional import yaml from dlt.common.runtime.logger import pretty_format_exception @@ -11,24 +11,30 @@ from dlt.destinations.job_impl import NewLoadJobImpl from dlt.destinations.sql_client import SqlClientBase +class SqlJobParams(TypedDict): + replace: Optional[bool] + +DEFAULTS: SqlJobParams = { + "replace": False +} class SqlBaseJob(NewLoadJobImpl): """Sql base job for jobs that rely on the whole tablechain""" failed_text: str = "" @classmethod - def from_table_chain(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> NewLoadJobImpl: + def from_table_chain(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> NewLoadJobImpl: """Generates a list of sql statements, that will be executed by the sql client when the job is executed in the loader. The `table_chain` contains a list schemas of a tables with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). """ - + params = cast(SqlJobParams, {**DEFAULTS, **(params or {})}) # type: ignore top_table = table_chain[0] file_info = ParsedLoadJobFileName(top_table["name"], uniq_id()[:10], 0, "sql") try: # Remove line breaks from multiline statements and write one SQL statement per line in output file # to support clients that need to execute one statement at a time (i.e. snowflake) - sql = [' '.join(stmt.splitlines()) for stmt in cls.generate_sql(table_chain, sql_client)] + sql = [' '.join(stmt.splitlines()) for stmt in cls.generate_sql(table_chain, sql_client, params)] job = cls(file_info.job_id(), "running") job._save_text_file("\n".join(sql)) except Exception: @@ -39,7 +45,7 @@ def from_table_chain(cls, table_chain: Sequence[TTableSchema], sql_client: SqlCl return job @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]: pass @@ -48,14 +54,15 @@ class SqlStagingCopyJob(SqlBaseJob): failed_text: str = "Tried to generate a staging copy sql job for the following tables:" @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]: sql: List[str] = [] for table in table_chain: with sql_client.with_staging_dataset(staging=True): staging_table_name = sql_client.make_qualified_table_name(table["name"]) table_name = sql_client.make_qualified_table_name(table["name"]) columns = ", ".join(map(sql_client.capabilities.escape_identifier, get_columns_names_with_prop(table, "name"))) - sql.append(sql_client._truncate_table_sql(table_name)) + if params["replace"]: + sql.append(sql_client._truncate_table_sql(table_name)) sql.append(f"INSERT INTO {table_name}({columns}) SELECT {columns} FROM {staging_table_name};") return sql @@ -64,7 +71,7 @@ class SqlMergeJob(SqlBaseJob): failed_text: str = "Tried to generate a merge sql job for the following tables:" @classmethod - def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any]) -> List[str]: + def generate_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None) -> List[str]: """Generates a list of sql statements that merge the data in staging dataset with the data in destination dataset. The `table_chain` contains a list schemas of a tables with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). @@ -154,7 +161,7 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien unique_column: str = None root_key_column: str = None - insert_temp_table_sql: str = None + insert_temp_table_name: str = None if len(table_chain) == 1: @@ -176,10 +183,10 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien # get first unique column unique_column = sql_client.capabilities.escape_identifier(unique_columns[0]) # create temp table with unique identifier - create_delete_temp_table_sql, delete_temp_table_sql = cls.gen_delete_temp_table_sql(unique_column, key_table_clauses) + create_delete_temp_table_sql, delete_temp_table_name = cls.gen_delete_temp_table_sql(unique_column, key_table_clauses) sql.extend(create_delete_temp_table_sql) # delete top table - sql.append(f"DELETE FROM {root_table_name} WHERE {unique_column} IN (SELECT * FROM {delete_temp_table_sql});") + sql.append(f"DELETE FROM {root_table_name} WHERE {unique_column} IN (SELECT * FROM {delete_temp_table_name});") # delete other tables for table in table_chain[1:]: table_name = sql_client.make_qualified_table_name(table["name"]) @@ -192,10 +199,10 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien f"There is no root foreign key (ie _dlt_root_id) in child table {table['name']} so it is not possible to refer to top level table {root_table['name']} unique column {unique_column}" ) root_key_column = sql_client.capabilities.escape_identifier(root_key_columns[0]) - sql.append(f"DELETE FROM {table_name} WHERE {root_key_column} IN (SELECT * FROM {delete_temp_table_sql});") + sql.append(f"DELETE FROM {table_name} WHERE {root_key_column} IN (SELECT * FROM {delete_temp_table_name});") # create temp table used to deduplicate, only when we have primary keys if primary_keys: - create_insert_temp_table_sql, insert_temp_table_sql = cls.gen_insert_temp_table_sql(staging_root_table_name, primary_keys, unique_column) + create_insert_temp_table_sql, insert_temp_table_name = cls.gen_insert_temp_table_sql(staging_root_table_name, primary_keys, unique_column) sql.extend(create_insert_temp_table_sql) # insert from staging to dataset, truncate staging table @@ -215,11 +222,11 @@ def gen_merge_sql(cls, table_chain: Sequence[TTableSchema], sql_client: SqlClien """ else: uniq_column = unique_column if table.get("parent") is None else root_key_column - insert_sql += f" WHERE {uniq_column} IN (SELECT * FROM {insert_temp_table_sql});" + insert_sql += f" WHERE {uniq_column} IN (SELECT * FROM {insert_temp_table_name});" if insert_sql.strip()[-1] != ";": insert_sql += ";" sql.append(insert_sql) # -- DELETE FROM {staging_table_name} WHERE 1=1; - return sql + return sql \ No newline at end of file diff --git a/dlt/destinations/type_mapping.py b/dlt/destinations/type_mapping.py index dcbb1a4261..3ddfee5904 100644 --- a/dlt/destinations/type_mapping.py +++ b/dlt/destinations/type_mapping.py @@ -1,6 +1,6 @@ from typing import Tuple, ClassVar, Dict, Optional -from dlt.common.schema.typing import TColumnSchema, TDataType, TColumnType +from dlt.common.schema.typing import TColumnSchema, TDataType, TColumnType, TTableFormat from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.utils import without_none @@ -20,15 +20,15 @@ class TypeMapper: def __init__(self, capabilities: DestinationCapabilitiesContext) -> None: self.capabilities = capabilities - def to_db_integer_type(self, precision: Optional[int]) -> str: + def to_db_integer_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: # Override in subclass if db supports other integer types (e.g. smallint, integer, tinyint, etc.) return self.sct_to_unbound_dbt["bigint"] - def to_db_type(self, column: TColumnSchema) -> str: + def to_db_type(self, column: TColumnSchema, table_format: TTableFormat = None) -> str: precision, scale = column.get("precision"), column.get("scale") sc_t = column["data_type"] if sc_t == "bigint": - return self.to_db_integer_type(precision) + return self.to_db_integer_type(precision, table_format) bounded_template = self.sct_to_dbt.get(sc_t) if not bounded_template: return self.sct_to_unbound_dbt[sc_t] diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index 4a50ad14ff..ec3f2bb47b 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -15,7 +15,7 @@ from dlt.common.pipeline import PipelineContext from dlt.common.source import _SOURCES, SourceInfo from dlt.common.schema.schema import Schema -from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns, TWriteDisposition, TAnySchemaColumns +from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns, TWriteDisposition, TAnySchemaColumns, TTableFormat from dlt.extract.utils import ensure_table_schema_columns_hint, simulate_func_call, wrap_compat_transformer, wrap_resource_gen from dlt.common.storages.exceptions import SchemaNotFoundError from dlt.common.storages.schema_storage import SchemaStorage @@ -206,6 +206,7 @@ def resource( columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + table_format: TTableHintTemplate[TTableFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None ) -> DltResource: @@ -221,6 +222,7 @@ def resource( columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + table_format: TTableHintTemplate[TTableFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None ) -> Callable[[Callable[TResourceFunParams, Any]], DltResource]: @@ -236,6 +238,7 @@ def resource( columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + table_format: TTableHintTemplate[TTableFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, standalone: Literal[True] = True @@ -253,6 +256,7 @@ def resource( columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + table_format: TTableHintTemplate[TTableFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None ) -> DltResource: @@ -268,6 +272,7 @@ def resource( columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + table_format: TTableHintTemplate[TTableFormat] = None, selected: bool = True, spec: Type[BaseConfiguration] = None, standalone: bool = False, @@ -317,6 +322,8 @@ def resource( merge_key (str | Sequence[str]): A column name or a list of column names that define a merge key. Typically used with "merge" write disposition to remove overlapping data ranges ie. to keep a single record for a given day. This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + table_format (Literal["iceberg"], optional): Defines the storage format of the table. Currently only "iceberg" is supported on Athena, other destinations ignore this hint. + selected (bool, optional): When `True` `dlt pipeline` will extract and load this resource, if `False`, the resource will be ignored. spec (Type[BaseConfiguration], optional): A specification of configuration and secret values required by the source. @@ -339,6 +346,7 @@ def make_resource(_name: str, _section: str, _data: Any, incremental: Incrementa columns=columns, primary_key=primary_key, merge_key=merge_key, + table_format=table_format ) return DltResource.from_data(_data, _name, _section, table_template, selected, cast(DltResource, data_from), incremental=incremental) diff --git a/dlt/extract/schema.py b/dlt/extract/schema.py index 309b3ab075..c1dfd1f7f5 100644 --- a/dlt/extract/schema.py +++ b/dlt/extract/schema.py @@ -3,7 +3,7 @@ from typing import List, TypedDict, cast, Any from dlt.common.schema.utils import DEFAULT_WRITE_DISPOSITION, merge_columns, new_column, new_table -from dlt.common.schema.typing import TColumnNames, TColumnProp, TColumnSchema, TPartialTableSchema, TTableSchemaColumns, TWriteDisposition, TAnySchemaColumns +from dlt.common.schema.typing import TColumnNames, TColumnProp, TColumnSchema, TPartialTableSchema, TTableSchemaColumns, TWriteDisposition, TAnySchemaColumns, TTableFormat from dlt.common.typing import TDataItem from dlt.common.utils import update_dict_nested from dlt.common.validation import validate_dict_ignoring_xkeys @@ -219,6 +219,7 @@ def new_table_template( columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, merge_key: TTableHintTemplate[TColumnNames] = None, + table_format: TTableHintTemplate[TTableFormat] = None ) -> TTableSchemaTemplate: if columns is not None: validator = get_column_validator(columns) @@ -229,7 +230,7 @@ def new_table_template( validator = None # create a table schema template where hints can be functions taking TDataItem new_template: TTableSchemaTemplate = new_table( - table_name, parent_table_name, write_disposition=write_disposition, columns=columns # type: ignore + table_name, parent_table_name, write_disposition=write_disposition, columns=columns, table_format=table_format # type: ignore ) if not table_name: new_template.pop("name") diff --git a/dlt/load/load.py b/dlt/load/load.py index 2cae753978..ca8fff66df 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -2,7 +2,7 @@ from copy import copy from functools import reduce import datetime # noqa: 251 -from typing import Dict, List, Optional, Tuple, Set, Iterator +from typing import Dict, List, Optional, Tuple, Set, Iterator, Iterable, Callable from multiprocessing.pool import ThreadPool import os @@ -10,20 +10,19 @@ from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config from dlt.common.pipeline import LoadInfo, SupportsPipeline -from dlt.common.schema.utils import get_child_tables, get_top_level_table, get_write_disposition +from dlt.common.schema.utils import get_child_tables, get_top_level_table from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState from dlt.common.typing import StrAny from dlt.common.runners import TRunMetrics, Runnable, workermethod from dlt.common.runtime.collector import Collector, NULL_COLLECTOR from dlt.common.runtime.logger import pretty_format_exception from dlt.common.exceptions import TerminalValueError, DestinationTerminalException, DestinationTransientException -from dlt.common.schema import Schema +from dlt.common.schema import Schema, TSchemaTables from dlt.common.schema.typing import TTableSchema, TWriteDisposition from dlt.common.storages import LoadStorage -from dlt.common.destination.reference import DestinationClientDwhConfiguration, FollowupJob, JobClientBase, WithStagingDataset, DestinationReference, LoadJob, NewLoadJob, TLoadJobState, DestinationClientConfiguration +from dlt.common.destination.reference import DestinationClientDwhConfiguration, FollowupJob, JobClientBase, WithStagingDataset, DestinationReference, LoadJob, NewLoadJob, TLoadJobState, DestinationClientConfiguration, SupportsStagingDestination from dlt.destinations.job_impl import EmptyLoadJob -from dlt.destinations.exceptions import LoadJobUnknownTableException from dlt.load.configuration import LoaderConfiguration from dlt.load.exceptions import LoadClientJobFailed, LoadClientJobRetry, LoadClientUnsupportedWriteDisposition, LoadClientUnsupportedFileFormats @@ -69,19 +68,6 @@ def create_storage(self, is_storage_owner: bool) -> LoadStorage: ) return load_storage - @staticmethod - def get_load_table(schema: Schema, file_name: str) -> TTableSchema: - table_name = LoadStorage.parse_job_file_name(file_name).table_name - try: - # make a copy of the schema so modifications do not affect the original document - table = copy(schema.get_table(table_name)) - # add write disposition if not specified - in child tables - if "write_disposition" not in table: - table["write_disposition"] = get_write_disposition(schema.tables, table_name) - return table - except KeyError: - raise LoadJobUnknownTableException(table_name, file_name) - def get_destination_client(self, schema: Schema) -> JobClientBase: return self.destination.client(schema, self.initial_client_config) @@ -92,9 +78,9 @@ def is_staging_destination_job(self, file_path: str) -> bool: return self.staging_destination is not None and os.path.splitext(file_path)[1][1:] in self.staging_destination.capabilities().supported_loader_file_formats @contextlib.contextmanager - def maybe_with_staging_dataset(self, job_client: JobClientBase, table: TTableSchema) -> Iterator[None]: + def maybe_with_staging_dataset(self, job_client: JobClientBase, use_staging: bool) -> Iterator[None]: """Executes job client methods in context of staging dataset if `table` has `write_disposition` that requires it""" - if isinstance(job_client, WithStagingDataset) and table["write_disposition"] in job_client.get_stage_dispositions(): + if isinstance(job_client, WithStagingDataset) and use_staging: with job_client.with_staging_dataset(): yield else: @@ -105,18 +91,26 @@ def maybe_with_staging_dataset(self, job_client: JobClientBase, table: TTableSch def w_spool_job(self: "Load", file_path: str, load_id: str, schema: Schema) -> Optional[LoadJob]: job: LoadJob = None try: + is_staging_destination_job = self.is_staging_destination_job(file_path) + job_client = self.get_destination_client(schema) + # if we have a staging destination and the file is not a reference, send to staging - job_client = self.get_staging_destination_client(schema) if self.is_staging_destination_job(file_path) else self.get_destination_client(schema) - with job_client as job_client: + with (self.get_staging_destination_client(schema) if is_staging_destination_job else job_client) as client: job_info = self.load_storage.parse_job_file_name(file_path) if job_info.file_format not in self.load_storage.supported_file_formats: raise LoadClientUnsupportedFileFormats(job_info.file_format, self.capabilities.supported_loader_file_formats, file_path) logger.info(f"Will load file {file_path} with table name {job_info.table_name}") - table = self.get_load_table(schema, file_path) + table = client.get_load_table(job_info.table_name) if table["write_disposition"] not in ["append", "replace", "merge"]: raise LoadClientUnsupportedWriteDisposition(job_info.table_name, table["write_disposition"], file_path) - with self.maybe_with_staging_dataset(job_client, table): - job = job_client.start_file_load(table, self.load_storage.storage.make_full_path(file_path), load_id) + + if is_staging_destination_job: + use_staging_dataset = isinstance(job_client, SupportsStagingDestination) and job_client.should_load_data_to_staging_dataset_on_staging_destination(table) + else: + use_staging_dataset = isinstance(job_client, WithStagingDataset) and job_client.should_load_data_to_staging_dataset(table) + + with self.maybe_with_staging_dataset(client, use_staging_dataset): + job = client.start_file_load(table, self.load_storage.storage.make_full_path(file_path), load_id) except (DestinationTerminalException, TerminalValueError): # if job irreversibly cannot be started, mark it as failed logger.exception(f"Terminal problem when adding job {file_path}") @@ -173,13 +167,8 @@ def retrieve_jobs(self, client: JobClientBase, load_id: str, staging_client: Job return len(jobs), jobs - def get_new_jobs_info(self, load_id: str, schema: Schema, dispositions: List[TWriteDisposition] = None) -> List[ParsedLoadJobFileName]: - jobs_info: List[ParsedLoadJobFileName] = [] - new_job_files = self.load_storage.list_new_jobs(load_id) - for job_file in new_job_files: - if dispositions is None or self.get_load_table(schema, job_file)["write_disposition"] in dispositions: - jobs_info.append(LoadStorage.parse_job_file_name(job_file)) - return jobs_info + def get_new_jobs_info(self, load_id: str) -> List[ParsedLoadJobFileName]: + return [LoadStorage.parse_job_file_name(job_file) for job_file in self.load_storage.list_new_jobs(load_id)] def get_completed_table_chain(self, load_id: str, schema: Schema, top_merged_table: TTableSchema, being_completed_job_id: str = None) -> List[TTableSchema]: """Gets a table chain starting from the `top_merged_table` containing only tables with completed/failed jobs. None is returned if there's any job that is not completed @@ -210,7 +199,7 @@ def create_followup_jobs(self, load_id: str, state: TLoadJobState, starting_job: starting_job_file_name = starting_job.file_name() if state == "completed" and not self.is_staging_destination_job(starting_job_file_name): client = self.destination.client(schema, self.initial_client_config) - top_job_table = get_top_level_table(schema.tables, self.get_load_table(schema, starting_job_file_name)["name"]) + top_job_table = get_top_level_table(schema.tables, starting_job.job_file_info().table_name) # if all tables of chain completed, create follow up jobs if table_chain := self.get_completed_table_chain(load_id, schema, top_job_table, starting_job.job_file_info().job_id()): if follow_up_jobs := client.create_table_chain_completed_followup_jobs(table_chain): @@ -278,54 +267,69 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) self.load_storage.complete_load_package(load_id, aborted) logger.info(f"All jobs completed, archiving package {load_id} with aborted set to {aborted}") - def get_table_chain_tables_for_write_disposition(self, load_id: str, schema: Schema, dispositions: List[TWriteDisposition]) -> Set[str]: + @staticmethod + def _get_table_chain_tables_with_filter(schema: Schema, f: Callable[[TTableSchema], bool], tables_with_jobs: Iterable[str]) -> Set[str]: """Get all jobs for tables with given write disposition and resolve the table chain""" result: Set[str] = set() - table_jobs = self.get_new_jobs_info(load_id, schema, dispositions) - for job in table_jobs: - top_job_table = get_top_level_table(schema.tables, self.get_load_table(schema, job.job_id())["name"]) - table_chain = get_child_tables(schema.tables, top_job_table["name"]) - for table in table_chain: - existing_jobs = self.load_storage.list_jobs_for_table(load_id, table["name"]) + for table_name in tables_with_jobs: + top_job_table = get_top_level_table(schema.tables, table_name) + if not f(top_job_table): + continue + for table in get_child_tables(schema.tables, top_job_table["name"]): # only add tables for tables that have jobs unless the disposition is replace - if not existing_jobs and top_job_table["write_disposition"] != "replace": + # TODO: this is a (formerly used) hack to make test_merge_on_keys_in_schema, + # we should change that test + if not table["name"] in tables_with_jobs and top_job_table["write_disposition"] != "replace": continue result.add(table["name"]) return result + @staticmethod + def _init_dataset_and_update_schema(job_client: JobClientBase, expected_update: TSchemaTables, update_tables: Iterable[str], truncate_tables: Iterable[str] = None, staging_info: bool = False) -> TSchemaTables: + staging_text = "for staging dataset" if staging_info else "" + logger.info(f"Client for {job_client.config.destination_name} will start initialize storage {staging_text}") + job_client.initialize_storage() + logger.info(f"Client for {job_client.config.destination_name} will update schema to package schema {staging_text}") + applied_update = job_client.update_stored_schema(only_tables=update_tables, expected_update=expected_update) + logger.info(f"Client for {job_client.config.destination_name} will truncate tables {staging_text}") + job_client.initialize_storage(truncate_tables=truncate_tables) + return applied_update + + + def _init_client(self, job_client: JobClientBase, schema: Schema, expected_update: TSchemaTables, load_id: str, truncate_filter: Callable[[TTableSchema], bool], truncate_staging_filter: Callable[[TTableSchema], bool]) -> TSchemaTables: + + tables_with_jobs = set(job.table_name for job in self.get_new_jobs_info(load_id)) + dlt_tables = set(t["name"] for t in schema.dlt_tables()) + + # update the default dataset + truncate_tables = self._get_table_chain_tables_with_filter(schema, truncate_filter, tables_with_jobs) + applied_update = self._init_dataset_and_update_schema(job_client, expected_update, tables_with_jobs | dlt_tables, truncate_tables) + + # update the staging dataset if client supports this + if isinstance(job_client, WithStagingDataset): + if staging_tables := self._get_table_chain_tables_with_filter(schema, truncate_staging_filter, tables_with_jobs): + with job_client.with_staging_dataset(): + self._init_dataset_and_update_schema(job_client, expected_update, staging_tables | {schema.version_table_name}, staging_tables, staging_info=True) + + return applied_update + + def load_single_package(self, load_id: str, schema: Schema) -> None: # initialize analytical storage ie. create dataset required by passed schema - job_client: JobClientBase with self.get_destination_client(schema) as job_client: - expected_update = self.load_storage.begin_schema_update(load_id) - if expected_update is not None: - # update the default dataset - logger.info(f"Client for {job_client.config.destination_name} will start initialize storage") - job_client.initialize_storage() - logger.info(f"Client for {job_client.config.destination_name} will update schema to package schema") - all_jobs = self.get_new_jobs_info(load_id, schema) - all_tables = set(job.table_name for job in all_jobs) - dlt_tables = set(t["name"] for t in schema.dlt_tables()) - # only update tables that are present in the load package - applied_update = job_client.update_stored_schema(only_tables=all_tables | dlt_tables, expected_update=expected_update) - truncate_tables = self.get_table_chain_tables_for_write_disposition(load_id, schema, job_client.get_truncate_destination_table_dispositions()) - job_client.initialize_storage(truncate_tables=truncate_tables) - # initialize staging storage if needed - if self.staging_destination: + + if (expected_update := self.load_storage.begin_schema_update(load_id)) is not None: + + # init job client + applied_update = self._init_client(job_client, schema, expected_update, load_id, job_client.should_truncate_table_before_load, job_client.should_load_data_to_staging_dataset if isinstance(job_client, WithStagingDataset) else None) + + # init staging client + if self.staging_destination and isinstance(job_client, SupportsStagingDestination): with self.get_staging_destination_client(schema) as staging_client: - truncate_tables = self.get_table_chain_tables_for_write_disposition(load_id, schema, staging_client.get_truncate_destination_table_dispositions()) - staging_client.initialize_storage(truncate_tables) - # update the staging dataset if client supports this - if isinstance(job_client, WithStagingDataset): - if staging_tables := self.get_table_chain_tables_for_write_disposition(load_id, schema, job_client.get_stage_dispositions()): - with job_client.with_staging_dataset(): - logger.info(f"Client for {job_client.config.destination_name} will start initialize STAGING storage") - job_client.initialize_storage() - logger.info(f"Client for {job_client.config.destination_name} will UPDATE STAGING SCHEMA to package schema") - job_client.update_stored_schema(only_tables=staging_tables | {schema.version_table_name}, expected_update=expected_update) - logger.info(f"Client for {job_client.config.destination_name} will TRUNCATE STAGING TABLES: {staging_tables}") - job_client.initialize_storage(truncate_tables=staging_tables) + self._init_client(staging_client, schema, expected_update, load_id, job_client.should_truncate_table_before_load_on_staging_destination, job_client.should_load_data_to_staging_dataset_on_staging_destination) + self.load_storage.commit_schema_update(load_id, applied_update) + # initialize staging destination and spool or retrieve unfinished jobs if self.staging_destination: with self.get_staging_destination_client(schema) as staging_client: diff --git a/docs/website/docs/dlt-ecosystem/destinations/athena.md b/docs/website/docs/dlt-ecosystem/destinations/athena.md index 74771ba74f..9bd1682e97 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/athena.md +++ b/docs/website/docs/dlt-ecosystem/destinations/athena.md @@ -6,7 +6,7 @@ keywords: [aws, athena, glue catalog] # AWS Athena / Glue Catalog -The athena destination stores data as parquet files in s3 buckets and creates [external tables in aws athena](https://docs.aws.amazon.com/athena/latest/ug/creating-tables.html). You can then query those tables with athena sql commands which will then scan the whole folder of parquet files and return the results. This destination works very similar to other sql based destinations, with the exception of the merge write disposition not being supported at this time. dlt metadata will be stored in the same bucket as the parquet files, but as iceberg tables. +The athena destination stores data as parquet files in s3 buckets and creates [external tables in aws athena](https://docs.aws.amazon.com/athena/latest/ug/creating-tables.html). You can then query those tables with athena sql commands which will then scan the whole folder of parquet files and return the results. This destination works very similar to other sql based destinations, with the exception of the merge write disposition not being supported at this time. dlt metadata will be stored in the same bucket as the parquet files, but as iceberg tables. Athena additionally supports writing individual data tables as iceberg tables, so the may be manipulated later, a common use-case would be to strip gdpr data from them. ## Setup Guide ### 1. Initialize the dlt project @@ -110,16 +110,37 @@ Using a staging destination is mandatory when using the athena destination. If y If you decide to change the [filename layout](./filesystem#data-loading) from the default value, keep the following in mind so that athena can reliable build your tables: - You need to provide the `{table_name}` placeholder and this placeholder needs to be followed by a forward slash - You need to provide the `{file_id}` placeholder and it needs to be somewhere after the `{table_name}` placeholder. - - {table_name} must be a first placeholder in the layout. + - {table_name} must be the first placeholder in the layout. ## Additional destination options +### iceberg data tables +You can save your tables as iceberg tables to athena. This will enable you to for example delete data from them later if you need to. To switch a resouce to the iceberg table-format, +supply the table_format argument like this: + +```python +@dlt.resource(table_format="iceberg") +def data() -> Iterable[TDataItem]: + ... +``` + +Alternatively you can set all tables to use the iceberg format with a config variable: + +```toml +[destination.athena] +force_iceberg = "True" +``` + +For every table created as an iceberg table, the athena destination will create a regular athena table in the staging dataset of both the filesystem as well as the athena glue catalog and then +copy all data into the final iceberg table that lives with the non-iceberg tables in the same dataset on both filesystem and the glue catalog. Switching from iceberg to regular table or vice versa +is not supported. ### dbt support -Athena is supported via `dbt-athena-community`. Credentials are passed into `aws_access_key_id` and `aws_secret_access_key` of generated dbt profile. -Athena adapter requires that you setup **region_name** in Athena configuration below. You can also setup table catalog name to change the default: **awsdatacatalog** +Athena is supported via `dbt-athena-community`. Credentials are passed into `aws_access_key_id` and `aws_secret_access_key` of generated dbt profile. Iceberg tables are supported but you need to make sure that you materialize your models as iceberg tables if your source table is iceberg. We encountered problems with materializing +date time columns due to different precision on iceberg (nanosecond) and regular Athena tables (millisecond). +The Athena adapter requires that you setup **region_name** in Athena configuration below. You can also setup table catalog name to change the default: **awsdatacatalog** ```toml [destination.athena] aws_data_catalog="awsdatacatalog" diff --git a/tests/load/athena_iceberg/__init__.py b/tests/load/athena_iceberg/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/load/athena_iceberg/test_athena_iceberg.py b/tests/load/athena_iceberg/test_athena_iceberg.py new file mode 100644 index 0000000000..72772b0e2d --- /dev/null +++ b/tests/load/athena_iceberg/test_athena_iceberg.py @@ -0,0 +1,78 @@ + +import pytest +import os +import datetime # noqa: I251 +from typing import Iterator, Any + +import dlt +from dlt.common import pendulum +from dlt.common.utils import uniq_id +from tests.load.pipeline.utils import load_table_counts +from tests.cases import table_update_and_row, assert_all_data_types_row +from tests.pipeline.utils import assert_load_info + +from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration + +from tests.utils import skip_if_not_active +from dlt.destinations.exceptions import DatabaseTerminalException + + +skip_if_not_active("athena") + + +def test_iceberg() -> None: + """ + We write two tables, one with the iceberg flag, one without. We expect the iceberg table and its subtables to accept update commands + and the other table to reject them. + """ + os.environ['DESTINATION__FILESYSTEM__BUCKET_URL'] = "s3://dlt-ci-test-bucket" + + pipeline = dlt.pipeline(pipeline_name="aaaaathena-iceberg", destination="athena", staging="filesystem", full_refresh=True) + + def items() -> Iterator[Any]: + yield { + "id": 1, + "name": "item", + "sub_items": [{ + "id": 101, + "name": "sub item 101" + },{ + "id": 101, + "name": "sub item 102" + }] + } + + @dlt.resource(name="items_normal", write_disposition="append") + def items_normal(): + yield from items() + + @dlt.resource(name="items_iceberg", write_disposition="append", table_format="iceberg") + def items_iceberg(): + yield from items() + + print(pipeline.run([items_normal, items_iceberg])) + + # see if we have athena tables with items + table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values() ]) + assert table_counts["items_normal"] == 1 + assert table_counts["items_normal__sub_items"] == 2 + assert table_counts["_dlt_loads"] == 1 + + assert table_counts["items_iceberg"] == 1 + assert table_counts["items_iceberg__sub_items"] == 2 + + with pipeline.sql_client() as client: + client.execute_sql("SELECT * FROM items_normal") + + # modifying regular athena table will fail + with pytest.raises(DatabaseTerminalException) as dbex: + client.execute_sql("UPDATE items_normal SET name='new name'") + assert "Modifying Hive table rows is only supported for transactional tables" in str(dbex) + with pytest.raises(DatabaseTerminalException) as dbex: + client.execute_sql("UPDATE items_normal__sub_items SET name='super new name'") + assert "Modifying Hive table rows is only supported for transactional tables" in str(dbex) + + # modifying iceberg table will succeed + client.execute_sql("UPDATE items_iceberg SET name='new name'") + client.execute_sql("UPDATE items_iceberg__sub_items SET name='super new name'") + diff --git a/tests/load/pipeline/test_dbt_helper.py b/tests/load/pipeline/test_dbt_helper.py index e55f5b2964..37c1f0c607 100644 --- a/tests/load/pipeline/test_dbt_helper.py +++ b/tests/load/pipeline/test_dbt_helper.py @@ -58,6 +58,8 @@ def test_run_jaffle_package(destination_config: DestinationTestConfiguration, db @pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) def test_run_chess_dbt(destination_config: DestinationTestConfiguration, dbt_venv: Venv) -> None: from docs.examples.chess.chess import chess + if not destination_config.supports_dbt: + pytest.skip("dbt is not supported for this destination configuration") # provide chess url via environ os.environ["CHESS_URL"] = "https://api.chess.com/pub/" @@ -95,6 +97,8 @@ def test_run_chess_dbt(destination_config: DestinationTestConfiguration, dbt_ven @pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name) def test_run_chess_dbt_to_other_dataset(destination_config: DestinationTestConfiguration, dbt_venv: Venv) -> None: from docs.examples.chess.chess import chess + if not destination_config.supports_dbt: + pytest.skip("dbt is not supported for this destination configuration") # provide chess url via environ os.environ["CHESS_URL"] = "https://api.chess.com/pub/" diff --git a/tests/load/pipeline/test_replace_disposition.py b/tests/load/pipeline/test_replace_disposition.py index 9ee3a5b947..d39556ab2f 100644 --- a/tests/load/pipeline/test_replace_disposition.py +++ b/tests/load/pipeline/test_replace_disposition.py @@ -25,7 +25,7 @@ def test_replace_disposition(destination_config: DestinationTestConfiguration, r # TODO: start storing _dlt_loads with right json content increase_loads = lambda x: x if destination_config.destination == "filesystem" else x + 1 - increase_state_loads = lambda info: len([job for job in info.load_packages[0].jobs["completed_jobs"] if job.job_file_info.table_name == "_dlt_pipeline_state" and job.job_file_info.file_format != "reference"]) + increase_state_loads = lambda info: len([job for job in info.load_packages[0].jobs["completed_jobs"] if job.job_file_info.table_name == "_dlt_pipeline_state" and job.job_file_info.file_format not in ["sql", "reference"]]) # filesystem does not have versions and child tables def norm_table_counts(counts: Dict[str, int], *child_tables: str) -> Dict[str, int]: @@ -72,6 +72,7 @@ def append_items(): "name": f"item {index}", } + # first run with offset 0 info = pipeline.run([load_items, append_items], loader_file_format=destination_config.file_format) assert_load_info(info) @@ -98,6 +99,7 @@ def append_items(): "_dlt_loads": dlt_loads, "_dlt_version": dlt_versions } + # check trace assert pipeline.last_trace.last_normalize_info.row_counts == { "append_items": 12, diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index dcea7bd94d..aaa89ebfb1 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -100,16 +100,7 @@ def test_get_new_jobs_info() -> None: ) # no write disposition specified - get all new jobs - assert len(load.get_new_jobs_info(load_id, schema)) == 2 - # empty list - none - assert len(load.get_new_jobs_info(load_id, schema, [])) == 0 - # two appends - assert len(load.get_new_jobs_info(load_id, schema, ["append"])) == 2 - assert len(load.get_new_jobs_info(load_id, schema, ["replace"])) == 0 - assert len(load.get_new_jobs_info(load_id, schema, ["replace", "append"])) == 2 - - load.load_storage.start_job(load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl") - assert len(load.get_new_jobs_info(load_id, schema, ["replace", "append"])) == 1 + assert len(load.get_new_jobs_info(load_id)) == 2 def test_get_completed_table_chain_single_job_per_table() -> None: diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 1a842f534b..35394ed1c6 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -428,14 +428,14 @@ def test_load_with_all_types(client: SqlJobClientBase, write_disposition: TWrite client.schema.bump_version() client.update_stored_schema() - if write_disposition in client.get_stage_dispositions(): # type: ignore[attr-defined] + if client.should_load_data_to_staging_dataset(client.schema.tables[table_name]): # type: ignore[attr-defined] with client.with_staging_dataset(): # type: ignore[attr-defined] # create staging for merge dataset client.initialize_storage() client.update_stored_schema() with client.sql_client.with_staging_dataset( - write_disposition in client.get_stage_dispositions() # type: ignore[attr-defined] + client.should_load_data_to_staging_dataset(client.schema.tables[table_name]) # type: ignore[attr-defined] ): canonical_name = client.sql_client.make_qualified_table_name(table_name) # write row @@ -493,7 +493,7 @@ def test_write_dispositions(client: SqlJobClientBase, write_disposition: TWriteD with io.BytesIO() as f: write_dataset(client, f, [table_row], TABLE_UPDATE_COLUMNS_SCHEMA) query = f.getvalue().decode() - if write_disposition in client.get_stage_dispositions(): # type: ignore[attr-defined] + if client.should_load_data_to_staging_dataset(client.schema.tables[table_name]): # type: ignore[attr-defined] # load to staging dataset on merge with client.with_staging_dataset(): # type: ignore[attr-defined] expect_load_file(client, file_storage, query, t) diff --git a/tests/load/utils.py b/tests/load/utils.py index 9941bbb55e..b2c87e56ae 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -26,7 +26,7 @@ from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.job_client_impl import SqlJobClientBase -from tests.utils import ACTIVE_DESTINATIONS, IMPLEMENTED_DESTINATIONS, SQL_DESTINATIONS +from tests.utils import ACTIVE_DESTINATIONS, IMPLEMENTED_DESTINATIONS, SQL_DESTINATIONS, EXCLUDED_DESTINATION_CONFIGURATIONS from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA, TABLE_UPDATE, TABLE_ROW_ALL_DATA_TYPES, assert_all_data_types_row # bucket urls @@ -53,6 +53,8 @@ class DestinationTestConfiguration: staging_iam_role: Optional[str] = None extra_info: Optional[str] = None supports_merge: bool = True # TODO: take it from client base class + force_iceberg: bool = False + supports_dbt: bool = True @property def name(self) -> str: @@ -72,6 +74,7 @@ def setup(self) -> None: os.environ['DESTINATION__FILESYSTEM__BUCKET_URL'] = self.bucket_url or "" os.environ['DESTINATION__STAGE_NAME'] = self.stage_name or "" os.environ['DESTINATION__STAGING_IAM_ROLE'] = self.staging_iam_role or "" + os.environ['DESTINATION__FORCE_ICEBERG'] = str(self.force_iceberg) or "" """For the filesystem destinations we disable compression to make analyzing the result easier""" if self.destination == "filesystem": @@ -108,15 +111,14 @@ def destinations_configs( destination_configs += [DestinationTestConfiguration(destination=destination) for destination in SQL_DESTINATIONS if destination != "athena"] # athena needs filesystem staging, which will be automatically set, we have to supply a bucket url though destination_configs += [DestinationTestConfiguration(destination="athena", supports_merge=False, bucket_url=AWS_BUCKET)] + destination_configs += [DestinationTestConfiguration(destination="athena", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, force_iceberg=True, supports_merge=False, supports_dbt=False, extra_info="iceberg")] if default_vector_configs: # for now only weaviate destination_configs += [DestinationTestConfiguration(destination="weaviate")] - if default_staging_configs or all_staging_configs: destination_configs += [ - DestinationTestConfiguration(destination="athena", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, supports_merge=False), DestinationTestConfiguration(destination="redshift", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, staging_iam_role="arn:aws:iam::267388281016:role/redshift_s3_read", extra_info="s3-role"), DestinationTestConfiguration(destination="bigquery", staging="filesystem", file_format="parquet", bucket_url=GCS_BUCKET, extra_info="gcs-authorization"), DestinationTestConfiguration(destination="snowflake", staging="filesystem", file_format="jsonl", bucket_url=GCS_BUCKET, stage_name="PUBLIC.dlt_gcs_stage", extra_info="gcs-integration"), @@ -153,6 +155,9 @@ def destinations_configs( if exclude: destination_configs = [conf for conf in destination_configs if conf.destination not in exclude] + # filter out excluded configs + destination_configs = [conf for conf in destination_configs if conf.name not in EXCLUDED_DESTINATION_CONFIGURATIONS] + return destination_configs @@ -171,7 +176,7 @@ def load_table(name: str) -> Dict[str, TTableSchemaColumns]: def expect_load_file(client: JobClientBase, file_storage: FileStorage, query: str, table_name: str, status = "completed") -> LoadJob: file_name = ParsedLoadJobFileName(table_name, uniq_id(), 0, client.capabilities.preferred_loader_file_format).job_id() file_storage.save(file_name, query.encode("utf-8")) - table = Load.get_load_table(client.schema, file_name) + table = client.get_load_table(table_name) job = client.start_file_load(table, file_storage.make_full_path(file_name), uniq_id()) while job.state() == "running": sleep(0.5) diff --git a/tests/utils.py b/tests/utils.py index bde02b3fbf..2eba788542 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -25,13 +25,19 @@ TEST_STORAGE_ROOT = "_storage" + # destination constants IMPLEMENTED_DESTINATIONS = {"athena", "duckdb", "bigquery", "redshift", "postgres", "snowflake", "filesystem", "weaviate", "dummy", "motherduck", "mssql"} NON_SQL_DESTINATIONS = {"filesystem", "weaviate", "dummy", "motherduck"} SQL_DESTINATIONS = IMPLEMENTED_DESTINATIONS - NON_SQL_DESTINATIONS +# exclude destination configs (for now used for athena and athena iceberg separation) +EXCLUDED_DESTINATION_CONFIGURATIONS = set(dlt.config.get("EXCLUDED_DESTINATION_CONFIGURATIONS", list) or set()) + + # filter out active destinations for current tests ACTIVE_DESTINATIONS = set(dlt.config.get("ACTIVE_DESTINATIONS", list) or IMPLEMENTED_DESTINATIONS) + ACTIVE_SQL_DESTINATIONS = SQL_DESTINATIONS.intersection(ACTIVE_DESTINATIONS) ACTIVE_NON_SQL_DESTINATIONS = NON_SQL_DESTINATIONS.intersection(ACTIVE_DESTINATIONS)