From ab8a01ca286e34afad213751ada0ee4caf1a3765 Mon Sep 17 00:00:00 2001 From: Dave Date: Wed, 11 Oct 2023 13:52:56 +0200 Subject: [PATCH] temp --- dlt/common/destination/reference.py | 17 ++++++++++++++++- dlt/common/schema/utils.py | 12 ------------ dlt/destinations/athena/athena.py | 10 ++++++++-- dlt/destinations/job_client_impl.py | 10 ++-------- dlt/load/load.py | 7 +++++-- tests/load/utils.py | 4 ++-- 6 files changed, 33 insertions(+), 27 deletions(-) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 5fea462159..6f5689ed20 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -4,18 +4,20 @@ from typing import ClassVar, Final, Optional, NamedTuple, Literal, Sequence, Iterable, Type, Protocol, Union, TYPE_CHECKING, cast, List, ContextManager, Dict, Any from contextlib import contextmanager import datetime # noqa: 251 +from copy import deepcopy from dlt.common import logger from dlt.common.exceptions import IdentifierTooLongException, InvalidDestinationReference, UnknownDestinationModule from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import TWriteDisposition from dlt.common.schema.exceptions import InvalidDatasetName -from dlt.common.schema.utils import get_load_table +from dlt.common.schema.utils import get_write_disposition, get_table_format from dlt.common.configuration import configspec from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.configuration.accessors import config from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.schema.utils import is_complete_column +from dlt.common.schema.exceptions import UnknownTableException from dlt.common.storages import FileStorage from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.utils import get_module_name @@ -287,6 +289,19 @@ def _verify_schema(self) -> None: if not is_complete_column(column): logger.warning(f"A column {column_name} in table {table_name} in schema {self.schema.name} is incomplete. It was not bound to the data during normalizations stage and its data type is unknown. Did you add this column manually in code ie. as a merge key?") + def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: + try: + # make a copy of the schema so modifications do not affect the original document + table = deepcopy(self.schema.tables[table_name]) + # add write disposition if not specified - in child tables + if "write_disposition" not in table: + table["write_disposition"] = get_write_disposition(self.schema.tables, table_name) + if "table_format" not in table: + table["table_format"] = get_table_format(self.schema.tables, table_name) + return table + except KeyError: + raise UnknownTableException(table_name) + class WithStateSync(ABC): diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 93f0913550..cc265891d1 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -536,18 +536,6 @@ def get_top_level_table(tables: TSchemaTables, table_name: str) -> TTableSchema: return get_top_level_table(tables, parent) return table -def get_load_table(tables: TSchemaTables, table_name: str) -> TTableSchema: - try: - # make a copy of the schema so modifications do not affect the original document - table = copy(tables[table_name]) - # add write disposition if not specified - in child tables - if "write_disposition" not in table: - table["write_disposition"] = get_write_disposition(tables, table_name) - if "table_format" not in table: - table["table_format"] = get_table_format(tables, table_name) - return table - except KeyError: - raise UnknownTableException(table_name) def get_child_tables(tables: TSchemaTables, table_name: str) -> List[TTableSchema]: """Get child tables for table name and return a list of tables ordered by ancestry so the child tables are always after their parents""" diff --git a/dlt/destinations/athena/athena.py b/dlt/destinations/athena/athena.py index c9d6f3abb7..514d868047 100644 --- a/dlt/destinations/athena/athena.py +++ b/dlt/destinations/athena/athena.py @@ -15,7 +15,7 @@ from dlt.common import logger from dlt.common.utils import without_none from dlt.common.data_types import TDataType -from dlt.common.schema import TColumnSchema, Schema +from dlt.common.schema import TColumnSchema, Schema, TSchemaTables, TTableSchema from dlt.common.schema.typing import TTableSchema, TColumnType, TWriteDisposition from dlt.common.schema.utils import table_schema_has_type, get_table_format from dlt.common.destination import DestinationCapabilitiesContext @@ -325,7 +325,7 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc # for the system tables we need to create empty iceberg tables to be able to run, DELETE and UPDATE queries # or if we are in iceberg mode, we create iceberg tables for all tables - is_iceberg = (self.schema.tables[table_name].get("write_disposition", None) == "skip") or (self._is_iceberg_table(self.schema.tables[table_name]) and not self.in_staging_mode) + is_iceberg = (self.schema.tables[table_name].get("write_disposition", None) == "skip") or self._is_iceberg_table(self.schema.tables[table_name]) columns = ", ".join([self._get_column_def_sql(c) for c in new_columns]) # this will fail if the table prefix is not properly defined @@ -381,6 +381,12 @@ def table_needs_staging(self, table: TTableSchema) -> bool: if self._is_iceberg_table(table): return True return super().table_needs_staging(table) + + def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: + table = super().get_load_table(table_name, staging) + if staging and table.get("table_format", None) == "iceberg": + table.pop("table_format") + return table @staticmethod def is_dbapi_exception(ex: Exception) -> bool: diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index ec3afced94..e2768dcf76 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -429,16 +429,10 @@ def _commit_schema_update(self, schema: Schema, schema_str: str) -> None: class SqlJobClientWithStaging(SqlJobClientBase, WithStagingDataset): - in_staging_mode: bool = False - @contextlib.contextmanager def with_staging_dataset(self)-> Iterator["SqlJobClientBase"]: - try: - with self.sql_client.with_staging_dataset(True): - self.in_staging_mode = True - yield self - finally: - self.in_staging_mode = False + with self.sql_client.with_staging_dataset(True): + yield self def table_needs_staging(self, table: TTableSchema) -> bool: if table["write_disposition"] == "merge": diff --git a/dlt/load/load.py b/dlt/load/load.py index ddce9bf8e9..c57d18fa0f 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -10,7 +10,7 @@ from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config from dlt.common.pipeline import LoadInfo, SupportsPipeline -from dlt.common.schema.utils import get_child_tables, get_top_level_table, get_load_table +from dlt.common.schema.utils import get_child_tables, get_top_level_table from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState from dlt.common.typing import StrAny from dlt.common.runners import TRunMetrics, Runnable, workermethod @@ -98,9 +98,12 @@ def w_spool_job(self: "Load", file_path: str, load_id: str, schema: Schema) -> O if job_info.file_format not in self.load_storage.supported_file_formats: raise LoadClientUnsupportedFileFormats(job_info.file_format, self.capabilities.supported_loader_file_formats, file_path) logger.info(f"Will load file {file_path} with table name {job_info.table_name}") - table = get_load_table(schema.tables, job_info.table_name) + table = job_client.get_load_table(job_info.table_name) if table["write_disposition"] not in ["append", "replace", "merge"]: raise LoadClientUnsupportedWriteDisposition(job_info.table_name, table["write_disposition"], file_path) + + table_needs_staging = isinstance(job_client, WithStagingDataset) and job_client.table_needs_staging(table) + table = job_client.get_load_table(job_info.table_name, table_needs_staging) with self.maybe_with_staging_dataset(job_client, table): job = job_client.start_file_load(table, self.load_storage.storage.make_full_path(file_path), load_id) except (DestinationTerminalException, TerminalValueError): diff --git a/tests/load/utils.py b/tests/load/utils.py index 9fd4f033b7..2af16af496 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -17,7 +17,7 @@ from dlt.common.data_writers import DataWriter from dlt.common.schema import TColumnSchema, TTableSchemaColumns, Schema from dlt.common.storages import SchemaStorage, FileStorage, SchemaStorageConfiguration -from dlt.common.schema.utils import new_table, get_load_table +from dlt.common.schema.utils import new_table from dlt.common.storages.load_storage import ParsedLoadJobFileName, LoadStorage from dlt.common.typing import StrAny from dlt.common.utils import uniq_id @@ -170,7 +170,7 @@ def load_table(name: str) -> Dict[str, TTableSchemaColumns]: def expect_load_file(client: JobClientBase, file_storage: FileStorage, query: str, table_name: str, status = "completed") -> LoadJob: file_name = ParsedLoadJobFileName(table_name, uniq_id(), 0, client.capabilities.preferred_loader_file_format).job_id() file_storage.save(file_name, query.encode("utf-8")) - table = get_load_table(client.schema.tables, table_name) + table = client.get_load_table(table_name) job = client.start_file_load(table, file_storage.make_full_path(file_name), uniq_id()) while job.state() == "running": sleep(0.5)