Skip to content

Commit

Permalink
Merge pull request #959 from dlt-hub/devel
Browse files Browse the repository at this point in the history
dlt 0.4.4 release master merge
  • Loading branch information
rudolfix authored Feb 11, 2024
2 parents a8f3338 + 882b29b commit f1633e5
Show file tree
Hide file tree
Showing 15 changed files with 4,445 additions and 3,940 deletions.
2 changes: 1 addition & 1 deletion dlt/common/libs/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def pydantic_to_table_schema_columns(
# This case is for a single field schema/model
# we need to generate snake_case field names
# and return flattened field schemas
schema_hints = pydantic_to_table_schema_columns(field.annotation)
schema_hints = pydantic_to_table_schema_columns(inner_type)

for field_name, hints in schema_hints.items():
schema_key = snake_case_naming_convention.make_path(name, field_name)
Expand Down
46 changes: 41 additions & 5 deletions dlt/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,26 @@
runtime_checkable,
IO,
)
from typing_extensions import TypeAlias, ParamSpec, Concatenate, Annotated, get_args, get_origin

from typing_extensions import (
Annotated,
Never,
ParamSpec,
TypeAlias,
Concatenate,
get_args,
get_origin,
)

try:
from types import UnionType # type: ignore[attr-defined]
except ImportError:
# Since new Union syntax was introduced in Python 3.10
# we need to substitute it here for older versions.
# it is defined as type(int | str) but for us having it
# as shown here should suffice because it is valid only
# in versions of Python>=3.10.
UnionType = Never

from dlt.common.pendulum import timedelta, pendulum

Expand Down Expand Up @@ -103,18 +122,35 @@ def extract_type_if_modifier(t: Type[Any]) -> Type[Any]:


def is_union_type(hint: Type[Any]) -> bool:
if get_origin(hint) is Union:
# We need to handle UnionType because with Python>=3.10
# new Optional syntax was introduced which treats Optionals
# as unions and probably internally there is no additional
# type hints to handle this edge case, see the examples below
# >>> type(str | int)
# <class 'types.UnionType'>
# >>> type(str | None)
# <class 'types.UnionType'>
# type(Union[int, str])
# <class 'typing._GenericAlias'>
origin = get_origin(hint)
if origin is Union or origin is UnionType:
return True

if hint := extract_type_if_modifier(hint):
return is_union_type(hint)

return False


def is_optional_type(t: Type[Any]) -> bool:
if get_origin(t) is Union:
return type(None) in get_args(t)
origin = get_origin(t)
is_union = origin is Union or origin is UnionType
if is_union and type(None) in get_args(t):
return True

if t := extract_type_if_modifier(t):
return is_optional_type(t)

return False


Expand Down Expand Up @@ -232,7 +268,7 @@ def get_generic_type_argument_from_instance(


def copy_sig(
wrapper: Callable[TInputArgs, Any]
wrapper: Callable[TInputArgs, Any],
) -> Callable[[Callable[..., TReturnVal]], Callable[TInputArgs, TReturnVal]]:
"""Copies docstring and signature from wrapper to func but keeps the func return value type"""

Expand Down
16 changes: 7 additions & 9 deletions dlt/destinations/impl/athena/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,9 @@ def _from_db_type(
return self.type_mapper.from_db_type(hive_t, precision, scale)

def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
return f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}"
return (
f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}"
)

def _get_table_update_sql(
self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool
Expand All @@ -376,19 +378,15 @@ def _get_table_update_sql(
# use qualified table names
qualified_table_name = self.sql_client.make_qualified_ddl_table_name(table_name)
if is_iceberg and not generate_alter:
sql.append(
f"""CREATE TABLE {qualified_table_name}
sql.append(f"""CREATE TABLE {qualified_table_name}
({columns})
LOCATION '{location}'
TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');"""
)
TBLPROPERTIES ('table_type'='ICEBERG', 'format'='parquet');""")
elif not generate_alter:
sql.append(
f"""CREATE EXTERNAL TABLE {qualified_table_name}
sql.append(f"""CREATE EXTERNAL TABLE {qualified_table_name}
({columns})
STORED AS PARQUET
LOCATION '{location}';"""
)
LOCATION '{location}';""")
# alter table to add new columns at the end
else:
sql.append(f"""ALTER TABLE {qualified_table_name} ADD COLUMNS ({columns});""")
Expand Down
10 changes: 6 additions & 4 deletions dlt/destinations/impl/bigquery/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ def _get_table_update_sql(
elif (c := partition_list[0])["data_type"] == "date":
sql[0] = f"{sql[0]}\nPARTITION BY {self.capabilities.escape_identifier(c['name'])}"
elif (c := partition_list[0])["data_type"] == "timestamp":
sql[
0
] = f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})"
sql[0] = (
f"{sql[0]}\nPARTITION BY DATE({self.capabilities.escape_identifier(c['name'])})"
)
# Automatic partitioning of an INT64 type requires us to be prescriptive - we treat the column as a UNIX timestamp.
# This is due to the bounds requirement of GENERATE_ARRAY function for partitioning.
# The 10,000 partitions limit makes it infeasible to cover the entire `bigint` range.
Expand All @@ -272,7 +272,9 @@ def _get_table_update_sql(

def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str:
name = self.capabilities.escape_identifier(c["name"])
return f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}"
return (
f"{name} {self.type_mapper.to_db_type(c, table_format)} {self._gen_not_null(c.get('nullable', True))}"
)

def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]:
schema_table: TTableSchemaColumns = {}
Expand Down
28 changes: 18 additions & 10 deletions dlt/destinations/impl/databricks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,14 @@ def __init__(
else:
raise LoadJobTerminalException(
file_path,
f"Databricks cannot load data from staging bucket {bucket_path}. Only s3 and azure buckets are supported",
f"Databricks cannot load data from staging bucket {bucket_path}. Only s3 and"
" azure buckets are supported",
)
else:
raise LoadJobTerminalException(
file_path,
"Cannot load from local file. Databricks does not support loading from local files. Configure staging with an s3 or azure storage bucket.",
"Cannot load from local file. Databricks does not support loading from local files."
" Configure staging with an s3 or azure storage bucket.",
)

# decide on source format, stage_file_path will either be a local file or a bucket path
Expand All @@ -181,27 +183,33 @@ def __init__(
if not config.get("data_writer.disable_compression"):
raise LoadJobTerminalException(
file_path,
"Databricks loader does not support gzip compressed JSON files. Please disable compression in the data writer configuration: https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression",
"Databricks loader does not support gzip compressed JSON files. Please disable"
" compression in the data writer configuration:"
" https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression",
)
if table_schema_has_type(table, "decimal"):
raise LoadJobTerminalException(
file_path,
"Databricks loader cannot load DECIMAL type columns from json files. Switch to parquet format to load decimals.",
"Databricks loader cannot load DECIMAL type columns from json files. Switch to"
" parquet format to load decimals.",
)
if table_schema_has_type(table, "binary"):
raise LoadJobTerminalException(
file_path,
"Databricks loader cannot load BINARY type columns from json files. Switch to parquet format to load byte values.",
"Databricks loader cannot load BINARY type columns from json files. Switch to"
" parquet format to load byte values.",
)
if table_schema_has_type(table, "complex"):
raise LoadJobTerminalException(
file_path,
"Databricks loader cannot load complex columns (lists and dicts) from json files. Switch to parquet format to load complex types.",
"Databricks loader cannot load complex columns (lists and dicts) from json"
" files. Switch to parquet format to load complex types.",
)
if table_schema_has_type(table, "date"):
raise LoadJobTerminalException(
file_path,
"Databricks loader cannot load DATE type columns from json files. Switch to parquet format to load dates.",
"Databricks loader cannot load DATE type columns from json files. Switch to"
" parquet format to load dates.",
)

source_format = "JSON"
Expand Down Expand Up @@ -311,7 +319,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non

def _get_storage_table_query_columns(self) -> List[str]:
fields = super()._get_storage_table_query_columns()
fields[
1
] = "full_data_type" # Override because this is the only way to get data type with precision
fields[1] = ( # Override because this is the only way to get data type with precision
"full_data_type"
)
return fields
6 changes: 2 additions & 4 deletions dlt/destinations/impl/snowflake/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,13 @@ def __init__(
f'PUT file://{file_path} @{stage_name}/"{load_id}" OVERWRITE = TRUE,'
" AUTO_COMPRESS = FALSE"
)
client.execute_sql(
f"""COPY INTO {qualified_table_name}
client.execute_sql(f"""COPY INTO {qualified_table_name}
{from_clause}
{files_clause}
{credentials_clause}
FILE_FORMAT = {source_format}
MATCH_BY_COLUMN_NAME='CASE_INSENSITIVE'
"""
)
""")
if stage_file_path and not keep_staged_files:
client.execute_sql(f"REMOVE {stage_file_path}")

Expand Down
31 changes: 23 additions & 8 deletions dlt/extract/incremental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(
self.start_value: Any = initial_value
"""Value of last_value at the beginning of current pipeline run"""
self.resource_name: Optional[str] = None
self.primary_key: Optional[TTableHintTemplate[TColumnNames]] = primary_key
self._primary_key: Optional[TTableHintTemplate[TColumnNames]] = primary_key
self.allow_external_schedulers = allow_external_schedulers

self._cached_state: IncrementalColumnState = None
Expand All @@ -133,6 +133,18 @@ def __init__(

self._transformers: Dict[str, IncrementalTransform] = {}

@property
def primary_key(self) -> Optional[TTableHintTemplate[TColumnNames]]:
return self._primary_key

@primary_key.setter
def primary_key(self, value: str) -> None:
# set key in incremental and data type transformers
self._primary_key = value
if self._transformers:
for transform in self._transformers.values():
transform.primary_key = value

def _make_transforms(self) -> None:
types = [("arrow", ArrowIncremental), ("json", JsonIncremental)]
for dt, kls in types:
Expand All @@ -143,7 +155,7 @@ def _make_transforms(self) -> None:
self.end_value,
self._cached_state,
self.last_value_func,
self.primary_key,
self._primary_key,
)

@classmethod
Expand All @@ -163,7 +175,7 @@ def copy(self) -> "Incremental[TCursorValue]":
self.cursor_path,
initial_value=self.initial_value,
last_value_func=self.last_value_func,
primary_key=self.primary_key,
primary_key=self._primary_key,
end_value=self.end_value,
allow_external_schedulers=self.allow_external_schedulers,
)
Expand All @@ -178,7 +190,7 @@ def merge(self, other: "Incremental[TCursorValue]") -> "Incremental[TCursorValue
>>>
>>> my_resource(updated=incremental(initial_value='2023-01-01', end_value='2023-02-01'))
"""
kwargs = dict(self, last_value_func=self.last_value_func, primary_key=self.primary_key)
kwargs = dict(self, last_value_func=self.last_value_func, primary_key=self._primary_key)
for key, value in dict(
other, last_value_func=other.last_value_func, primary_key=other.primary_key
).items():
Expand Down Expand Up @@ -395,7 +407,6 @@ def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]:
return rows

transformer = self._get_transformer(rows)
transformer.primary_key = self.primary_key

if isinstance(rows, list):
return [
Expand Down Expand Up @@ -476,7 +487,7 @@ def _wrap(*args: Any, **kwargs: Any) -> Any:
elif isinstance(p.default, Incremental):
new_incremental = p.default.copy()

if not new_incremental or new_incremental.is_partial():
if (not new_incremental or new_incremental.is_partial()) and not self._incremental:
if is_optional_type(p.annotation):
bound_args.arguments[p.name] = None # Remove partial spec
return func(*bound_args.args, **bound_args.kwargs)
Expand All @@ -486,15 +497,16 @@ def _wrap(*args: Any, **kwargs: Any) -> Any:
)
# pass Generic information from annotation to new_incremental
if (
not hasattr(new_incremental, "__orig_class__")
new_incremental
and not hasattr(new_incremental, "__orig_class__")
and p.annotation
and get_args(p.annotation)
):
new_incremental.__orig_class__ = p.annotation # type: ignore

# set the incremental only if not yet set or if it was passed explicitly
# NOTE: the _incremental may be also set by applying hints to the resource see `set_template` in `DltResource`
if p.name in bound_args.arguments or not self._incremental:
if (new_incremental and p.name in bound_args.arguments) or not self._incremental:
self._incremental = new_incremental
self._incremental.resolve()
# in case of transformers the bind will be called before this wrapper is set: because transformer is called for a first time late in the pipe
Expand Down Expand Up @@ -531,6 +543,9 @@ def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]:
return item
if self._incremental.primary_key is None:
self._incremental.primary_key = self.primary_key
elif self.primary_key is None:
# propagate from incremental
self.primary_key = self._incremental.primary_key
return self._incremental(item, meta)


Expand Down
6 changes: 3 additions & 3 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,9 +1163,9 @@ def _set_context(self, is_active: bool) -> None:
# set destination context on activation
if self.destination:
# inject capabilities context
self._container[
DestinationCapabilitiesContext
] = self._get_destination_capabilities()
self._container[DestinationCapabilitiesContext] = (
self._get_destination_capabilities()
)
else:
# remove destination context on deactivation
if DestinationCapabilitiesContext in self._container:
Expand Down
Loading

0 comments on commit f1633e5

Please sign in to comment.