Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Oct 11, 2023
1 parent 0924bc5 commit ab8a01c
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 27 deletions.
17 changes: 16 additions & 1 deletion dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
from typing import ClassVar, Final, Optional, NamedTuple, Literal, Sequence, Iterable, Type, Protocol, Union, TYPE_CHECKING, cast, List, ContextManager, Dict, Any
from contextlib import contextmanager
import datetime # noqa: 251
from copy import deepcopy

from dlt.common import logger
from dlt.common.exceptions import IdentifierTooLongException, InvalidDestinationReference, UnknownDestinationModule
from dlt.common.schema import Schema, TTableSchema, TSchemaTables
from dlt.common.schema.typing import TWriteDisposition
from dlt.common.schema.exceptions import InvalidDatasetName
from dlt.common.schema.utils import get_load_table
from dlt.common.schema.utils import get_write_disposition, get_table_format
from dlt.common.configuration import configspec
from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration
from dlt.common.configuration.accessors import config
from dlt.common.destination.capabilities import DestinationCapabilitiesContext
from dlt.common.schema.utils import is_complete_column
from dlt.common.schema.exceptions import UnknownTableException
from dlt.common.storages import FileStorage
from dlt.common.storages.load_storage import ParsedLoadJobFileName
from dlt.common.utils import get_module_name
Expand Down Expand Up @@ -287,6 +289,19 @@ def _verify_schema(self) -> None:
if not is_complete_column(column):
logger.warning(f"A column {column_name} in table {table_name} in schema {self.schema.name} is incomplete. It was not bound to the data during normalizations stage and its data type is unknown. Did you add this column manually in code ie. as a merge key?")

def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema:
try:
# make a copy of the schema so modifications do not affect the original document
table = deepcopy(self.schema.tables[table_name])
# add write disposition if not specified - in child tables
if "write_disposition" not in table:
table["write_disposition"] = get_write_disposition(self.schema.tables, table_name)
if "table_format" not in table:
table["table_format"] = get_table_format(self.schema.tables, table_name)
return table
except KeyError:
raise UnknownTableException(table_name)


class WithStateSync(ABC):

Expand Down
12 changes: 0 additions & 12 deletions dlt/common/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,18 +536,6 @@ def get_top_level_table(tables: TSchemaTables, table_name: str) -> TTableSchema:
return get_top_level_table(tables, parent)
return table

def get_load_table(tables: TSchemaTables, table_name: str) -> TTableSchema:
try:
# make a copy of the schema so modifications do not affect the original document
table = copy(tables[table_name])
# add write disposition if not specified - in child tables
if "write_disposition" not in table:
table["write_disposition"] = get_write_disposition(tables, table_name)
if "table_format" not in table:
table["table_format"] = get_table_format(tables, table_name)
return table
except KeyError:
raise UnknownTableException(table_name)

def get_child_tables(tables: TSchemaTables, table_name: str) -> List[TTableSchema]:
"""Get child tables for table name and return a list of tables ordered by ancestry so the child tables are always after their parents"""
Expand Down
10 changes: 8 additions & 2 deletions dlt/destinations/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dlt.common import logger
from dlt.common.utils import without_none
from dlt.common.data_types import TDataType
from dlt.common.schema import TColumnSchema, Schema
from dlt.common.schema import TColumnSchema, Schema, TSchemaTables, TTableSchema
from dlt.common.schema.typing import TTableSchema, TColumnType, TWriteDisposition
from dlt.common.schema.utils import table_schema_has_type, get_table_format
from dlt.common.destination import DestinationCapabilitiesContext
Expand Down Expand Up @@ -325,7 +325,7 @@ def _get_table_update_sql(self, table_name: str, new_columns: Sequence[TColumnSc

# for the system tables we need to create empty iceberg tables to be able to run, DELETE and UPDATE queries
# or if we are in iceberg mode, we create iceberg tables for all tables
is_iceberg = (self.schema.tables[table_name].get("write_disposition", None) == "skip") or (self._is_iceberg_table(self.schema.tables[table_name]) and not self.in_staging_mode)
is_iceberg = (self.schema.tables[table_name].get("write_disposition", None) == "skip") or self._is_iceberg_table(self.schema.tables[table_name])
columns = ", ".join([self._get_column_def_sql(c) for c in new_columns])

# this will fail if the table prefix is not properly defined
Expand Down Expand Up @@ -381,6 +381,12 @@ def table_needs_staging(self, table: TTableSchema) -> bool:
if self._is_iceberg_table(table):
return True
return super().table_needs_staging(table)

def get_load_table(self, table_name: str, staging: bool = False) -> TTableSchema:
table = super().get_load_table(table_name, staging)
if staging and table.get("table_format", None) == "iceberg":
table.pop("table_format")
return table

@staticmethod
def is_dbapi_exception(ex: Exception) -> bool:
Expand Down
10 changes: 2 additions & 8 deletions dlt/destinations/job_client_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,16 +429,10 @@ def _commit_schema_update(self, schema: Schema, schema_str: str) -> None:

class SqlJobClientWithStaging(SqlJobClientBase, WithStagingDataset):

in_staging_mode: bool = False

@contextlib.contextmanager
def with_staging_dataset(self)-> Iterator["SqlJobClientBase"]:
try:
with self.sql_client.with_staging_dataset(True):
self.in_staging_mode = True
yield self
finally:
self.in_staging_mode = False
with self.sql_client.with_staging_dataset(True):
yield self

def table_needs_staging(self, table: TTableSchema) -> bool:
if table["write_disposition"] == "merge":
Expand Down
7 changes: 5 additions & 2 deletions dlt/load/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from dlt.common.configuration import with_config, known_sections
from dlt.common.configuration.accessors import config
from dlt.common.pipeline import LoadInfo, SupportsPipeline
from dlt.common.schema.utils import get_child_tables, get_top_level_table, get_load_table
from dlt.common.schema.utils import get_child_tables, get_top_level_table
from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState
from dlt.common.typing import StrAny
from dlt.common.runners import TRunMetrics, Runnable, workermethod
Expand Down Expand Up @@ -98,9 +98,12 @@ def w_spool_job(self: "Load", file_path: str, load_id: str, schema: Schema) -> O
if job_info.file_format not in self.load_storage.supported_file_formats:
raise LoadClientUnsupportedFileFormats(job_info.file_format, self.capabilities.supported_loader_file_formats, file_path)
logger.info(f"Will load file {file_path} with table name {job_info.table_name}")
table = get_load_table(schema.tables, job_info.table_name)
table = job_client.get_load_table(job_info.table_name)
if table["write_disposition"] not in ["append", "replace", "merge"]:
raise LoadClientUnsupportedWriteDisposition(job_info.table_name, table["write_disposition"], file_path)

table_needs_staging = isinstance(job_client, WithStagingDataset) and job_client.table_needs_staging(table)
table = job_client.get_load_table(job_info.table_name, table_needs_staging)
with self.maybe_with_staging_dataset(job_client, table):
job = job_client.start_file_load(table, self.load_storage.storage.make_full_path(file_path), load_id)
except (DestinationTerminalException, TerminalValueError):
Expand Down
4 changes: 2 additions & 2 deletions tests/load/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dlt.common.data_writers import DataWriter
from dlt.common.schema import TColumnSchema, TTableSchemaColumns, Schema
from dlt.common.storages import SchemaStorage, FileStorage, SchemaStorageConfiguration
from dlt.common.schema.utils import new_table, get_load_table
from dlt.common.schema.utils import new_table
from dlt.common.storages.load_storage import ParsedLoadJobFileName, LoadStorage
from dlt.common.typing import StrAny
from dlt.common.utils import uniq_id
Expand Down Expand Up @@ -170,7 +170,7 @@ def load_table(name: str) -> Dict[str, TTableSchemaColumns]:
def expect_load_file(client: JobClientBase, file_storage: FileStorage, query: str, table_name: str, status = "completed") -> LoadJob:
file_name = ParsedLoadJobFileName(table_name, uniq_id(), 0, client.capabilities.preferred_loader_file_format).job_id()
file_storage.save(file_name, query.encode("utf-8"))
table = get_load_table(client.schema.tables, table_name)
table = client.get_load_table(table_name)
job = client.start_file_load(table, file_storage.make_full_path(file_name), uniq_id())
while job.state() == "running":
sleep(0.5)
Expand Down

0 comments on commit ab8a01c

Please sign in to comment.