From 44ba39f3f6f20378f07bf335b42f687eeac87fd7 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 27 Mar 2024 13:19:18 +0100 Subject: [PATCH 01/22] bumps for prerelease 0.4.8a1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 444541aa43..861fbcd5c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dlt" -version = "0.4.8a0" +version = "0.4.8a1" description = "dlt is an open-source python-first scalable data loading library that does not require any backend to run." authors = ["dltHub Inc. "] maintainers = [ "Marcin Rudolf ", "Adrian Brudaru ", "Ty Dunn "] From 2cdb2968d764c92912c1b6b72305b564aace393a Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 3 Apr 2024 22:59:10 +0200 Subject: [PATCH 02/22] requires password and database in motherduck credentials --- .../impl/motherduck/configuration.py | 12 +++--- tests/load/duckdb/test_motherduck_client.py | 37 +++++++++++++++---- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/dlt/destinations/impl/motherduck/configuration.py b/dlt/destinations/impl/motherduck/configuration.py index 3179295c54..c4888d3fe4 100644 --- a/dlt/destinations/impl/motherduck/configuration.py +++ b/dlt/destinations/impl/motherduck/configuration.py @@ -17,6 +17,8 @@ class MotherDuckCredentials(DuckDbBaseCredentials): drivername: Final[str] = dataclasses.field(default="md", init=False, repr=False, compare=False) # type: ignore username: str = "motherduck" + password: TSecretValue = None + database: str = "my_db" read_only: bool = False # open database read/write @@ -47,13 +49,11 @@ def parse_native_representation(self, native_value: Any) -> None: super().parse_native_representation(native_value) self._token_to_password() - def on_resolved(self) -> None: + def on_partial(self) -> None: + """Takes a token from query string and reuses it as a password""" self._token_to_password() - if self.drivername == MOTHERDUCK_DRIVERNAME and not self.password: - raise ConfigurationValueError( - "Motherduck schema 'md' was specified without corresponding token or password. The" - " required format of connection string is: md:///?token=" - ) + if not self.is_partial(): + self.resolve() @configspec diff --git a/tests/load/duckdb/test_motherduck_client.py b/tests/load/duckdb/test_motherduck_client.py index ba60e0de6d..7a156d8ab4 100644 --- a/tests/load/duckdb/test_motherduck_client.py +++ b/tests/load/duckdb/test_motherduck_client.py @@ -1,6 +1,7 @@ import os import pytest +from dlt.common.configuration.exceptions import ConfigFieldMissingException from dlt.common.configuration.resolve import resolve_configuration from dlt.destinations.impl.motherduck.configuration import ( @@ -13,17 +14,39 @@ skip_if_not_active("motherduck") -def test_motherduck_database() -> None: - # set HOME env otherwise some internal components in ducdkb (HTTPS) do not initialize - os.environ["HOME"] = "/tmp" - # os.environ.pop("HOME", None) - - cred = MotherDuckCredentials("md:///?token=TOKEN") - print(dict(cred)) +def test_motherduck_configuration() -> None: + cred = MotherDuckCredentials("md:///dlt_data?token=TOKEN") + # print(dict(cred)) assert cred.password == "TOKEN" + assert cred.database == "dlt_data" + assert cred.is_partial() is False + assert cred.is_resolved() is True + cred = MotherDuckCredentials() cred.parse_native_representation("md:///?token=TOKEN") assert cred.password == "TOKEN" + assert cred.database == "" + assert cred.is_partial() is False + assert cred.is_resolved() is False + + # password or token are mandatory + with pytest.raises(ConfigFieldMissingException) as conf_ex: + resolve_configuration(MotherDuckCredentials()) + assert conf_ex.value.fields == ["password"] + + os.environ["CREDENTIALS__PASSWORD"] = "pwd" + config = resolve_configuration(MotherDuckCredentials()) + assert config.password == "pwd" + + del os.environ["CREDENTIALS__PASSWORD"] + os.environ["CREDENTIALS__QUERY"] = '{"token": "tok"}' + config = resolve_configuration(MotherDuckCredentials()) + assert config.password == "tok" + + +def test_motherduck_connect() -> None: + # set HOME env otherwise some internal components in ducdkb (HTTPS) do not initialize + os.environ["HOME"] = "/tmp" config = resolve_configuration( MotherDuckClientConfiguration()._bind_dataset_name(dataset_name="test"), From f9ab06e6f94f59d3617d31dec23403403849ef28 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 3 Apr 2024 23:01:54 +0200 Subject: [PATCH 03/22] identifies data writer by both file format and source item format, adds csv writer for arrow and object(wip) --- dlt/common/data_writers/__init__.py | 10 +- dlt/common/data_writers/buffered.py | 35 +- dlt/common/data_writers/exceptions.py | 19 + dlt/common/data_writers/writers.py | 327 ++++++++++++++---- .../common/data_writers/test_data_writers.py | 3 - tests/libs/test_arrow_csv_writer.py | 128 +++++++ .../load/pipeline/test_filesystem_pipeline.py | 23 ++ .../test_pipeline_file_format_resolver.py | 10 +- 8 files changed, 459 insertions(+), 96 deletions(-) create mode 100644 tests/libs/test_arrow_csv_writer.py diff --git a/dlt/common/data_writers/__init__.py b/dlt/common/data_writers/__init__.py index 04c5d04328..931bda962b 100644 --- a/dlt/common/data_writers/__init__.py +++ b/dlt/common/data_writers/__init__.py @@ -1,4 +1,9 @@ -from dlt.common.data_writers.writers import DataWriter, DataWriterMetrics, TLoaderFileFormat +from dlt.common.data_writers.writers import ( + DataWriter, + DataWriterMetrics, + TDataItemFormat, + FileWriterSpec, +) from dlt.common.data_writers.buffered import BufferedDataWriter, new_file_id from dlt.common.data_writers.escape import ( escape_redshift_literal, @@ -8,8 +13,9 @@ __all__ = [ "DataWriter", + "FileWriterSpec", "DataWriterMetrics", - "TLoaderFileFormat", + "TDataItemFormat", "BufferedDataWriter", "new_file_id", "escape_redshift_literal", diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index b10b1d14b9..1db18b065e 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -1,23 +1,20 @@ import gzip import time -from typing import ClassVar, List, IO, Any, Optional, Type, TypeVar, Generic +from typing import ClassVar, List, IO, Any, Optional, Type, Generic from dlt.common.typing import TDataItem, TDataItems -from dlt.common.data_writers import TLoaderFileFormat from dlt.common.data_writers.exceptions import ( BufferedDataWriterClosed, DestinationCapabilitiesRequired, InvalidFileNameTemplateException, ) -from dlt.common.data_writers.writers import DataWriter, DataWriterMetrics +from dlt.common.data_writers.writers import TWriter, DataWriter, DataWriterMetrics, FileWriterSpec from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.configuration import with_config, known_sections, configspec from dlt.common.configuration.specs import BaseConfiguration from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.utils import uniq_id -TWriter = TypeVar("TWriter", bound=DataWriter) - def new_file_id() -> str: """Creates new file id which is globally unique within table_name scope""" @@ -38,7 +35,7 @@ class BufferedDataWriterConfiguration(BaseConfiguration): @with_config(spec=BufferedDataWriterConfiguration) def __init__( self, - file_format: TLoaderFileFormat, + writer_spec: FileWriterSpec, file_name_template: str, *, buffer_max_items: int = 5000, @@ -47,10 +44,13 @@ def __init__( disable_compression: bool = False, _caps: DestinationCapabilitiesContext = None ): - self.file_format = file_format - self._file_format_spec = DataWriter.data_format_from_file_format(self.file_format) - if self._file_format_spec.requires_destination_capabilities and not _caps: - raise DestinationCapabilitiesRequired(file_format) + self.writer_spec = writer_spec + if self.writer_spec.requires_destination_capabilities and not _caps: + raise DestinationCapabilitiesRequired(self.writer_spec.file_format) + self.writer_cls = DataWriter.class_factory( + writer_spec.file_format, writer_spec.data_item_format + ) + self._supports_schema_changes = self.writer_spec.supports_schema_changes self._caps = _caps # validate if template has correct placeholders self.file_name_template = file_name_template @@ -61,9 +61,7 @@ def __init__( self.file_max_items = file_max_items # the open function is either gzip.open or open self.open = ( - gzip.open - if self._file_format_spec.supports_compression and not disable_compression - else open + gzip.open if self.writer_spec.supports_compression and not disable_compression else open ) self._current_columns: TTableSchemaColumns = None @@ -87,8 +85,9 @@ def write_data_item(self, item: TDataItems, columns: TTableSchemaColumns) -> int # rotate file if columns changed and writer does not allow for that # as the only allowed change is to add new column (no updates/deletes), we detect the change by comparing lengths if ( - self._writer - and not self._writer.data_format().supports_schema_changes + self._current_columns is not None + and (self._writer or self._supports_schema_changes == "False") + and self._supports_schema_changes != "True" and len(columns) != len(self._current_columns) ): assert len(columns) > len(self._current_columns) @@ -183,7 +182,7 @@ def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb def _rotate_file(self, allow_empty_file: bool = False) -> DataWriterMetrics: metrics = self._flush_and_close_file(allow_empty_file) self._file_name = ( - self.file_name_template % new_file_id() + "." + self._file_format_spec.file_extension + self.file_name_template % new_file_id() + "." + self.writer_spec.file_extension ) self._created = time.time() return metrics @@ -193,11 +192,11 @@ def _flush_items(self, allow_empty_file: bool = False) -> None: # we only open a writer when there are any items in the buffer and first flush is requested if not self._writer: # create new writer and write header - if self._file_format_spec.is_binary_format: + if self.writer_spec.is_binary_format: self._file = self.open(self._file_name, "wb") # type: ignore else: self._file = self.open(self._file_name, "wt", encoding="utf-8") # type: ignore - self._writer = DataWriter.from_file_format(self.file_format, self._file, caps=self._caps) # type: ignore[assignment] + self._writer = self.writer_cls(self._file, caps=self._caps) # type: ignore[assignment] self._writer.write_header(self._current_columns) # write buffer if self._buffered_items: diff --git a/dlt/common/data_writers/exceptions.py b/dlt/common/data_writers/exceptions.py index d3a073cf4e..ac339ba31c 100644 --- a/dlt/common/data_writers/exceptions.py +++ b/dlt/common/data_writers/exceptions.py @@ -27,3 +27,22 @@ def __init__(self, file_format: TLoaderFileFormat): super().__init__( f"Writer for {file_format} requires destination capabilities which were not provided." ) + + +class DataWriterNotFound(DataWriterException): + def __init__(self, file_format: TLoaderFileFormat, data_item_format: str): + self.file_format = file_format + self.data_item_format = data_item_format + super().__init__( + f"Can't find a file writer for file format {file_format} and item format" + f" {data_item_format}" + ) + + +class InvalidDataItem(DataWriterException): + def __init__(self, file_format: TLoaderFileFormat, data_item_format: str, details: str): + self.file_format = file_format + self.data_item_format = data_item_format + super().__init__( + f"A data item of type {data_item_format} cannot be written as {file_format}: {details}" + ) diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 2aadb010e0..67e5466d39 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -1,4 +1,5 @@ import abc +import csv from dataclasses import dataclass from typing import ( IO, @@ -7,17 +8,19 @@ ClassVar, Dict, List, + Literal, Optional, Sequence, Tuple, Type, NamedTuple, - overload, + TypeVar, ) from dlt.common import json from dlt.common.configuration import configspec, known_sections, with_config from dlt.common.configuration.specs import BaseConfiguration +from dlt.common.data_writers.exceptions import DataWriterNotFound, InvalidDataItem from dlt.common.destination import DestinationCapabilitiesContext, TLoaderFileFormat from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.typing import StrAny @@ -26,12 +29,20 @@ from dlt.common.libs.pyarrow import pyarrow as pa +TDataItemFormat = Literal["arrow", "object"] +TWriter = TypeVar("TWriter", bound="DataWriter") + + @dataclass -class TFileFormatSpec: +class FileWriterSpec: file_format: TLoaderFileFormat + """format of the output file""" + data_item_format: TDataItemFormat + """format of the input data""" file_extension: str is_binary_format: bool - supports_schema_changes: bool + supports_schema_changes: Literal["True", "Buffer", "False"] + """File format supports changes of schema: True - at any moment, Buffer - in memory buffer before opening file, False - not at all""" requires_destination_capabilities: bool = False supports_compression: bool = False @@ -64,15 +75,13 @@ def __init__(self, f: IO[Any], caps: DestinationCapabilitiesContext = None) -> N self._caps = caps self.items_count = 0 - @abc.abstractmethod - def write_header(self, columns_schema: TTableSchemaColumns) -> None: + def write_header(self, columns_schema: TTableSchemaColumns) -> None: # noqa pass def write_data(self, rows: Sequence[Any]) -> None: self.items_count += len(rows) - @abc.abstractmethod - def write_footer(self) -> None: + def write_footer(self) -> None: # noqa pass def write_all(self, columns_schema: TTableSchemaColumns, rows: Sequence[Any]) -> None: @@ -82,66 +91,66 @@ def write_all(self, columns_schema: TTableSchemaColumns, rows: Sequence[Any]) -> @classmethod @abc.abstractmethod - def data_format(cls) -> TFileFormatSpec: + def writer_spec(cls) -> FileWriterSpec: pass @classmethod def from_file_format( - cls, file_format: TLoaderFileFormat, f: IO[Any], caps: DestinationCapabilitiesContext = None + cls, + file_format: TLoaderFileFormat, + data_item_format: TDataItemFormat, + f: IO[Any], + caps: DestinationCapabilitiesContext = None, ) -> "DataWriter": - return cls.class_factory(file_format)(f, caps) + return cls.class_factory(file_format, data_item_format)(f, caps) @classmethod - def from_destination_capabilities( - cls, caps: DestinationCapabilitiesContext, f: IO[Any] - ) -> "DataWriter": - return cls.class_factory(caps.preferred_loader_file_format)(f, caps) + def writer_spec_from_file_format( + cls, file_format: TLoaderFileFormat, data_item_format: TDataItemFormat + ) -> FileWriterSpec: + return cls.class_factory(file_format, data_item_format).writer_spec() @classmethod - def data_format_from_file_format(cls, file_format: TLoaderFileFormat) -> TFileFormatSpec: - return cls.class_factory(file_format).data_format() + def item_format_from_file_extension(cls, extension: str) -> TDataItemFormat: + """Simple heuristic to get data item format from file extension""" + if extension == "typed-jsonl": + return "object" + elif extension == "parquet": + return "arrow" + else: + raise ValueError(f"Cannot figure out data item format for extension {extension}") @staticmethod - def class_factory(file_format: TLoaderFileFormat) -> Type["DataWriter"]: - if file_format == "jsonl": - return JsonlWriter - elif file_format == "puae-jsonl": - return JsonlListPUAEncodeWriter - elif file_format == "insert_values": - return InsertValuesWriter - elif file_format == "parquet": - return ParquetDataWriter # type: ignore - elif file_format == "arrow": - return ArrowWriter # type: ignore - else: - raise ValueError(file_format) + def class_factory( + file_format: TLoaderFileFormat, data_item_format: TDataItemFormat + ) -> Type["DataWriter"]: + for writer in ALL_WRITERS: + spec = writer.writer_spec() + if spec.file_format == file_format and spec.data_item_format == data_item_format: + return writer + raise DataWriterNotFound(file_format, data_item_format) class JsonlWriter(DataWriter): - def write_header(self, columns_schema: TTableSchemaColumns) -> None: - pass - def write_data(self, rows: Sequence[Any]) -> None: super().write_data(rows) for row in rows: json.dump(row, self._f) self._f.write(b"\n") - def write_footer(self) -> None: - pass - @classmethod - def data_format(cls) -> TFileFormatSpec: - return TFileFormatSpec( + def writer_spec(cls) -> FileWriterSpec: + return FileWriterSpec( "jsonl", + "object", file_extension="jsonl", is_binary_format=True, - supports_schema_changes=True, + supports_schema_changes="True", supports_compression=True, ) -class JsonlListPUAEncodeWriter(JsonlWriter): +class TypedJsonlListWriter(JsonlWriter): def write_data(self, rows: Sequence[Any]) -> None: # skip JsonlWriter when calling super super(JsonlWriter, self).write_data(rows) @@ -151,12 +160,13 @@ def write_data(self, rows: Sequence[Any]) -> None: self._f.write(b"\n") @classmethod - def data_format(cls) -> TFileFormatSpec: - return TFileFormatSpec( - "puae-jsonl", - file_extension="jsonl", + def writer_spec(cls) -> FileWriterSpec: + return FileWriterSpec( + "typed-jsonl", + "object", + file_extension="typed-jsonl", is_binary_format=True, - supports_schema_changes=True, + supports_schema_changes="True", supports_compression=True, ) @@ -217,12 +227,13 @@ def write_footer(self) -> None: self._f.write(";") @classmethod - def data_format(cls) -> TFileFormatSpec: - return TFileFormatSpec( + def writer_spec(cls) -> FileWriterSpec: + return FileWriterSpec( "insert_values", + "object", file_extension="insert_values", is_binary_format=False, - supports_schema_changes=False, + supports_schema_changes="Buffer", supports_compression=True, requires_destination_capabilities=True, ) @@ -230,9 +241,9 @@ def data_format(cls) -> TFileFormatSpec: @configspec class ParquetDataWriterConfiguration(BaseConfiguration): - flavor: str = "spark" - version: str = "2.4" - data_page_size: int = 1024 * 1024 + flavor: Optional[str] = None # could be ie. "spark" + version: Optional[str] = "2.4" + data_page_size: Optional[int] = None timestamp_precision: str = "us" timestamp_timezone: str = "UTC" row_group_size: Optional[int] = None @@ -247,9 +258,9 @@ def __init__( f: IO[Any], caps: DestinationCapabilitiesContext = None, *, - flavor: str = "spark", - version: str = "2.4", - data_page_size: int = 1024 * 1024, + flavor: Optional[str] = None, + version: Optional[str] = "2.4", + data_page_size: Optional[int] = None, timestamp_timezone: str = "UTC", row_group_size: Optional[int] = None, ) -> None: @@ -303,8 +314,8 @@ def write_data(self, rows: Sequence[Any]) -> None: # replace complex types with json for key in self.complex_indices: for row in rows: - if key in row: - row[key] = json.dumps(row[key]) + if (value := row.get(key)) is not None: + row[key] = json.dumps(value) table = pyarrow.Table.from_pylist(rows, schema=self.schema) # Write @@ -315,18 +326,61 @@ def write_footer(self) -> None: self.writer = None @classmethod - def data_format(cls) -> TFileFormatSpec: - return TFileFormatSpec( + def writer_spec(cls) -> FileWriterSpec: + return FileWriterSpec( "parquet", + "object", "parquet", - True, - False, + is_binary_format=True, + supports_schema_changes="Buffer", requires_destination_capabilities=True, supports_compression=False, ) -class ArrowWriter(ParquetDataWriter): +class CsvWriter(DataWriter): + def __init__( + self, f: IO[Any], caps: DestinationCapabilitiesContext = None, delimiter: str = "," + ) -> None: + super().__init__(f, caps) + self.delimiter = delimiter + self.writer: csv.DictWriter[str] = None + + def write_header(self, columns_schema: TTableSchemaColumns) -> None: + self._columns_schema = columns_schema + self.writer = csv.DictWriter( + self._f, + fieldnames=list(columns_schema.keys()), + extrasaction="ignore", + dialect=csv.unix_dialect, + delimiter=self.delimiter, + ) + self.writer.writeheader() + + def write_data(self, rows: Sequence[Any]) -> None: + self.writer.writerows(rows) + # count rows that got written + self.items_count += sum(len(row) for row in rows) + + def write_footer(self) -> None: + if self.writer is None: + self.writer = None + self._first_schema = None + + @classmethod + def writer_spec(cls) -> FileWriterSpec: + return FileWriterSpec( + "csv", + "object", + file_extension="csv", + is_binary_format=False, + supports_schema_changes="False", + requires_destination_capabilities=False, + supports_compression=True, + ) + + +class ArrowToParquetWriter(ParquetDataWriter): def write_header(self, columns_schema: TTableSchemaColumns) -> None: # Schema will be written as-is from the arrow table self._column_schema = columns_schema @@ -334,12 +388,9 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: def write_data(self, rows: Sequence[Any]) -> None: from dlt.common.libs.pyarrow import pyarrow - rows = list(rows) - if not rows: - return - first = rows[0] - self.writer = self.writer or self._create_writer(first.schema) for row in rows: + if not self.writer: + self.writer = self._create_writer(row.schema) if isinstance(row, pyarrow.Table): self.writer.write_table(row, row_group_size=self.parquet_row_group_size) elif isinstance(row, pyarrow.RecordBatch): @@ -355,12 +406,152 @@ def write_footer(self) -> None: return super().write_footer() @classmethod - def data_format(cls) -> TFileFormatSpec: - return TFileFormatSpec( + def writer_spec(cls) -> FileWriterSpec: + return FileWriterSpec( + "parquet", "arrow", file_extension="parquet", is_binary_format=True, - supports_schema_changes=False, + supports_schema_changes="False", requires_destination_capabilities=False, supports_compression=False, ) + + +class ArrowToCsvWriter(DataWriter): + def __init__( + self, f: IO[Any], caps: DestinationCapabilitiesContext = None, delimiter: bytes = b"," + ) -> None: + super().__init__(f, caps) + self.delimiter = delimiter + self.writer: Any = None + + def write_header(self, columns_schema: TTableSchemaColumns) -> None: + self._columns_schema = columns_schema + + def write_data(self, rows: Sequence[Any]) -> None: + from dlt.common.libs.pyarrow import pyarrow + import pyarrow.csv + + for row in rows: + if isinstance(row, (pyarrow.Table, pyarrow.RecordBatch)): + if not self.writer: + try: + self.writer = pyarrow.csv.CSVWriter( + self._f, + row.schema, + write_options=pyarrow.csv.WriteOptions( + include_header=True, delimiter=self.delimiter + ), + ) + self._first_schema = row.schema + except pyarrow.ArrowInvalid as inv_ex: + if "Unsupported Type" in str(inv_ex): + raise InvalidDataItem( + "csv", + "arrow", + "Arrow data contains a column that cannot be written to csv file" + f" ({inv_ex}). Remove nested columns (struct, map) or convert them" + " to json strings.", + ) + raise + # make sure that Schema stays the same + if not row.schema.equals(self._first_schema): + raise InvalidDataItem( + "csv", + "arrow", + "Arrow schema changed without rotating the file. This may be internal" + " error or misuse of the writer.\nFirst" + f" schema:\n{self._first_schema}\n\nCurrent schema:\n{row.schema}", + ) + + # write headers only on the first write + try: + self.writer.write(row) + except pyarrow.ArrowInvalid as inv_ex: + if "Invalid UTF8 payload" in str(inv_ex): + raise InvalidDataItem( + "csv", + "arrow", + "Arrow data contains string or binary columns with invalid UTF-8" + " characters. Remove binary columns or replace their content with a hex" + " representation: \\x... while keeping data type as binary.", + ) + raise + else: + raise ValueError(f"Unsupported type {type(row)}") + # count rows that got written + self.items_count += row.num_rows + + def write_footer(self) -> None: + if self.writer is None: + # write empty file + self._f.write( + self.delimiter.join( + [col["name"].encode("utf-8") for col in self._columns_schema.values()] + ) + ) + else: + self.writer.close() + self.writer = None + self._first_schema = None + + @classmethod + def writer_spec(cls) -> FileWriterSpec: + return FileWriterSpec( + "csv", + "arrow", + file_extension="csv", + is_binary_format=True, + supports_schema_changes="False", + requires_destination_capabilities=False, + supports_compression=True, + ) + + +class ArrowToObjectAdapter: + """A mixin that will convert object writer into arrow writer.""" + + def write_data(self, rows: Sequence[Any]) -> None: + for batch in rows: + # convert to object data item format + super().write_data(batch.to_pylist()) # type: ignore[misc] + + @staticmethod + def convert_spec(base: Type[DataWriter]) -> FileWriterSpec: + spec = base.writer_spec() + spec.data_item_format = "arrow" + return spec + + +class ArrowToInsertValuesWriter(ArrowToObjectAdapter, InsertValuesWriter): + @classmethod + def writer_spec(cls) -> FileWriterSpec: + return cls.convert_spec(InsertValuesWriter) + + +class ArrowToJsonlWriter(ArrowToObjectAdapter, JsonlWriter): + @classmethod + def writer_spec(cls) -> FileWriterSpec: + return cls.convert_spec(JsonlWriter) + + +class ArrowToTypedJsonlListWriter(ArrowToObjectAdapter, TypedJsonlListWriter): + @classmethod + def writer_spec(cls) -> FileWriterSpec: + return cls.convert_spec(TypedJsonlListWriter) + + +# ArrowToCsvWriter +ALL_WRITERS: List[Type[DataWriter]] = [ + JsonlWriter, + TypedJsonlListWriter, + InsertValuesWriter, + ParquetDataWriter, + CsvWriter, + ArrowToParquetWriter, + ArrowToInsertValuesWriter, + ArrowToJsonlWriter, + ArrowToTypedJsonlListWriter, + ArrowToCsvWriter, +] diff --git a/tests/common/data_writers/test_data_writers.py b/tests/common/data_writers/test_data_writers.py index ac4f118229..456ac64996 100644 --- a/tests/common/data_writers/test_data_writers.py +++ b/tests/common/data_writers/test_data_writers.py @@ -23,9 +23,6 @@ EMPTY_DATA_WRITER_METRICS, InsertValuesWriter, JsonlWriter, - JsonlListPUAEncodeWriter, - ParquetDataWriter, - ArrowWriter, ) from tests.common.utils import load_json_case, row_to_column_schemas diff --git a/tests/libs/test_arrow_csv_writer.py b/tests/libs/test_arrow_csv_writer.py new file mode 100644 index 0000000000..91038f01c4 --- /dev/null +++ b/tests/libs/test_arrow_csv_writer.py @@ -0,0 +1,128 @@ +import csv +from copy import copy +import pytest +import pyarrow.parquet as pq + +from dlt.common.data_writers.exceptions import InvalidDataItem +from dlt.common.data_writers.writers import ArrowToCsvWriter, ParquetDataWriter +from dlt.common.libs.pyarrow import remove_columns + +from tests.common.data_writers.utils import get_writer +from tests.cases import ( + TABLE_UPDATE_COLUMNS_SCHEMA, + TABLE_ROW_ALL_DATA_TYPES_DATETIMES, + TABLE_ROW_ALL_DATA_TYPES, + arrow_table_all_data_types, +) +from tests.utils import write_version, autouse_test_storage, preserve_environ + + +def test_csv_writer_all_data_fields() -> None: + data = TABLE_ROW_ALL_DATA_TYPES_DATETIMES + + # write parquet and read it + with get_writer(ParquetDataWriter) as pq_writer: + pq_writer.write_data_item([data], TABLE_UPDATE_COLUMNS_SCHEMA) + + with open(pq_writer.closed_files[0].file_path, "rb") as f: + table = pq.read_table(f) + + with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: + writer.write_data_item(table, TABLE_UPDATE_COLUMNS_SCHEMA) + + with open(writer.closed_files[0].file_path, "r", encoding="utf-8") as f: + rows = list(csv.reader(f, dialect=csv.unix_dialect)) + # header + 1 data + assert len(rows) == 2 + + # compare headers + assert rows[0] == list(TABLE_UPDATE_COLUMNS_SCHEMA.keys()) + # compare row + assert len(rows[1]) == len(list(TABLE_ROW_ALL_DATA_TYPES.values())) + # compare none values + for actual, expected, col_name in zip( + rows[1], TABLE_ROW_ALL_DATA_TYPES.values(), TABLE_UPDATE_COLUMNS_SCHEMA.keys() + ): + if expected is None: + assert actual == "", f"{col_name} is not recovered as None" + + # write again with several arrows + with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: + writer.write_data_item([table, table], TABLE_UPDATE_COLUMNS_SCHEMA) + writer.write_data_item(table.to_batches(), TABLE_UPDATE_COLUMNS_SCHEMA) + + with open(writer.closed_files[0].file_path, "r", encoding="utf-8") as f: + rows = list(csv.reader(f, dialect=csv.unix_dialect)) + # header + 3 data + assert len(rows) == 4 + assert rows[0] == list(TABLE_UPDATE_COLUMNS_SCHEMA.keys()) + assert rows[1] == rows[2] == rows[3] + + # simulate non announced schema change + base_table = remove_columns(table, ["col9_null"]) + base_column_schema = copy(TABLE_UPDATE_COLUMNS_SCHEMA) + base_column_schema.pop("col9_null") + + with pytest.raises(InvalidDataItem): + with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: + writer.write_data_item([base_table, table], TABLE_UPDATE_COLUMNS_SCHEMA) + + with pytest.raises(InvalidDataItem): + with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: + writer.write_data_item(table, TABLE_UPDATE_COLUMNS_SCHEMA) + writer.write_data_item(base_table, TABLE_UPDATE_COLUMNS_SCHEMA) + + # schema change will rotate the file + with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: + writer.write_data_item(base_table, base_column_schema) + writer.write_data_item([table, table], TABLE_UPDATE_COLUMNS_SCHEMA) + + assert len(writer.closed_files) == 2 + + # first file + with open(writer.closed_files[0].file_path, "r", encoding="utf-8") as f: + rows = list(csv.reader(f, dialect=csv.unix_dialect)) + # header + 1 data + assert len(rows) == 2 + assert rows[0] == list(base_column_schema.keys()) + # second file + with open(writer.closed_files[1].file_path, "r", encoding="utf-8") as f: + rows = list(csv.reader(f, dialect=csv.unix_dialect)) + # header + 2 data + assert len(rows) == 3 + assert rows[0] == list(TABLE_UPDATE_COLUMNS_SCHEMA.keys()) + + +def test_non_utf8_binary() -> None: + data = TABLE_ROW_ALL_DATA_TYPES_DATETIMES + data["col7"] += b"\x8e" # type: ignore[operator] + + # write parquet and read it + with get_writer(ParquetDataWriter) as pq_writer: + pq_writer.write_data_item([data], TABLE_UPDATE_COLUMNS_SCHEMA) + + with open(pq_writer.closed_files[0].file_path, "rb") as f: + table = pq.read_table(f) + + with pytest.raises(InvalidDataItem): + with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: + writer.write_data_item(table, TABLE_UPDATE_COLUMNS_SCHEMA) + + +def test_arrow_struct() -> None: + item, _ = arrow_table_all_data_types("table", include_json=True, include_time=False) + with pytest.raises(InvalidDataItem): + with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: + writer.write_data_item(item, TABLE_UPDATE_COLUMNS_SCHEMA) + + +def test_csv_writer_empty() -> None: + with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: + writer.write_empty_file(TABLE_UPDATE_COLUMNS_SCHEMA) + + with open(writer.closed_files[0].file_path, "r", encoding="utf-8") as f: + rows = list(csv.reader(f, dialect=csv.unix_dialect)) + # only header + assert len(rows) == 1 + + assert rows[0] == list(TABLE_UPDATE_COLUMNS_SCHEMA.keys()) diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 8fc4adc0c3..6d33b477fc 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -1,3 +1,4 @@ +import csv import posixpath from pathlib import Path @@ -7,6 +8,7 @@ from dlt.destinations.impl.filesystem.filesystem import FilesystemClient, LoadFilesystemJob from dlt.common.schema.typing import LOADS_TABLE_NAME +from tests.cases import arrow_table_all_data_types from tests.utils import skip_if_not_active skip_if_not_active("filesystem") @@ -97,6 +99,27 @@ def some_source(): assert len(replace_files) == 1 +def test_pipeline_csv_filesystem_destination() -> None: + os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" + # store locally + os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "file://_storage" + pipeline = dlt.pipeline( + pipeline_name="parquet_test_" + uniq_id(), + destination="filesystem", + dataset_name="parquet_test_" + uniq_id(), + ) + + item, _ = arrow_table_all_data_types("table", include_json=False, include_time=True) + info = pipeline.run(item, table_name="table", loader_file_format="csv") + info.raise_on_failed_jobs() + job = info.load_packages[0].jobs["completed_jobs"][0].file_path + assert job.endswith("csv") + with open(job, "r", encoding="utf-8") as f: + rows = list(csv.reader(f, dialect=csv.unix_dialect)) + # header + 3 data rows + assert len(rows) == 4 + + def test_pipeline_parquet_filesystem_destination() -> None: import pyarrow.parquet as pq # Module is evaluated by other tests diff --git a/tests/pipeline/test_pipeline_file_format_resolver.py b/tests/pipeline/test_pipeline_file_format_resolver.py index 588ad720a5..82afcd7dfb 100644 --- a/tests/pipeline/test_pipeline_file_format_resolver.py +++ b/tests/pipeline/test_pipeline_file_format_resolver.py @@ -47,7 +47,7 @@ def __init__(self) -> None: # check invalid input with pytest.raises(DestinationIncompatibleLoaderFileFormatException): - assert p._resolve_loader_file_format("some", "some", destcp, None, "csv") # type: ignore[arg-type] + assert p._resolve_loader_file_format("some", "some", destcp, None, "tsv") # type: ignore[arg-type] # check staging resolution with clear preference destcp.supported_staging_file_formats = ["jsonl", "insert_values", "parquet"] @@ -57,18 +57,18 @@ def __init__(self) -> None: # check invalid input with pytest.raises(DestinationIncompatibleLoaderFileFormatException): - p._resolve_loader_file_format("some", "some", destcp, stagecp, "csv") # type: ignore[arg-type] + p._resolve_loader_file_format("some", "some", destcp, stagecp, "tsv") # type: ignore[arg-type] # check staging resolution where preference does not match destcp.supported_staging_file_formats = ["insert_values", "parquet"] - destcp.preferred_staging_file_format = "csv" # type: ignore[assignment] + destcp.preferred_staging_file_format = "tsv" # type: ignore[assignment] stagecp.supported_loader_file_formats = ["jsonl", "insert_values", "parquet"] assert p._resolve_loader_file_format("some", "some", destcp, stagecp, None) == "insert_values" assert p._resolve_loader_file_format("some", "some", destcp, stagecp, "parquet") == "parquet" # check incompatible staging - destcp.supported_staging_file_formats = ["insert_values", "csv"] # type: ignore[list-item] - destcp.preferred_staging_file_format = "csv" # type: ignore[assignment] + destcp.supported_staging_file_formats = ["insert_values", "tsv"] # type: ignore[list-item] + destcp.preferred_staging_file_format = "tsv" # type: ignore[assignment] stagecp.supported_loader_file_formats = ["jsonl", "parquet"] with pytest.raises(DestinationIncompatibleLoaderFileFormatException): p._resolve_loader_file_format("some", "some", destcp, stagecp, None) From e8ad1f994b60d709c70089ce87b446a90aa12306 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 3 Apr 2024 23:02:44 +0200 Subject: [PATCH 04/22] adds postgres csv writer via COPY --- dlt/destinations/impl/destination/__init__.py | 7 +- .../impl/destination/configuration.py | 2 +- dlt/destinations/impl/destination/factory.py | 11 +-- dlt/destinations/impl/dummy/__init__.py | 2 +- dlt/destinations/impl/postgres/__init__.py | 2 +- dlt/destinations/impl/postgres/postgres.py | 39 +++++++-- tests/load/pipeline/test_postgres.py | 83 +++++++++++++++++++ 7 files changed, 127 insertions(+), 19 deletions(-) create mode 100644 tests/load/pipeline/test_postgres.py diff --git a/dlt/destinations/impl/destination/__init__.py b/dlt/destinations/impl/destination/__init__.py index 560c9d4eda..f985119f26 100644 --- a/dlt/destinations/impl/destination/__init__.py +++ b/dlt/destinations/impl/destination/__init__.py @@ -1,15 +1,14 @@ from typing import Optional -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.data_writers import TLoaderFileFormat +from dlt.common.destination import DestinationCapabilitiesContext, TLoaderFileFormat def capabilities( - preferred_loader_file_format: TLoaderFileFormat = "puae-jsonl", + preferred_loader_file_format: TLoaderFileFormat = "typed-jsonl", naming_convention: str = "direct", max_table_nesting: Optional[int] = 0, ) -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext.generic_capabilities(preferred_loader_file_format) - caps.supported_loader_file_formats = ["puae-jsonl", "parquet"] + caps.supported_loader_file_formats = ["typed-jsonl", "parquet"] caps.supports_ddl_transactions = False caps.supports_transactions = False caps.naming_convention = naming_convention diff --git a/dlt/destinations/impl/destination/configuration.py b/dlt/destinations/impl/destination/configuration.py index 30e54a8313..bad7e4e3cc 100644 --- a/dlt/destinations/impl/destination/configuration.py +++ b/dlt/destinations/impl/destination/configuration.py @@ -19,7 +19,7 @@ class CustomDestinationClientConfiguration(DestinationClientConfiguration): destination_type: Final[str] = dataclasses.field(default="destination", init=False, repr=False, compare=False) # type: ignore destination_callable: Optional[Union[str, TDestinationCallable]] = None # noqa: A003 - loader_file_format: TLoaderFileFormat = "puae-jsonl" + loader_file_format: TLoaderFileFormat = "typed-jsonl" batch_size: int = 10 skip_dlt_columns_and_tables: bool = True max_table_nesting: int = 0 diff --git a/dlt/destinations/impl/destination/factory.py b/dlt/destinations/impl/destination/factory.py index 7cca8f2202..8395c66ac8 100644 --- a/dlt/destinations/impl/destination/factory.py +++ b/dlt/destinations/impl/destination/factory.py @@ -5,19 +5,18 @@ from types import ModuleType from dlt.common.typing import AnyFun -from dlt.common.destination import Destination, DestinationCapabilitiesContext -from dlt.destinations.exceptions import DestinationTransientException +from dlt.common.destination import Destination, DestinationCapabilitiesContext, TLoaderFileFormat from dlt.common.configuration import known_sections, with_config, get_fun_spec from dlt.common.configuration.exceptions import ConfigurationValueError from dlt.common import logger +from dlt.common.utils import get_callable_name, is_inner_callable +from dlt.destinations.exceptions import DestinationTransientException from dlt.destinations.impl.destination.configuration import ( CustomDestinationClientConfiguration, TDestinationCallable, ) from dlt.destinations.impl.destination import capabilities -from dlt.common.data_writers import TLoaderFileFormat -from dlt.common.utils import get_callable_name, is_inner_callable if t.TYPE_CHECKING: from dlt.destinations.impl.destination.destination import DestinationClient @@ -38,7 +37,9 @@ class DestinationInfo(t.NamedTuple): class destination(Destination[CustomDestinationClientConfiguration, "DestinationClient"]): def capabilities(self) -> DestinationCapabilitiesContext: return capabilities( - preferred_loader_file_format=self.config_params.get("loader_file_format", "puae-jsonl"), + preferred_loader_file_format=self.config_params.get( + "loader_file_format", "typed-jsonl" + ), naming_convention=self.config_params.get("naming_convention", "direct"), max_table_nesting=self.config_params.get("max_table_nesting", None), ) diff --git a/dlt/destinations/impl/dummy/__init__.py b/dlt/destinations/impl/dummy/__init__.py index 37b2e77c8a..e09f7d07a9 100644 --- a/dlt/destinations/impl/dummy/__init__.py +++ b/dlt/destinations/impl/dummy/__init__.py @@ -21,7 +21,7 @@ def _configure(config: DummyClientConfiguration = config.value) -> DummyClientCo def capabilities() -> DestinationCapabilitiesContext: config = _configure() additional_formats: List[TLoaderFileFormat] = ( - ["reference"] if config.create_followup_jobs else [] + ["reference"] if config.create_followup_jobs else [] # type:ignore[list-item] ) caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = config.loader_file_format diff --git a/dlt/destinations/impl/postgres/__init__.py b/dlt/destinations/impl/postgres/__init__.py index 43e6af1996..bdb9297210 100644 --- a/dlt/destinations/impl/postgres/__init__.py +++ b/dlt/destinations/impl/postgres/__init__.py @@ -9,7 +9,7 @@ def capabilities() -> DestinationCapabilitiesContext: # https://www.postgresql.org/docs/current/limits.html caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "insert_values" - caps.supported_loader_file_formats = ["insert_values"] + caps.supported_loader_file_formats = ["insert_values", "csv"] caps.preferred_staging_file_format = None caps.supported_staging_file_formats = [] caps.escape_identifier = escape_postgres_identifier diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index f8fa3e341a..b585967196 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -1,23 +1,19 @@ from typing import ClassVar, Dict, Optional, Sequence, List, Any -from dlt.common.wei import EVM_DECIMAL_PRECISION -from dlt.common.destination.reference import NewLoadJob +from dlt.common.destination.reference import FollowupJob, LoadJob, NewLoadJob, TLoadJobState from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat +from dlt.common.storages.file_storage import FileStorage from dlt.destinations.sql_jobs import SqlStagingCopyJob, SqlJobParams - from dlt.destinations.insert_job_client import InsertValuesJobClient - from dlt.destinations.impl.postgres import capabilities from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.impl.postgres.configuration import PostgresClientConfiguration from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.type_mapping import TypeMapper - HINT_TO_POSTGRES_ATTR: Dict[TColumnHint, str] = {"unique": "UNIQUE"} @@ -104,6 +100,29 @@ def generate_sql( return sql +class PostgresCsvCopyJob(LoadJob, FollowupJob): + def __init__(self, table_name: str, file_path: str, sql_client: Psycopg2SqlClient) -> None: + super().__init__(FileStorage.get_file_name_from_file_path(file_path)) + + with FileStorage.open_zipsafe_ro(file_path, "rb") as f: + # all headers in first line + headers = f.readline().decode("utf-8").strip() + qualified_table_name = sql_client.make_qualified_table_name(table_name) + copy_sql = "COPY %s (%s) FROM STDIN WITH CSV DELIMITER ',' NULL ''" % ( + qualified_table_name, + headers, + ) + with sql_client.begin_transaction(): + with sql_client.native_connection.cursor() as cursor: + cursor.copy_expert(copy_sql, f, size=8192) + + def state(self) -> TLoadJobState: + return "completed" + + def exception(self) -> str: + raise NotImplementedError() + + class PostgresClient(InsertValuesJobClient): capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities() @@ -111,10 +130,16 @@ def __init__(self, schema: Schema, config: PostgresClientConfiguration) -> None: sql_client = Psycopg2SqlClient(config.normalize_dataset_name(schema), config.credentials) super().__init__(schema, config, sql_client) self.config: PostgresClientConfiguration = config - self.sql_client = sql_client + self.sql_client: Psycopg2SqlClient = sql_client self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {} self.type_mapper = PostgresTypeMapper(self.capabilities) + def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: + job = super().start_file_load(table, file_path, load_id) + if not job and file_path.endswith("csv"): + job = PostgresCsvCopyJob(table["name"], file_path, self.sql_client) + return job + def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: hints_str = " ".join( self.active_hints.get(h, "") diff --git a/tests/load/pipeline/test_postgres.py b/tests/load/pipeline/test_postgres.py new file mode 100644 index 0000000000..bf57bb0c4e --- /dev/null +++ b/tests/load/pipeline/test_postgres.py @@ -0,0 +1,83 @@ +import os +import hashlib +import random +from string import ascii_lowercase +import pytest + +from dlt.common.utils import uniq_id + +from tests.load.pipeline.utils import destinations_configs, DestinationTestConfiguration +from tests.cases import arrow_table_all_data_types, prepare_shuffled_tables +from tests.pipeline.utils import assert_data_table_counts, assert_load_info, load_tables_to_dicts + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres"]), + ids=lambda x: x.name, +) +def test_postgres_load_csv_from_arrow(destination_config: DestinationTestConfiguration) -> None: + os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" + pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), full_refresh=True) + table, shuffled_table, shuffled_removed_column = prepare_shuffled_tables() + + load_info = pipeline.run( + [shuffled_removed_column, shuffled_table, table], + table_name="table", + loader_file_format="csv", + ) + assert_load_info(load_info) + job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path + assert job.endswith("csv") + assert_data_table_counts(pipeline, {"table": 5432 * 3}) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres"]), + ids=lambda x: x.name, +) +def test_postgres_encoded_binary(destination_config: DestinationTestConfiguration) -> None: + import pyarrow + + os.environ["RESTORE_FROM_DESTINATION"] = "False" + blob = hashlib.sha3_256(random.choice(ascii_lowercase).encode()).digest() + # encode as \x... which postgres understands + blob_table = pyarrow.Table.from_pylist([{"hash": b"\\x" + blob.hex().encode("ascii")}]) + pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), full_refresh=True) + load_info = pipeline.run(blob_table, table_name="table", loader_file_format="csv") + assert_load_info(load_info) + job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path + assert job.endswith("csv") + + data = load_tables_to_dicts(pipeline, "table") + # print(bytes(data["table"][0]["hash"])) + # data in postgres equals unencoded blob + assert bytes(data["table"][0]["hash"]) == blob + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres"]), + ids=lambda x: x.name, +) +def test_postgres_empty_csv_from_arrow(destination_config: DestinationTestConfiguration) -> None: + os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" + os.environ["RESTORE_FROM_DESTINATION"] = "False" + pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), full_refresh=True) + table, _ = arrow_table_all_data_types("table", include_json=False) + + load_info = pipeline.run( + table.schema.empty_table(), table_name="table", loader_file_format="csv" + ) + assert_load_info(load_info) + job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path + assert job.endswith("csv") + assert_data_table_counts(pipeline, {"table": 0}) + with pipeline.sql_client() as client: + with client.execute_query('SELECT * FROM "table"') as cur: + columns = [col.name for col in cur.description] + assert len(cur.fetchall()) == 0 + + # all columns in order + assert columns == list(pipeline.default_schema.get_table_columns("table").keys()) From 9ac63fbb5cb96919a32056da2cfd0ea77f9ba76e Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 3 Apr 2024 23:03:47 +0200 Subject: [PATCH 05/22] improves arrow and parquet tests, adds arrow normalization edge cases --- dlt/common/libs/pyarrow.py | 69 ++++++--- tests/libs/test_parquet_writer.py | 86 ++++------- tests/load/pipeline/test_arrow_loading.py | 42 +++++- tests/pipeline/test_arrow_sources.py | 169 ++++++++++++++++++++-- 4 files changed, 273 insertions(+), 93 deletions(-) diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index c1fbfbff85..cb19c8c00a 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -1,6 +1,18 @@ from datetime import datetime, date # noqa: I251 from pendulum.tz import UTC -from typing import Any, Tuple, Optional, Union, Callable, Iterable, Iterator, Sequence, Tuple +from typing import ( + Any, + Dict, + Mapping, + Tuple, + Optional, + Union, + Callable, + Iterable, + Iterator, + Sequence, + Tuple, +) from dlt import version from dlt.common import pendulum @@ -23,7 +35,6 @@ "Install pyarrow to be allow to load arrow tables, panda frames and to use parquet files.", ) - TAnyArrowItem = Union[pyarrow.Table, pyarrow.RecordBatch] @@ -205,19 +216,12 @@ def rename_columns(item: TAnyArrowItem, new_column_names: Sequence[str]) -> TAny raise TypeError(f"Unsupported data item type {type(item)}") -def normalize_py_arrow_schema( - item: TAnyArrowItem, +def should_normalize_arrow_schema( + schema: pyarrow.Schema, columns: TTableSchemaColumns, naming: NamingConvention, - caps: DestinationCapabilitiesContext, -) -> TAnyArrowItem: - """Normalize arrow `item` schema according to the `columns`. - - 1. arrow schema field names will be normalized according to `naming` - 2. arrows columns will be reordered according to `columns` - 3. empty columns will be inserted if they are missing, types will be generated using `caps` - """ - rename_mapping = get_normalized_arrow_fields_mapping(item, naming) +) -> Tuple[bool, Mapping[str, str], Dict[str, str], TTableSchemaColumns]: + rename_mapping = get_normalized_arrow_fields_mapping(schema, naming) rev_mapping = {v: k for k, v in rename_mapping.items()} dlt_table_prefix = naming.normalize_table_identifier(DLT_NAME_PREFIX) @@ -230,12 +234,31 @@ def normalize_py_arrow_schema( } # check if nothing to rename - if list(rename_mapping.keys()) == list(rename_mapping.values()): - # check if nothing to reorder - if list(rename_mapping.keys())[: len(columns)] == list(columns.keys()): - return item + skip_normalize = ( + list(rename_mapping.keys()) == list(rename_mapping.values()) == list(columns.keys()) + ) + return not skip_normalize, rename_mapping, rev_mapping, columns + + +def normalize_py_arrow_item( + item: TAnyArrowItem, + columns: TTableSchemaColumns, + naming: NamingConvention, + caps: DestinationCapabilitiesContext, +) -> TAnyArrowItem: + """Normalize arrow `item` schema according to the `columns`. + 1. arrow schema field names will be normalized according to `naming` + 2. arrows columns will be reordered according to `columns` + 3. empty columns will be inserted if they are missing, types will be generated using `caps` + """ schema = item.schema + should_normalize, rename_mapping, rev_mapping, columns = should_normalize_arrow_schema( + schema, columns, naming + ) + if not should_normalize: + return item + new_fields = [] new_columns = [] @@ -268,10 +291,10 @@ def normalize_py_arrow_schema( return item.__class__.from_arrays(new_columns, schema=pyarrow.schema(new_fields)) -def get_normalized_arrow_fields_mapping(item: TAnyArrowItem, naming: NamingConvention) -> StrStr: +def get_normalized_arrow_fields_mapping(schema: pyarrow.Schema, naming: NamingConvention) -> StrStr: """Normalizes schema field names and returns mapping from original to normalized name. Raises on name clashes""" norm_f = naming.normalize_identifier - name_mapping = {n.name: norm_f(n.name) for n in item.schema} + name_mapping = {n.name: norm_f(n.name) for n in schema} # verify if names uniquely normalize normalized_names = set(name_mapping.values()) if len(name_mapping) != len(normalized_names): @@ -301,17 +324,17 @@ def py_arrow_to_table_schema_columns(schema: pyarrow.Schema) -> TTableSchemaColu return result -def get_row_count(parquet_file: TFileOrPath) -> int: - """Get the number of rows in a parquet file. +def get_parquet_metadata(parquet_file: TFileOrPath) -> Tuple[int, pyarrow.Schema]: + """Gets parquet file metadata (including row count and schema) Args: parquet_file (str): path to parquet file Returns: - int: number of rows + FileMetaData: file metadata """ with pyarrow.parquet.ParquetFile(parquet_file) as reader: - return reader.metadata.num_rows # type: ignore[no-any-return] + return reader.metadata.num_rows, reader.schema_arrow def is_arrow_item(item: Any) -> bool: diff --git a/tests/libs/test_parquet_writer.py b/tests/libs/test_parquet_writer.py index b1c19114fe..3b4239f2b0 100644 --- a/tests/libs/test_parquet_writer.py +++ b/tests/libs/test_parquet_writer.py @@ -3,37 +3,17 @@ import pyarrow.parquet as pq import datetime # noqa: 251 -from dlt.common import pendulum, Decimal +from dlt.common import pendulum, Decimal, json from dlt.common.configuration import inject_section -from dlt.common.data_writers.buffered import BufferedDataWriter from dlt.common.data_writers.writers import ParquetDataWriter from dlt.common.destination import TLoaderFileFormat, DestinationCapabilitiesContext from dlt.common.schema.utils import new_column from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.time import ensure_pendulum_date, ensure_pendulum_datetime -from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA, TABLE_ROW_ALL_DATA_TYPES -from tests.utils import TEST_STORAGE_ROOT, write_version, autouse_test_storage, preserve_environ - - -def get_writer( - _format: TLoaderFileFormat = "insert_values", - buffer_max_items: int = 10, - file_max_items: int = 10, - file_max_bytes: int = None, - _caps: DestinationCapabilitiesContext = None, -) -> BufferedDataWriter[ParquetDataWriter]: - caps = _caps or DestinationCapabilitiesContext.generic_capabilities() - caps.preferred_loader_file_format = _format - file_template = os.path.join(TEST_STORAGE_ROOT, f"{_format}.%s") - return BufferedDataWriter( - _format, - file_template, - buffer_max_items=buffer_max_items, - _caps=caps, - file_max_items=file_max_items, - file_max_bytes=file_max_bytes, - ) +from tests.common.data_writers.utils import get_writer +from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA, TABLE_ROW_ALL_DATA_TYPES_DATETIMES +from tests.utils import write_version, autouse_test_storage, preserve_environ def test_parquet_writer_schema_evolution_with_big_buffer() -> None: @@ -42,7 +22,7 @@ def test_parquet_writer_schema_evolution_with_big_buffer() -> None: c3 = new_column("col3", "text") c4 = new_column("col4", "text") - with get_writer("parquet") as writer: + with get_writer(ParquetDataWriter) as writer: writer.write_data_item( [{"col1": 1, "col2": 2, "col3": "3"}], {"col1": c1, "col2": c2, "col3": c3} ) @@ -65,7 +45,7 @@ def test_parquet_writer_schema_evolution_with_small_buffer() -> None: c3 = new_column("col3", "text") c4 = new_column("col4", "text") - with get_writer("parquet", buffer_max_items=4, file_max_items=50) as writer: + with get_writer(ParquetDataWriter, buffer_max_items=4, file_max_items=50) as writer: for _ in range(0, 20): writer.write_data_item( [{"col1": 1, "col2": 2, "col3": "3"}], {"col1": c1, "col2": c2, "col3": c3} @@ -92,7 +72,7 @@ def test_parquet_writer_json_serialization() -> None: c2 = new_column("col2", "bigint") c3 = new_column("col3", "complex") - with get_writer("parquet") as writer: + with get_writer(ParquetDataWriter) as writer: writer.write_data_item( [{"col1": 1, "col2": 2, "col3": {"hello": "dave"}}], {"col1": c1, "col2": c2, "col3": c3}, @@ -121,16 +101,11 @@ def test_parquet_writer_json_serialization() -> None: def test_parquet_writer_all_data_fields() -> None: - data = dict(TABLE_ROW_ALL_DATA_TYPES) - # fix dates to use pendulum - data["col4"] = ensure_pendulum_datetime(data["col4"]) # type: ignore[arg-type] - data["col10"] = ensure_pendulum_date(data["col10"]) # type: ignore[arg-type] - data["col11"] = pendulum.Time.fromisoformat(data["col11"]) # type: ignore[arg-type] - data["col4_precision"] = ensure_pendulum_datetime(data["col4_precision"]) # type: ignore[arg-type] - data["col11_precision"] = pendulum.Time.fromisoformat(data["col11_precision"]) # type: ignore[arg-type] + data = dict(TABLE_ROW_ALL_DATA_TYPES_DATETIMES) - with get_writer("parquet") as writer: - writer.write_data_item([data], TABLE_UPDATE_COLUMNS_SCHEMA) + # this modifies original `data` + with get_writer(ParquetDataWriter) as writer: + writer.write_data_item([dict(data)], TABLE_UPDATE_COLUMNS_SCHEMA) # We want to test precision for these fields is trimmed to millisecond data["col4_precision"] = data["col4_precision"].replace( # type: ignore[attr-defined] @@ -142,20 +117,23 @@ def test_parquet_writer_all_data_fields() -> None: with open(writer.closed_files[0].file_path, "rb") as f: table = pq.read_table(f) - for key, value in data.items(): - # what we have is pandas Timezone which is naive - actual = table.column(key).to_pylist()[0] - if isinstance(value, datetime.datetime): - actual = ensure_pendulum_datetime(actual) - assert actual == value - - assert table.schema.field("col1_precision").type == pa.int16() - # flavor=spark only writes ns precision timestamp, so this is expected - assert table.schema.field("col4_precision").type == pa.timestamp("ns") - assert table.schema.field("col5_precision").type == pa.string() - assert table.schema.field("col6_precision").type == pa.decimal128(6, 2) - assert table.schema.field("col7_precision").type == pa.binary(19) - assert table.schema.field("col11_precision").type == pa.time32("ms") + + for key, value in data.items(): + # what we have is pandas Timezone which is naive + actual = table.column(key).to_pylist()[0] + if isinstance(value, datetime.datetime): + actual = ensure_pendulum_datetime(actual) + if isinstance(value, dict): + actual = json.loads(actual) + assert actual == value + + assert table.schema.field("col1_precision").type == pa.int16() + # flavor=spark only writes ns precision timestamp, so this is expected + assert table.schema.field("col4_precision").type == pa.timestamp("ns") + assert table.schema.field("col5_precision").type == pa.string() + assert table.schema.field("col6_precision").type == pa.decimal128(6, 2) + assert table.schema.field("col7_precision").type == pa.binary(19) + assert table.schema.field("col11_precision").type == pa.time32("ms") def test_parquet_writer_items_file_rotation() -> None: @@ -163,7 +141,7 @@ def test_parquet_writer_items_file_rotation() -> None: "col1": new_column("col1", "bigint"), } - with get_writer("parquet", file_max_items=10) as writer: + with get_writer(ParquetDataWriter, file_max_items=10) as writer: for i in range(0, 100): writer.write_data_item([{"col1": i}], columns) @@ -178,7 +156,7 @@ def test_parquet_writer_size_file_rotation() -> None: "col1": new_column("col1", "bigint"), } - with get_writer("parquet", file_max_bytes=2**8, buffer_max_items=2) as writer: + with get_writer(ParquetDataWriter, file_max_bytes=2**8, buffer_max_items=2) as writer: for i in range(0, 100): writer.write_data_item([{"col1": i}], columns) @@ -194,7 +172,7 @@ def test_parquet_writer_config() -> None: os.environ["NORMALIZE__DATA_WRITER__TIMESTAMP_TIMEZONE"] = "America/New York" with inject_section(ConfigSectionContext(pipeline_name=None, sections=("normalize",))): - with get_writer("parquet", file_max_bytes=2**8, buffer_max_items=2) as writer: + with get_writer(ParquetDataWriter, file_max_bytes=2**8, buffer_max_items=2) as writer: for i in range(0, 5): writer.write_data_item( [{"col1": i, "col2": pendulum.now()}], @@ -219,7 +197,7 @@ def test_parquet_writer_schema_from_caps() -> None: caps.wei_precision = (156, 78) # will be trimmed to dec256 caps.timestamp_precision = 9 # nanoseconds - with get_writer("parquet", file_max_bytes=2**8, buffer_max_items=2) as writer: + with get_writer(ParquetDataWriter, file_max_bytes=2**8, buffer_max_items=2) as writer: for _ in range(0, 5): writer.write_data_item( [{"col1": Decimal("2617.27"), "col2": pendulum.now(), "col3": Decimal(2**250)}], diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 59cd90c535..2c649c18de 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -14,6 +14,7 @@ from dlt.common.utils import uniq_id from tests.load.utils import destinations_configs, DestinationTestConfiguration from tests.load.pipeline.utils import assert_table, assert_query_data, select_data +from tests.pipeline.utils import assert_load_info from tests.utils import preserve_environ from tests.cases import arrow_table_all_data_types, TArrowFormat @@ -26,19 +27,38 @@ ids=lambda x: x.name, ) @pytest.mark.parametrize("item_type", ["pandas", "table", "record_batch"]) -def test_load_item( +def test_load_arrow_item( item_type: Literal["pandas", "table", "record_batch"], destination_config: DestinationTestConfiguration, ) -> None: + # compression must be on for redshift + # os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_LOAD_ID"] = "True" os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_ID"] = "True" include_time = destination_config.destination not in ( "athena", "redshift", "databricks", - ) # athena/redshift can't load TIME columns from parquet + "synapse", + ) # athena/redshift can't load TIME columns + include_binary = not ( + destination_config.destination in ("redshift", "databricks") + and destination_config.file_format == "jsonl" + ) + include_decimal = not ( + destination_config.destination == "databricks" and destination_config.file_format == "jsonl" + ) + include_date = not ( + destination_config.destination == "databricks" and destination_config.file_format == "jsonl" + ) + item, records = arrow_table_all_data_types( - item_type, include_json=False, include_time=include_time + item_type, + include_json=False, + include_time=include_time, + include_decimal=include_decimal, + include_binary=include_binary, + include_date=include_date, ) pipeline = destination_config.setup_pipeline("arrow_" + uniq_id()) @@ -47,18 +67,28 @@ def test_load_item( def some_data(): yield item - load_info = pipeline.run(some_data(), loader_file_format=destination_config.file_format) + # use csv for postgres to get native arrow processing + file_format = ( + destination_config.file_format if destination_config.destination != "postgres" else "csv" + ) + + load_info = pipeline.run(some_data(), loader_file_format=file_format) + assert_load_info(load_info) # assert the table types some_table_columns = pipeline.default_schema.get_table("some_data")["columns"] assert some_table_columns["string"]["data_type"] == "text" assert some_table_columns["float"]["data_type"] == "double" assert some_table_columns["int"]["data_type"] == "bigint" assert some_table_columns["datetime"]["data_type"] == "timestamp" - assert some_table_columns["binary"]["data_type"] == "binary" - assert some_table_columns["decimal"]["data_type"] == "decimal" assert some_table_columns["bool"]["data_type"] == "bool" if include_time: assert some_table_columns["time"]["data_type"] == "time" + if include_binary: + assert some_table_columns["binary"]["data_type"] == "binary" + if include_decimal: + assert some_table_columns["decimal"]["data_type"] == "decimal" + if include_date: + assert some_table_columns["date"]["data_type"] == "date" qual_name = pipeline.sql_client().make_qualified_table_name("some_data") rows = [list(row) for row in select_data(pipeline, f"SELECT * FROM {qual_name}")] diff --git a/tests/pipeline/test_arrow_sources.py b/tests/pipeline/test_arrow_sources.py index 4991afa002..96159648ea 100644 --- a/tests/pipeline/test_arrow_sources.py +++ b/tests/pipeline/test_arrow_sources.py @@ -1,8 +1,7 @@ import os +from typing import Any import pytest - import pandas as pd -import numpy as np import os import io import pyarrow as pa @@ -14,7 +13,13 @@ from dlt.pipeline.exceptions import PipelineStepFailed -from tests.cases import arrow_format_from_pandas, arrow_table_all_data_types, TArrowFormat +from tests.cases import ( + arrow_format_from_pandas, + arrow_item_from_table, + arrow_table_all_data_types, + prepare_shuffled_tables, + TArrowFormat, +) from tests.utils import preserve_environ @@ -52,7 +57,7 @@ def some_data(): with norm_storage.extracted_packages.storage.open_file(extract_files[0], "rb") as f: extracted_bytes = f.read() - info = pipeline.normalize() + info = pipeline.normalize(loader_file_format="parquet") assert info.row_counts["some_data"] == len(records) @@ -311,13 +316,157 @@ def some_data(): pipeline.run(item, table_name="some_data").raise_on_failed_jobs() # should be able to load arrow with a new column - # TODO: uncomment when load_id fixed in normalizer - # item, records = arrow_table_all_data_types(item_type, num_rows=200) - # item = item.append_column("static_int", [[0] * 200]) - # pipeline.run(item, table_name="some_data").raise_on_failed_jobs() + item, records = arrow_table_all_data_types(item_type, num_rows=200) + item = item.append_column("static_int", [[0] * 200]) + pipeline.run(item, table_name="some_data").raise_on_failed_jobs() - # schema = pipeline.default_schema - # assert schema.tables['some_data']['columns']['static_int']['data_type'] == 'bigint' + schema = pipeline.default_schema + assert schema.tables["some_data"]["columns"]["static_int"]["data_type"] == "bigint" + + +@pytest.mark.parametrize("item_type", ["table", "pandas", "record_batch"]) +def test_normalize_reorder_columns_separate_packages(item_type: TArrowFormat) -> None: + os.environ["RESTORE_FROM_DESTINATION"] = "False" + table, shuffled_table, shuffled_removed_column = prepare_shuffled_tables() + + def _to_item(table: Any) -> Any: + return arrow_item_from_table(table, item_type) + + pipeline_name = "arrow_" + uniq_id() + # all arrows will be written to the same table in the destination + pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination="duckdb") + storage = pipeline._get_normalize_storage() + extract_info = pipeline.extract(_to_item(shuffled_removed_column), table_name="table") + job_file = extract_info.load_packages[0].jobs["new_jobs"][0].file_path + with storage.extracted_packages.storage.open_file(job_file, "rb") as f: + actual_tbl_no_binary = pa.parquet.read_table(f) + # schema must be same + assert actual_tbl_no_binary.schema.names == shuffled_removed_column.schema.names + assert actual_tbl_no_binary.schema.equals(shuffled_removed_column.schema) + # print(pipeline.default_schema.to_pretty_yaml()) + + extract_info = pipeline.extract(_to_item(shuffled_table), table_name="table") + job_file = extract_info.load_packages[0].jobs["new_jobs"][0].file_path + with storage.extracted_packages.storage.open_file(job_file, "rb") as f: + actual_tbl_shuffled = pa.parquet.read_table(f) + # shuffled has additional "binary column" which must be added at the end + shuffled_names = list(shuffled_table.schema.names) + shuffled_names.remove("binary") + shuffled_names.append("binary") + assert actual_tbl_shuffled.schema.names == shuffled_names + + extract_info = pipeline.extract(_to_item(table), table_name="table") + job_file = extract_info.load_packages[0].jobs["new_jobs"][0].file_path + with storage.extracted_packages.storage.open_file(job_file, "rb") as f: + actual_tbl = pa.parquet.read_table(f) + # orig table must be ordered exactly as shuffled table + assert actual_tbl.schema.names == shuffled_names + assert actual_tbl.schema.equals(actual_tbl_shuffled.schema) + + # now normalize everything to parquet + normalize_info = pipeline.normalize(loader_file_format="parquet") + print(normalize_info.asstr(verbosity=2)) + # we should have 3 load packages + assert len(normalize_info.load_packages) == 3 + assert normalize_info.row_counts["table"] == 5432 * 3 + + # load to duckdb + load_info = pipeline.load() + load_info.raise_on_failed_jobs() + + +@pytest.mark.parametrize("item_type", ["table", "pandas", "record_batch"]) +def test_normalize_reorder_columns_single_package(item_type: TArrowFormat) -> None: + os.environ["RESTORE_FROM_DESTINATION"] = "False" + # we do not want to rotate buffer + os.environ["DATA_WRITER__BUFFER_MAX_ITEMS"] = "100000" + table, shuffled_table, shuffled_removed_column = prepare_shuffled_tables() + + def _to_item(table: Any) -> Any: + return arrow_item_from_table(table, item_type) + + pipeline_name = "arrow_" + uniq_id() + # all arrows will be written to the same table in the destination + pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination="duckdb") + + # extract arrows one by one + extract_info = pipeline.extract( + [_to_item(shuffled_removed_column), _to_item(shuffled_table), _to_item(table)], + table_name="table", + ) + assert len(extract_info.load_packages) == 1 + # there was a schema change (binary column was added) + assert len(extract_info.load_packages[0].jobs["new_jobs"]) == 2 + + normalize_info = pipeline.normalize(loader_file_format="parquet") + assert len(normalize_info.load_packages) == 1 + assert normalize_info.row_counts["table"] == 5432 * 3 + # we have 2 jobs: one was imported and second one had to be normalized + assert len(normalize_info.load_packages[0].jobs["new_jobs"]) == 2 + load_storage = pipeline._get_load_storage() + for new_job in normalize_info.load_packages[0].jobs["new_jobs"]: + # all jobs must have the destination schemas + with load_storage.normalized_packages.storage.open_file(new_job.file_path, "rb") as f: + actual_tbl = pa.parquet.read_table(f) + shuffled_names = list(shuffled_table.schema.names) + # binary must be at the end + shuffled_names.remove("binary") + shuffled_names.append("binary") + assert actual_tbl.schema.names == shuffled_names + + pipeline.load().raise_on_failed_jobs() + + +@pytest.mark.parametrize("item_type", ["table", "pandas", "record_batch"]) +def test_normalize_reorder_columns_single_batch(item_type: TArrowFormat) -> None: + os.environ["RESTORE_FROM_DESTINATION"] = "False" + # we do not want to rotate buffer + os.environ["DATA_WRITER__BUFFER_MAX_ITEMS"] = "100000" + table, shuffled_table, shuffled_removed_column = prepare_shuffled_tables() + + def _to_item(table: Any) -> Any: + return arrow_item_from_table(table, item_type) + + pipeline_name = "arrow_" + uniq_id() + # all arrows will be written to the same table in the destination + pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination="duckdb") + + # extract arrows in a single batch. this should unify the schema and generate just a single file + # that can be directly imported + extract_info = pipeline.extract( + [[_to_item(shuffled_removed_column), _to_item(shuffled_table), _to_item(table)]], + table_name="table", + ) + assert len(extract_info.load_packages) == 1 + # all arrow tables got normalized to the same schema so no rotation + assert len(extract_info.load_packages[0].jobs["new_jobs"]) == 1 + + shuffled_names = list(shuffled_table.schema.names) + # binary must be at the end + shuffled_names.remove("binary") + shuffled_names.append("binary") + + storage = pipeline._get_normalize_storage() + job_file = extract_info.load_packages[0].jobs["new_jobs"][0].file_path + with storage.extracted_packages.storage.open_file(job_file, "rb") as f: + actual_tbl = pa.parquet.read_table(f) + # must be exactly shuffled_schema like in all other cases + assert actual_tbl.schema.names == shuffled_names + + normalize_info = pipeline.normalize(loader_file_format="parquet") + assert len(normalize_info.load_packages) == 1 + assert normalize_info.row_counts["table"] == 5432 * 3 + # one job below that was imported without normalization + assert len(normalize_info.load_packages[0].jobs["new_jobs"]) == 1 + load_storage = pipeline._get_load_storage() + for new_job in normalize_info.load_packages[0].jobs["new_jobs"]: + # all jobs must have the destination schemas + with load_storage.normalized_packages.storage.open_file(new_job.file_path, "rb") as f: + actual_tbl = pa.parquet.read_table(f) + assert len(actual_tbl) == 5432 * 3 + assert actual_tbl.schema.names == shuffled_names + + pipeline.load().raise_on_failed_jobs() @pytest.mark.parametrize("item_type", ["pandas", "table", "record_batch"]) From 45ed7ffaf34750baba2aaaaa0df9d7e8bdf913b4 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 3 Apr 2024 23:04:32 +0200 Subject: [PATCH 06/22] refactors extractors in extract, disables schema caches when processing multiple arrows --- dlt/extract/extract.py | 31 ++- dlt/extract/extractors.py | 112 +++++------ dlt/extract/storage.py | 49 ++--- dlt/extract/utils.py | 36 ++++ .../data_writers/test_buffered_writer.py | 102 +++++++--- .../data_writers/test_data_item_storage.py | 16 +- tests/extract/test_incremental.py | 188 +++++++++--------- tests/extract/utils.py | 6 +- 8 files changed, 297 insertions(+), 243 deletions(-) diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index 2fc4fd77aa..75f22bb802 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -2,14 +2,13 @@ from collections.abc import Sequence as C_Sequence from copy import copy import itertools -from typing import List, Set, Dict, Optional, Set, Any +from typing import List, Dict, Any import yaml from dlt.common.configuration.container import Container from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.specs import ConfigSectionContext, known_sections -from dlt.common.data_writers import TLoaderFileFormat -from dlt.common.data_writers.writers import EMPTY_DATA_WRITER_METRICS +from dlt.common.data_writers.writers import EMPTY_DATA_WRITER_METRICS, TDataItemFormat from dlt.common.pipeline import ( ExtractDataInfo, ExtractInfo, @@ -38,7 +37,8 @@ from dlt.extract.source import DltSource from dlt.extract.resource import DltResource from dlt.extract.storage import ExtractStorage -from dlt.extract.extractors import JsonLExtractor, ArrowExtractor, Extractor +from dlt.extract.extractors import ObjectExtractor, ArrowExtractor, Extractor +from dlt.extract.utils import get_data_item_format def data_to_sources( @@ -244,10 +244,10 @@ def _compute_metrics(self, load_id: str, source: DltSource) -> ExtractMetrics: } def _write_empty_files( - self, source: DltSource, extractors: Dict[TLoaderFileFormat, Extractor] + self, source: DltSource, extractors: Dict[TDataItemFormat, Extractor] ) -> None: schema = source.schema - json_extractor = extractors["puae-jsonl"] + json_extractor = extractors["object"] resources_with_items = set().union(*[e.resources_with_items for e in extractors.values()]) # find REPLACE resources that did not yield any pipe items and create empty jobs for them # NOTE: do not include tables that have never seen data @@ -296,13 +296,14 @@ def _extract_single_source( ) -> None: schema = source.schema collector = self.collector - extractors: Dict[TLoaderFileFormat, Extractor] = { - "puae-jsonl": JsonLExtractor( - load_id, self.extract_storage, schema, collector=collector + extractors: Dict[TDataItemFormat, Extractor] = { + "object": ObjectExtractor( + load_id, self.extract_storage.item_storages["object"], schema, collector=collector + ), + "arrow": ArrowExtractor( + load_id, self.extract_storage.item_storages["arrow"], schema, collector=collector ), - "arrow": ArrowExtractor(load_id, self.extract_storage, schema, collector=collector), } - last_item_format: Optional[TLoaderFileFormat] = None with collector(f"Extract {source.name}"): self._step_info_start_load_id(load_id) @@ -321,16 +322,10 @@ def _extract_single_source( delta = left_gens - curr_gens left_gens -= delta collector.update("Resources", delta) - signals.raise_if_signalled() - resource = source.resources[pipe_item.pipe.name] - # Fallback to last item's format or default (puae-jsonl) if the current item is an empty list - item_format = ( - Extractor.item_format(pipe_item.item) or last_item_format or "puae-jsonl" - ) + item_format = get_data_item_format(pipe_item.item) extractors[item_format].write_items(resource, pipe_item.item, pipe_item.meta) - last_item_format = item_format self._write_empty_files(source, extractors) if left_gens > 0: diff --git a/dlt/extract/extractors.py b/dlt/extract/extractors.py index b8e615aae4..c4b7653164 100644 --- a/dlt/extract/extractors.py +++ b/dlt/extract/extractors.py @@ -5,7 +5,6 @@ from dlt.common.configuration.inject import with_config from dlt.common.configuration.specs import BaseConfiguration, configspec from dlt.common.destination.capabilities import DestinationCapabilitiesContext -from dlt.common.data_writers import TLoaderFileFormat from dlt.common.exceptions import MissingDependencyException from dlt.common.runtime.collector import Collector, NULL_COLLECTOR @@ -22,7 +21,7 @@ from dlt.extract.hints import HintsMeta from dlt.extract.resource import DltResource from dlt.extract.items import TableNameMeta -from dlt.extract.storage import ExtractStorage, ExtractorItemStorage +from dlt.extract.storage import ExtractorItemStorage try: from dlt.common.libs import pyarrow @@ -49,8 +48,6 @@ def materialize_schema_item() -> MaterializedEmptyList: class Extractor: - file_format: TLoaderFileFormat - @configspec class ExtractorConfiguration(BaseConfiguration): _caps: Optional[DestinationCapabilitiesContext] = None @@ -59,7 +56,7 @@ class ExtractorConfiguration(BaseConfiguration): def __init__( self, load_id: str, - storage: ExtractStorage, + item_storage: ExtractorItemStorage, schema: Schema, collector: Collector = NULL_COLLECTOR, *, @@ -73,33 +70,12 @@ def __init__( self.resources_with_empty: Set[str] = set() """Track resources that received empty materialized list""" self.load_id = load_id + self.item_storage = item_storage self._table_contracts: Dict[str, TSchemaContractDict] = {} self._filtered_tables: Set[str] = set() self._filtered_columns: Dict[str, Dict[str, TSchemaEvolutionMode]] = {} - self._storage = storage self._caps = _caps or DestinationCapabilitiesContext.generic_capabilities() - @property - def storage(self) -> ExtractorItemStorage: - return self._storage.get_storage(self.file_format) - - @staticmethod - def item_format(items: TDataItems) -> Optional[TLoaderFileFormat]: - """Detect the loader file format of the data items based on type. - Currently this is either 'arrow' or 'puae-jsonl' - - Returns: - The loader file format or `None` if if can't be detected. - """ - for item in items if isinstance(items, list) else [items]: - # Assume all items in list are the same type - if (pyarrow and pyarrow.is_arrow_item(item)) or ( - pandas and isinstance(item, pandas.DataFrame) - ): - return "arrow" - return "puae-jsonl" - return None # Empty list is unknown format - def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> None: """Write `items` to `resource` optionally computing table schemas and revalidating/filtering data""" if isinstance(meta, HintsMeta): @@ -121,7 +97,7 @@ def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> No def write_empty_items_file(self, table_name: str) -> None: table_name = self.naming.normalize_table_identifier(table_name) - self.storage.write_empty_items_file(self.load_id, self.schema.name, table_name, None) + self.item_storage.write_empty_items_file(self.load_id, self.schema.name, table_name, None) def _get_static_table_name(self, resource: DltResource, meta: Any) -> Optional[str]: if resource._table_name_hint_fun: @@ -142,14 +118,14 @@ def _write_item( items: TDataItems, columns: TTableSchemaColumns = None, ) -> None: - new_rows_count = self.storage.write_data_item( + new_rows_count = self.item_storage.write_data_item( self.load_id, self.schema.name, table_name, items, columns ) self.collector.update(table_name, inc=new_rows_count) if new_rows_count > 0: self.resources_with_items.add(resource_name) else: - if isinstance(items, MaterializedEmptyList): + if isinstance(items, MaterializedEmptyList) or self.__class__ is ArrowExtractor: self.resources_with_empty.add(resource_name) def _write_to_dynamic_table(self, resource: DltResource, items: TDataItems) -> None: @@ -229,12 +205,14 @@ def _reset_contracts_cache(self) -> None: self._filtered_columns.clear() -class JsonLExtractor(Extractor): - file_format = "puae-jsonl" +class ObjectExtractor(Extractor): + """Extracts Python object data items into typed jsonl""" + + pass class ArrowExtractor(Extractor): - file_format = "arrow" + """Extracts arrow data items into parquet""" def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> None: static_table_name = self._get_static_table_name(resource, meta) @@ -256,12 +234,19 @@ def write_items(self, resource: DltResource, items: TDataItems, meta: Any) -> No ] super().write_items(resource, items, meta) + def _write_to_static_table( + self, resource: DltResource, table_name: str, items: TDataItems, meta: Any + ) -> None: + # contract cache not supported for arrow tables + self._reset_contracts_cache() + super()._write_to_static_table(resource, table_name, items, meta) + def _apply_contract_filters( self, item: "TAnyArrowItem", resource: DltResource, static_table_name: Optional[str] ) -> "TAnyArrowItem": """Removes the columns (discard value) or rows (discard rows) as indicated by contract filters.""" # convert arrow schema names into normalized names - rename_mapping = pyarrow.get_normalized_arrow_fields_mapping(item, self.naming) + rename_mapping = pyarrow.get_normalized_arrow_fields_mapping(item.schema, self.naming) # find matching columns and delete by original name table_name = static_table_name or self._get_dynamic_table_name(resource, item) filtered_columns = self._filtered_columns.get(table_name) @@ -301,38 +286,47 @@ def _write_item( columns = columns or self.schema.tables[table_name]["columns"] # Note: `items` is always a list here due to the conversion in `write_table` items = [ - pyarrow.normalize_py_arrow_schema(item, columns, self.naming, self._caps) + pyarrow.normalize_py_arrow_item(item, columns, self.naming, self._caps) for item in items ] + # write items one by one super()._write_item(table_name, resource_name, items, columns) def _compute_table( self, resource: DltResource, items: TDataItems, meta: Any ) -> TPartialTableSchema: - items = items[0] - computed_table = super()._compute_table(resource, items, Any) - - # Merge the columns to include primary_key and other hints that may be set on the resource - arrow_table = copy(computed_table) - arrow_table["columns"] = pyarrow.py_arrow_to_table_schema_columns(items.schema) - # normalize arrow table before merging - arrow_table = self.schema.normalize_table_identifiers(arrow_table) - # issue warnings when overriding computed with arrow - for col_name, column in arrow_table["columns"].items(): - if src_column := computed_table["columns"].get(col_name): - for hint_name, hint in column.items(): - if (src_hint := src_column.get(hint_name)) is not None: - if src_hint != hint: - logger.warning( - f"In resource: {resource.name}, when merging arrow schema on column" - f" {col_name}. The hint {hint_name} value {src_hint} defined in" - f" resource is overwritten from arrow with value {hint}." - ) - - # we must override the columns to preserve the order in arrow table - arrow_table["columns"] = update_dict_nested( - arrow_table["columns"], computed_table["columns"], keep_dst_values=True - ) + arrow_table: TTableSchema = None + + # several arrow tables will update the pipeline schema and we want that earlier + # arrow tables override the latter so the resultant schema is the same as if + # they are sent separately + for item in reversed(items): + computed_table = super()._compute_table(resource, item, Any) + # Merge the columns to include primary_key and other hints that may be set on the resource + if arrow_table: + utils.merge_table(computed_table, arrow_table) + else: + arrow_table = copy(computed_table) + arrow_table["columns"] = pyarrow.py_arrow_to_table_schema_columns(item.schema) + # normalize arrow table before merging + arrow_table = self.schema.normalize_table_identifiers(arrow_table) + # issue warnings when overriding computed with arrow + for col_name, column in arrow_table["columns"].items(): + if src_column := computed_table["columns"].get(col_name): + for hint_name, hint in column.items(): + if (src_hint := src_column.get(hint_name)) is not None: + if src_hint != hint: + logger.warning( + f"In resource: {resource.name}, when merging arrow schema on" + f" column {col_name}. The hint {hint_name} value" + f" {src_hint} defined in resource is overwritten from arrow" + f" with value {hint}." + ) + + # we must override the columns to preserve the order in arrow table + arrow_table["columns"] = update_dict_nested( + arrow_table["columns"], computed_table["columns"], keep_dst_values=True + ) return arrow_table diff --git a/dlt/extract/storage.py b/dlt/extract/storage.py index 251d7a5ce9..b76822a4f2 100644 --- a/dlt/extract/storage.py +++ b/dlt/extract/storage.py @@ -1,8 +1,7 @@ import os from typing import Dict, List -from dlt.common.data_writers import TLoaderFileFormat -from dlt.common.data_writers.writers import DataWriterMetrics +from dlt.common.data_writers import TDataItemFormat, DataWriterMetrics, DataWriter, FileWriterSpec from dlt.common.schema import Schema from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.storages import ( @@ -20,11 +19,9 @@ class ExtractorItemStorage(DataItemStorage): - load_file_type: TLoaderFileFormat - - def __init__(self, package_storage: PackageStorage) -> None: + def __init__(self, package_storage: PackageStorage, writer_spec: FileWriterSpec) -> None: """Data item storage using `storage` to manage load packages""" - super().__init__(self.load_file_type) + super().__init__(writer_spec) self.package_storage = package_storage def _get_data_item_path_template(self, load_id: str, _: str, table_name: str) -> str: @@ -35,14 +32,6 @@ def _get_data_item_path_template(self, load_id: str, _: str, table_name: str) -> return self.package_storage.storage.make_full_path(file_path) -class JsonLExtractorStorage(ExtractorItemStorage): - load_file_type: TLoaderFileFormat = "puae-jsonl" - - -class ArrowExtractorStorage(ExtractorItemStorage): - load_file_type: TLoaderFileFormat = "arrow" - - class ExtractStorage(NormalizeStorage): """Wrapper around multiple extractor storages with different file formats""" @@ -55,9 +44,13 @@ def __init__(self, config: NormalizeStorageConfiguration) -> None: self.new_packages = PackageStorage( FileStorage(os.path.join(self.storage.storage_path, self.new_packages_folder)), "new" ) - self._item_storages: Dict[TLoaderFileFormat, ExtractorItemStorage] = { - "puae-jsonl": JsonLExtractorStorage(self.new_packages), - "arrow": ArrowExtractorStorage(self.new_packages), + self.item_storages: Dict[TDataItemFormat, ExtractorItemStorage] = { + "object": ExtractorItemStorage( + self.new_packages, DataWriter.writer_spec_from_file_format("typed-jsonl", "object") + ), + "arrow": ExtractorItemStorage( + self.new_packages, DataWriter.writer_spec_from_file_format("parquet", "arrow") + ), } def create_load_package(self, schema: Schema, reuse_exiting_package: bool = True) -> str: @@ -81,21 +74,18 @@ def create_load_package(self, schema: Schema, reuse_exiting_package: bool = True self.new_packages.save_schema(load_id, schema) return load_id - def get_storage(self, loader_file_format: TLoaderFileFormat) -> ExtractorItemStorage: - return self._item_storages[loader_file_format] - def close_writers(self, load_id: str) -> None: - for storage in self._item_storages.values(): + for storage in self.item_storages.values(): storage.close_writers(load_id) def closed_files(self, load_id: str) -> List[DataWriterMetrics]: files = [] - for storage in self._item_storages.values(): + for storage in self.item_storages.values(): files.extend(storage.closed_files(load_id)) return files def remove_closed_files(self, load_id: str) -> None: - for storage in self._item_storages.values(): + for storage in self.item_storages.values(): storage.remove_closed_files(load_id) def commit_new_load_package(self, load_id: str, schema: Schema) -> None: @@ -117,16 +107,3 @@ def get_load_package_info(self, load_id: str) -> LoadPackageInfo: return self.new_packages.get_load_package_info(load_id) except LoadPackageNotFound: return self.extracted_packages.get_load_package_info(load_id) - - def write_data_item( - self, - file_format: TLoaderFileFormat, - load_id: str, - schema_name: str, - table_name: str, - item: TDataItems, - columns: TTableSchemaColumns, - ) -> None: - self.get_storage(file_format).write_data_item( - load_id, schema_name, table_name, item, columns - ) diff --git a/dlt/extract/utils.py b/dlt/extract/utils.py index 69edcab93d..2024796972 100644 --- a/dlt/extract/utils.py +++ b/dlt/extract/utils.py @@ -19,6 +19,7 @@ from collections.abc import Mapping as C_Mapping from functools import wraps, partial +from dlt.common.data_writers import TDataItemFormat from dlt.common.exceptions import MissingDependencyException from dlt.common.pipeline import reset_resource_state from dlt.common.schema.typing import TColumnNames, TAnySchemaColumns, TTableSchemaColumns @@ -42,6 +43,41 @@ pydantic = None +try: + from dlt.common.libs import pyarrow +except MissingDependencyException: + pyarrow = None + +try: + from dlt.common.libs.pandas import pandas +except MissingDependencyException: + pandas = None + + +def get_data_item_format(items: TDataItems) -> TDataItemFormat: + """Detect the format of the data item from `items`. + + Reverts to `object` for empty lists + + Returns: + The data file format. + """ + if not pyarrow and not pandas: + return "object" + + # Assume all items in list are the same type + try: + if isinstance(items, list): + items = items[0] + if (pyarrow and pyarrow.is_arrow_item(items)) or ( + pandas and isinstance(items, pandas.DataFrame) + ): + return "arrow" + except IndexError: + pass + return "object" + + def resolve_column_value( column_hint: TTableHintTemplate[TColumnNames], item: TDataItem ) -> Union[Any, List[Any]]: diff --git a/tests/extract/data_writers/test_buffered_writer.py b/tests/extract/data_writers/test_buffered_writer.py index aff49e06ac..a1b4be3999 100644 --- a/tests/extract/data_writers/test_buffered_writer.py +++ b/tests/extract/data_writers/test_buffered_writer.py @@ -1,22 +1,28 @@ import os import pytest import time -from typing import Iterator +from typing import Iterator, Type from dlt.common.data_writers.exceptions import BufferedDataWriterClosed -from dlt.common.data_writers.writers import DataWriterMetrics +from dlt.common.data_writers.writers import ( + DataWriter, + DataWriterMetrics, + InsertValuesWriter, + JsonlWriter, + ALL_WRITERS, +) from dlt.common.destination.capabilities import TLoaderFileFormat from dlt.common.schema.utils import new_column from dlt.common.storages.file_storage import FileStorage from dlt.common.typing import DictStrAny -from tests.common.data_writers.utils import ALL_WRITERS, get_writer +from tests.common.data_writers.utils import get_writer, ALL_OBJECT_WRITERS -@pytest.mark.parametrize("format_", ALL_WRITERS) -def test_write_no_item(format_: TLoaderFileFormat) -> None: - with get_writer(_format=format_) as writer: +@pytest.mark.parametrize("writer_type", ALL_WRITERS) +def test_write_no_item(writer_type: Type[DataWriter]) -> None: + with get_writer(writer=writer_type) as writer: pass assert writer.closed with pytest.raises(BufferedDataWriterClosed): @@ -28,7 +34,7 @@ def test_write_no_item(format_: TLoaderFileFormat) -> None: @pytest.mark.parametrize( "disable_compression", [True, False], ids=["no_compression", "compression"] ) -def test_rotation_on_schema_change(disable_compression: bool) -> None: +def test_rotation_with_buffer_on_schema_change(disable_compression: bool) -> None: c1 = new_column("col1", "bigint") c2 = new_column("col2", "bigint") c3 = new_column("col3", "text") @@ -47,7 +53,9 @@ def c3_doc(count: int) -> Iterator[DictStrAny]: return map(lambda x: {"col3": "col3_value"}, range(0, count)) # change schema before file first flush - with get_writer(disable_compression=disable_compression) as writer: + with get_writer( + InsertValuesWriter, file_max_items=100, disable_compression=disable_compression + ) as writer: writer.write_data_item(list(c1_doc(8)), t1) assert writer._current_columns == t1 # but different instance @@ -70,7 +78,7 @@ def c3_doc(count: int) -> Iterator[DictStrAny]: assert "1,0" in content[-1] # data would flush and schema change - with get_writer() as writer: + with get_writer(InsertValuesWriter, file_max_items=100) as writer: writer.write_data_item(list(c1_doc(9)), t1) old_file = writer._file_name writer.write_data_item(list(c2_doc(1)), t2) # rotates here @@ -83,7 +91,7 @@ def c3_doc(count: int) -> Iterator[DictStrAny]: assert writer._buffered_items == [] # file would rotate and schema change - with get_writer() as writer: + with get_writer(InsertValuesWriter, file_max_items=100) as writer: writer.file_max_items = 10 writer.write_data_item(list(c1_doc(9)), t1) old_file = writer._file_name @@ -97,7 +105,7 @@ def c3_doc(count: int) -> Iterator[DictStrAny]: assert writer._buffered_items == [] # schema change after flush rotates file - with get_writer() as writer: + with get_writer(InsertValuesWriter, file_max_items=100) as writer: writer.write_data_item(list(c1_doc(11)), t1) writer.write_data_item(list(c2_doc(1)), t2) assert len(writer.closed_files) == 1 @@ -122,6 +130,40 @@ def c3_doc(count: int) -> Iterator[DictStrAny]: assert writer.closed_files[1].items_count == 22 +@pytest.mark.parametrize( + "disable_compression", [False, True], ids=["no_compression", "compression"] +) +def test_rotation_on_schema_change(disable_compression: bool) -> None: + c1 = new_column("col1", "bigint") + c2 = new_column("col2", "bigint") + + t1 = {"col1": c1} + t2 = {"col2": c2, "col1": c1} + + def c1_doc(count: int) -> Iterator[DictStrAny]: + return map(lambda x: {"col1": x}, range(0, count)) + + def c2_doc(count: int) -> Iterator[DictStrAny]: + return map(lambda x: {"col1": x, "col2": x * 2 + 1}, range(0, count)) + + # change schema before file first flush + with get_writer( + writer=JsonlWriter, file_max_items=100, disable_compression=disable_compression + ) as writer: + # mock spec + writer._supports_schema_changes = writer.writer_spec.supports_schema_changes = "False" + # write 1 doc + writer.write_data_item(list(c1_doc(1)), t1) + # in buffer + assert writer._file is None + assert len(writer._buffered_items) == 1 + writer.write_data_item(list(c2_doc(1)), t2) + # flushed because we force rotation with buffer flush + assert writer._file is None + assert len(writer._buffered_items) == 1 + assert len(writer.closed_files) == 2 + + @pytest.mark.parametrize( "disable_compression", [True, False], ids=["no_compression", "compression"] ) @@ -139,7 +181,9 @@ def c2_doc(count: int) -> Iterator[DictStrAny]: return map(lambda x: {"col1": x, "col2": x * 2 + 1}, range(0, count)) # change schema before file first flush - with get_writer(_format="jsonl", disable_compression=disable_compression) as writer: + with get_writer( + writer=JsonlWriter, file_max_items=100, disable_compression=disable_compression + ) as writer: writer.write_data_item(list(c1_doc(15)), t1) # flushed assert writer._file is not None @@ -149,6 +193,7 @@ def c2_doc(count: int) -> Iterator[DictStrAny]: # only the initial 15 items written assert writer._writer.items_count == 15 # all written + assert len(writer.closed_files) == 1 with FileStorage.open_zipsafe_ro(writer.closed_files[-1].file_path, "r", encoding="utf-8") as f: content = f.readlines() assert content[-1] == '{"col1":1,"col2":3}\n' @@ -160,12 +205,12 @@ def c2_doc(count: int) -> Iterator[DictStrAny]: def test_writer_requiring_schema(disable_compression: bool) -> None: # assertion on flushing with pytest.raises(AssertionError): - with get_writer(disable_compression=disable_compression) as writer: + with get_writer(InsertValuesWriter, disable_compression=disable_compression) as writer: writer.write_data_item([{"col1": 1}], None) # just single schema is enough c1 = new_column("col1", "bigint") t1 = {"col1": c1} - with get_writer(disable_compression=disable_compression) as writer: + with get_writer(InsertValuesWriter, disable_compression=disable_compression) as writer: writer.write_data_item([{"col1": 1}], None) writer.write_data_item([{"col1": 1}], t1) @@ -174,7 +219,7 @@ def test_writer_requiring_schema(disable_compression: bool) -> None: "disable_compression", [True, False], ids=["no_compression", "compression"] ) def test_writer_optional_schema(disable_compression: bool) -> None: - with get_writer(_format="jsonl", disable_compression=disable_compression) as writer: + with get_writer(writer=JsonlWriter, disable_compression=disable_compression) as writer: writer.write_data_item([{"col1": 1}], None) writer.write_data_item([{"col1": 1}], None) @@ -182,13 +227,13 @@ def test_writer_optional_schema(disable_compression: bool) -> None: @pytest.mark.parametrize( "disable_compression", [True, False], ids=["no_compression", "compression"] ) -@pytest.mark.parametrize("format_", ALL_WRITERS - {"arrow"}) -def test_write_empty_file(disable_compression: bool, format_: TLoaderFileFormat) -> None: +@pytest.mark.parametrize("writer_type", ALL_OBJECT_WRITERS) +def test_write_empty_file(disable_compression: bool, writer_type: Type[DataWriter]) -> None: # just single schema is enough c1 = new_column("col1", "bigint") t1 = {"col1": c1} now = time.time() - with get_writer(format_, disable_compression=disable_compression) as writer: + with get_writer(writer_type, disable_compression=disable_compression) as writer: metrics = writer.write_empty_file(t1) assert len(writer.closed_files) == 1 assert os.path.abspath(metrics.file_path) @@ -200,10 +245,10 @@ def test_write_empty_file(disable_compression: bool, format_: TLoaderFileFormat) assert writer.closed_files[0] == metrics -@pytest.mark.parametrize("format_", ALL_WRITERS) -def test_import_file(format_: TLoaderFileFormat) -> None: +@pytest.mark.parametrize("writer_type", ALL_WRITERS) +def test_import_file(writer_type: Type[DataWriter]) -> None: now = time.time() - with get_writer(format_) as writer: + with get_writer(writer_type) as writer: # won't destroy the original metrics = writer.import_file( "tests/extract/cases/imported.any", DataWriterMetrics("", 1, 231, 0, 0) @@ -220,13 +265,13 @@ def test_import_file(format_: TLoaderFileFormat) -> None: @pytest.mark.parametrize( "disable_compression", [True, False], ids=["no_compression", "compression"] ) -@pytest.mark.parametrize("format_", ALL_WRITERS - {"arrow"}) -def test_gather_metrics(disable_compression: bool, format_: TLoaderFileFormat) -> None: +@pytest.mark.parametrize("writer_type", ALL_OBJECT_WRITERS) +def test_gather_metrics(disable_compression: bool, writer_type: Type[DataWriter]) -> None: now = time.time() c1 = new_column("col1", "bigint") t1 = {"col1": c1} with get_writer( - format_, disable_compression=disable_compression, buffer_max_items=2, file_max_items=2 + writer_type, disable_compression=disable_compression, buffer_max_items=2, file_max_items=2 ) as writer: time.sleep(0.55) count = writer.write_data_item([{"col1": 182812}, {"col1": -1}], t1) @@ -253,12 +298,15 @@ def test_gather_metrics(disable_compression: bool, format_: TLoaderFileFormat) - @pytest.mark.parametrize( "disable_compression", [True, False], ids=["no_compression", "compression"] ) -@pytest.mark.parametrize("format_", ALL_WRITERS - {"arrow"}) -def test_special_write_rotates(disable_compression: bool, format_: TLoaderFileFormat) -> None: +@pytest.mark.parametrize("writer_type", ALL_OBJECT_WRITERS) +def test_special_write_rotates(disable_compression: bool, writer_type: Type[DataWriter]) -> None: c1 = new_column("col1", "bigint") t1 = {"col1": c1} with get_writer( - format_, disable_compression=disable_compression, buffer_max_items=100, file_max_items=100 + writer_type, + disable_compression=disable_compression, + buffer_max_items=100, + file_max_items=100, ) as writer: writer.write_data_item([{"col1": 182812}, {"col1": -1}], t1) assert len(writer.closed_files) == 0 diff --git a/tests/extract/data_writers/test_data_item_storage.py b/tests/extract/data_writers/test_data_item_storage.py index 1e6327a3ba..feda51c229 100644 --- a/tests/extract/data_writers/test_data_item_storage.py +++ b/tests/extract/data_writers/test_data_item_storage.py @@ -1,14 +1,15 @@ import os +from typing import Type import pytest from dlt.common.configuration.container import Container -from dlt.common.data_writers.writers import DataWriterMetrics -from dlt.common.destination.capabilities import TLoaderFileFormat, DestinationCapabilitiesContext +from dlt.common.data_writers.writers import DataWriterMetrics, DataWriter +from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.schema.utils import new_column -from tests.common.data_writers.utils import ALL_WRITERS from dlt.common.storages.data_item_storage import DataItemStorage from tests.utils import TEST_STORAGE_ROOT +from tests.common.data_writers.utils import ALL_OBJECT_WRITERS class ItemTestStorage(DataItemStorage): @@ -16,12 +17,13 @@ def _get_data_item_path_template(self, load_id: str, schema_name: str, table_nam return os.path.join(TEST_STORAGE_ROOT, f"{load_id}.{schema_name}.{table_name}.%s") -@pytest.mark.parametrize("format_", ALL_WRITERS - {"arrow"}) -def test_write_items(format_: TLoaderFileFormat) -> None: +@pytest.mark.parametrize("writer_type", ALL_OBJECT_WRITERS) +def test_write_items(writer_type: Type[DataWriter]) -> None: + writer_spec = writer_type.writer_spec() with Container().injectable_context( - DestinationCapabilitiesContext.generic_capabilities(format_) + DestinationCapabilitiesContext.generic_capabilities(writer_spec.file_format) ): - item_storage = ItemTestStorage(format_) + item_storage = ItemTestStorage(writer_spec) c1 = new_column("col1", "bigint") t1 = {"col1": c1} count = item_storage.write_data_item( diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index a393706de7..d5775266f2 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -34,13 +34,13 @@ from tests.utils import ( data_item_length, data_to_item_format, - TDataItemFormat, - ALL_DATA_ITEM_FORMATS, + TestDataItemFormat, + ALL_TEST_DATA_ITEM_FORMATS, ) -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_single_items_last_value_state_is_updated(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_single_items_last_value_state_is_updated(item_type: TestDataItemFormat) -> None: data = [ {"created_at": 425}, {"created_at": 426}, @@ -57,8 +57,10 @@ def some_data(created_at=dlt.sources.incremental("created_at")): assert s["last_value"] == 426 -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_single_items_last_value_state_is_updated_transformer(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_single_items_last_value_state_is_updated_transformer( + item_type: TestDataItemFormat, +) -> None: data = [ {"created_at": 425}, {"created_at": 426}, @@ -76,8 +78,8 @@ def some_data(item, created_at=dlt.sources.incremental("created_at")): assert s["last_value"] == 426 -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_batch_items_last_value_state_is_updated(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_batch_items_last_value_state_is_updated(item_type: TestDataItemFormat) -> None: data1 = [{"created_at": i} for i in range(5)] data2 = [{"created_at": i} for i in range(5, 10)] @@ -98,8 +100,8 @@ def some_data(created_at=dlt.sources.incremental("created_at")): assert s["last_value"] == 9 -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_last_value_access_in_resource(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_last_value_access_in_resource(item_type: TestDataItemFormat) -> None: values = [] data = [{"created_at": i} for i in range(6)] @@ -117,8 +119,8 @@ def some_data(created_at=dlt.sources.incremental("created_at")): assert values == [None, 5] -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_unique_keys_are_deduplicated(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_unique_keys_are_deduplicated(item_type: TestDataItemFormat) -> None: data1 = [ {"created_at": 1, "id": "a"}, {"created_at": 2, "id": "b"}, @@ -157,8 +159,8 @@ def some_data(created_at=dlt.sources.incremental("created_at")): assert rows == [(1, "a"), (2, "b"), (3, "c"), (3, "d"), (3, "e"), (3, "f"), (4, "g")] -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_unique_rows_by_hash_are_deduplicated(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_unique_rows_by_hash_are_deduplicated(item_type: TestDataItemFormat) -> None: data1 = [ {"created_at": 1, "id": "a"}, {"created_at": 2, "id": "b"}, @@ -212,7 +214,7 @@ def some_data(created_at=dlt.sources.incremental("data.items[0].created_at")): @pytest.mark.parametrize("item_type", ["arrow", "pandas"]) -def test_nested_cursor_path_arrow_fails(item_type: TDataItemFormat) -> None: +def test_nested_cursor_path_arrow_fails(item_type: TestDataItemFormat) -> None: data = [{"data": {"items": [{"created_at": 2}]}}] source_items = data_to_item_format(item_type, data) @@ -229,8 +231,8 @@ def some_data(created_at=dlt.sources.incremental("data.items[0].created_at")): assert ex.exception.json_path == "data.items[0].created_at" -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_explicit_initial_value(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_explicit_initial_value(item_type: TestDataItemFormat) -> None: @dlt.resource def some_data(created_at=dlt.sources.incremental("created_at")): data = [{"created_at": created_at.last_value}] @@ -245,8 +247,8 @@ def some_data(created_at=dlt.sources.incremental("created_at")): assert s["last_value"] == 4242 -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_explicit_incremental_instance(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_explicit_incremental_instance(item_type: TestDataItemFormat) -> None: data = [{"inserted_at": 242, "some_uq": 444}] source_items = data_to_item_format(item_type, data) @@ -263,7 +265,7 @@ def some_data(incremental=dlt.sources.incremental("created_at", initial_value=0) @dlt.resource def some_data_from_config( call_no: int, - item_type: TDataItemFormat, + item_type: TestDataItemFormat, created_at: Optional[dlt.sources.incremental[str]] = dlt.secrets.value, ): assert created_at.cursor_path == "created_at" @@ -279,8 +281,8 @@ def some_data_from_config( yield from source_items -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_optional_incremental_from_config(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_optional_incremental_from_config(item_type: TestDataItemFormat) -> None: os.environ["SOURCES__TEST_INCREMENTAL__SOME_DATA_FROM_CONFIG__CREATED_AT__CURSOR_PATH"] = ( "created_at" ) @@ -293,8 +295,8 @@ def test_optional_incremental_from_config(item_type: TDataItemFormat) -> None: p.extract(some_data_from_config(2, item_type)) -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_optional_incremental_not_passed(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_optional_incremental_not_passed(item_type: TestDataItemFormat) -> None: """Resource still runs when no incremental is passed""" data = [1, 2, 3] source_items = data_to_item_format(item_type, data) @@ -314,7 +316,7 @@ class OptionalIncrementalConfig(BaseConfiguration): @dlt.resource(spec=OptionalIncrementalConfig) def optional_incremental_arg_resource( - item_type: TDataItemFormat, incremental: Optional[dlt.sources.incremental[Any]] = None + item_type: TestDataItemFormat, incremental: Optional[dlt.sources.incremental[Any]] = None ) -> Any: data = [1, 2, 3] source_items = data_to_item_format(item_type, data) @@ -322,8 +324,8 @@ def optional_incremental_arg_resource( yield source_items -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_optional_arg_from_spec_not_passed(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_optional_arg_from_spec_not_passed(item_type: TestDataItemFormat) -> None: p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(optional_incremental_arg_resource(item_type)) @@ -336,7 +338,7 @@ class SomeDataOverrideConfiguration(BaseConfiguration): # provide what to inject via spec. the spec contain the default @dlt.resource(spec=SomeDataOverrideConfiguration) def some_data_override_config( - item_type: TDataItemFormat, created_at: dlt.sources.incremental[str] = dlt.config.value + item_type: TestDataItemFormat, created_at: dlt.sources.incremental[str] = dlt.config.value ): assert created_at.cursor_path == "created_at" assert created_at.initial_value == "2000-02-03T00:00:00Z" @@ -345,8 +347,8 @@ def some_data_override_config( yield from source_items -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_override_initial_value_from_config(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_override_initial_value_from_config(item_type: TestDataItemFormat) -> None: # use the shortest possible config version # os.environ['SOURCES__TEST_INCREMENTAL__SOME_DATA_OVERRIDE_CONFIG__CREATED_AT__INITIAL_VALUE'] = '2000-02-03T00:00:00Z' os.environ["CREATED_AT__INITIAL_VALUE"] = "2000-02-03T00:00:00Z" @@ -355,8 +357,8 @@ def test_override_initial_value_from_config(item_type: TDataItemFormat) -> None: p.extract(some_data_override_config(item_type)) -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_override_primary_key_in_pipeline(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_override_primary_key_in_pipeline(item_type: TestDataItemFormat) -> None: """Primary key hint passed to pipeline is propagated through apply_hints""" data = [{"created_at": 22, "id": 2, "other_id": 5}, {"created_at": 22, "id": 2, "other_id": 6}] source_items = data_to_item_format(item_type, data) @@ -372,8 +374,8 @@ def some_data(created_at=dlt.sources.incremental("created_at")): p.extract(some_data, primary_key=["id", "other_id"]) -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_composite_primary_key(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_composite_primary_key(item_type: TestDataItemFormat) -> None: data = [ {"created_at": 1, "isrc": "AAA", "market": "DE"}, {"created_at": 2, "isrc": "BBB", "market": "DE"}, @@ -412,8 +414,8 @@ def some_data(created_at=dlt.sources.incremental("created_at")): assert set(rows) == expected -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_last_value_func_min(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_last_value_func_min(item_type: TestDataItemFormat) -> None: data = [ {"created_at": 10}, {"created_at": 11}, @@ -456,8 +458,8 @@ def some_data(created_at=dlt.sources.incremental("created_at", last_value_func=l assert s["last_value"] == 11 -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_cursor_datetime_type(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_datetime_type(item_type: TestDataItemFormat) -> None: initial_value = pendulum.now() data = [ {"created_at": initial_value + timedelta(minutes=1)}, @@ -482,8 +484,8 @@ def some_data(created_at=dlt.sources.incremental("created_at", initial_value)): assert s["last_value"] == initial_value + timedelta(minutes=4) -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_descending_order_unique_hashes(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_descending_order_unique_hashes(item_type: TestDataItemFormat) -> None: """Resource returns items in descending order but using `max` last value function. Only hash matching last_value are stored. """ @@ -509,8 +511,8 @@ def some_data(created_at=dlt.sources.incremental("created_at", 20)): assert list(some_data()) == [] -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_unique_keys_json_identifiers(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_unique_keys_json_identifiers(item_type: TestDataItemFormat) -> None: """Uses primary key name that is matching the name of the JSON element in the original namespace but gets converted into destination namespace""" @dlt.resource(primary_key="DelTa") @@ -542,8 +544,8 @@ def some_data(last_timestamp=dlt.sources.incremental("ts")): assert rows2[-1][0] == 9 -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_missing_primary_key(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_missing_primary_key(item_type: TestDataItemFormat) -> None: @dlt.resource(primary_key="DELTA") def some_data(last_timestamp=dlt.sources.incremental("ts")): data = [{"delta": i, "ts": pendulum.now().add(days=i).timestamp()} for i in range(-10, 10)] @@ -555,8 +557,8 @@ def some_data(last_timestamp=dlt.sources.incremental("ts")): assert py_ex.value.primary_key_column == "DELTA" -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_missing_cursor_field(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_missing_cursor_field(item_type: TestDataItemFormat) -> None: os.environ["COMPLETED_PROB"] = "1.0" # make it complete immediately @dlt.resource @@ -637,13 +639,13 @@ def some_data( assert list(some_data(last_timestamp=dlt.sources.incremental.EMPTY)) == [1] -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_filter_processed_items(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_filter_processed_items(item_type: TestDataItemFormat) -> None: """Checks if already processed items are filtered out""" @dlt.resource def standalone_some_data( - item_type: TDataItemFormat, now=None, last_timestamp=dlt.sources.incremental("timestamp") + item_type: TestDataItemFormat, now=None, last_timestamp=dlt.sources.incremental("timestamp") ): data = [ {"delta": i, "timestamp": (now or pendulum.now()).add(days=i).timestamp()} @@ -706,8 +708,8 @@ def some_data(step, last_timestamp=dlt.sources.incremental("ts")): p.run(r, destination="duckdb") -@pytest.mark.parametrize("item_type", set(ALL_DATA_ITEM_FORMATS) - {"json"}) -def test_start_value_set_to_last_value_arrow(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", set(ALL_TEST_DATA_ITEM_FORMATS) - {"json"}) +def test_start_value_set_to_last_value_arrow(item_type: TestDataItemFormat) -> None: p = dlt.pipeline(pipeline_name=uniq_id(), destination="duckdb") now = pendulum.now() @@ -733,14 +735,14 @@ def some_data(first: bool, last_timestamp=dlt.sources.incremental("ts")): p.run(some_data(False)) -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_replace_resets_state(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_replace_resets_state(item_type: TestDataItemFormat) -> None: p = dlt.pipeline(pipeline_name=uniq_id(), destination="duckdb") now = pendulum.now() @dlt.resource def standalone_some_data( - item_type: TDataItemFormat, now=None, last_timestamp=dlt.sources.incremental("timestamp") + item_type: TestDataItemFormat, now=None, last_timestamp=dlt.sources.incremental("timestamp") ): data = [ {"delta": i, "timestamp": (now or pendulum.now()).add(days=i).timestamp()} @@ -817,8 +819,8 @@ def child(item): assert extracted[child._pipe.parent.name].write_disposition == "append" -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_incremental_as_transform(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_incremental_as_transform(item_type: TestDataItemFormat) -> None: now = pendulum.now().timestamp() @dlt.resource @@ -841,8 +843,8 @@ def some_data(): assert len(info.loads_ids) == 1 -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_incremental_explicit_disable_unique_check(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_incremental_explicit_disable_unique_check(item_type: TestDataItemFormat) -> None: @dlt.resource(primary_key="delta") def some_data(last_timestamp=dlt.sources.incremental("ts", primary_key=())): data = [{"delta": i, "ts": pendulum.now().timestamp()} for i in range(-10, 10)] @@ -856,8 +858,8 @@ def some_data(last_timestamp=dlt.sources.incremental("ts", primary_key=())): assert s.state["incremental"]["ts"]["unique_hashes"] == [] -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_apply_hints_incremental(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_apply_hints_incremental(item_type: TestDataItemFormat) -> None: p = dlt.pipeline(pipeline_name=uniq_id()) data = [{"created_at": 1}, {"created_at": 2}, {"created_at": 3}] source_items = data_to_item_format(item_type, data) @@ -974,8 +976,8 @@ def _get_shuffled_events( assert [e for e in all_events if e["type"] == "WatchEvent"] == watch_events -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_timezone_naive_datetime(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_timezone_naive_datetime(item_type: TestDataItemFormat) -> None: """Resource has timezone naive datetime objects, but incremental stored state is converted to tz aware pendulum dates. Can happen when loading e.g. from sql database""" start_dt = datetime.now() @@ -1081,7 +1083,7 @@ def some_data( @dlt.resource def endless_sequence( - item_type: TDataItemFormat, + item_type: TestDataItemFormat, updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( "updated_at", initial_value=1 ), @@ -1093,8 +1095,8 @@ def endless_sequence( yield from source_items -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_chunked_ranges(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_chunked_ranges(item_type: TestDataItemFormat) -> None: """Load chunked ranges with end value along with incremental""" pipeline = dlt.pipeline(pipeline_name="incremental_" + uniq_id(), destination="duckdb") @@ -1146,8 +1148,8 @@ def test_chunked_ranges(item_type: TDataItemFormat) -> None: assert items == expected_range -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_end_value_with_batches(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_end_value_with_batches(item_type: TestDataItemFormat) -> None: """Ensure incremental with end_value works correctly when resource yields lists instead of single items""" @dlt.resource @@ -1195,8 +1197,8 @@ def batched_sequence( assert items == list(range(1, 14)) -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_load_with_end_value_does_not_write_state(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_load_with_end_value_does_not_write_state(item_type: TestDataItemFormat) -> None: """When loading chunk with initial/end value range. The resource state is untouched.""" pipeline = dlt.pipeline(pipeline_name="incremental_" + uniq_id(), destination="duckdb") @@ -1209,8 +1211,8 @@ def test_load_with_end_value_does_not_write_state(item_type: TDataItemFormat) -> assert pipeline.state.get("sources") is None -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_end_value_initial_value_errors(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_end_value_initial_value_errors(item_type: TestDataItemFormat) -> None: @dlt.resource def some_data( updated_at: dlt.sources.incremental[int] = dlt.sources.incremental("updated_at"), @@ -1264,8 +1266,8 @@ def custom_last_value(items): ) -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_out_of_range_flags(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_out_of_range_flags(item_type: TestDataItemFormat) -> None: """Test incremental.start_out_of_range / end_out_of_range flags are set when items are filtered out""" @dlt.resource @@ -1341,8 +1343,8 @@ def ascending_single_item( pipeline.extract(ascending_single_item()) -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_async_row_order_out_of_range(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_async_row_order_out_of_range(item_type: TestDataItemFormat) -> None: @dlt.resource async def descending( updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( @@ -1358,8 +1360,8 @@ async def descending( assert data_item_length(data) == 48 - 10 + 1 # both bounds included -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_parallel_row_order_out_of_range(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_parallel_row_order_out_of_range(item_type: TestDataItemFormat) -> None: """Test automatic generator close for ordered rows""" @dlt.resource(parallelized=True) @@ -1376,8 +1378,8 @@ def descending( assert data_item_length(data) == 48 - 10 + 1 # both bounds included -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_transformer_row_order_out_of_range(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_transformer_row_order_out_of_range(item_type: TestDataItemFormat) -> None: out_of_range = [] @dlt.transformer @@ -1401,8 +1403,8 @@ def descending( assert len(out_of_range) == 3 -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_row_order_out_of_range(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_row_order_out_of_range(item_type: TestDataItemFormat) -> None: """Test automatic generator close for ordered rows""" @dlt.resource @@ -1456,14 +1458,14 @@ def ascending_desc( assert data_item_length(data) == 45 - 22 -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) @pytest.mark.parametrize("order", ["random", "desc", "asc"]) @pytest.mark.parametrize("primary_key", [[], None, "updated_at"]) @pytest.mark.parametrize( "deterministic", (True, False), ids=("deterministic-record", "non-deterministic-record") ) def test_unique_values_unordered_rows( - item_type: TDataItemFormat, order: str, primary_key: Any, deterministic: bool + item_type: TestDataItemFormat, order: str, primary_key: Any, deterministic: bool ) -> None: @dlt.resource(primary_key=primary_key) def random_ascending_chunks( @@ -1502,13 +1504,13 @@ def random_ascending_chunks( assert pipeline.last_trace.last_normalize_info.row_counts["random_ascending_chunks"] == rows -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) @pytest.mark.parametrize("primary_key", [[], None, "updated_at"]) # [], None, @pytest.mark.parametrize( "deterministic", (True, False), ids=("deterministic-record", "non-deterministic-record") ) def test_carry_unique_hashes( - item_type: TDataItemFormat, primary_key: Any, deterministic: bool + item_type: TestDataItemFormat, primary_key: Any, deterministic: bool ) -> None: # each day extends list of hashes and removes duplicates until the last day @@ -1593,8 +1595,8 @@ def _assert_state(r_: DltResource, day: int, info: NormalizeInfo) -> None: _assert_state(r_, 4, pipeline.last_trace.last_normalize_info) -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_get_incremental_value_type(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_get_incremental_value_type(item_type: TestDataItemFormat) -> None: assert dlt.sources.incremental("id").get_incremental_value_type() is Any assert dlt.sources.incremental("id", initial_value=0).get_incremental_value_type() is int assert dlt.sources.incremental("id", initial_value=None).get_incremental_value_type() is Any @@ -1669,8 +1671,8 @@ def test_type_5( assert r.incremental._incremental.get_incremental_value_type() is Any -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_join_env_scheduler(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_join_env_scheduler(item_type: TestDataItemFormat) -> None: @dlt.resource def test_type_2( updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( @@ -1696,8 +1698,8 @@ def test_type_2( assert data_item_to_list(item_type, result) == [{"updated_at": 2}] -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_join_env_scheduler_pipeline(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_join_env_scheduler_pipeline(item_type: TestDataItemFormat) -> None: @dlt.resource def test_type_2( updated_at: dlt.sources.incremental[int] = dlt.sources.incremental( @@ -1729,8 +1731,8 @@ def test_type_2( pipeline.extract(r) -@pytest.mark.parametrize("item_type", ALL_DATA_ITEM_FORMATS) -def test_allow_external_schedulers(item_type: TDataItemFormat) -> None: +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_allow_external_schedulers(item_type: TestDataItemFormat) -> None: @dlt.resource() def test_type_2( updated_at: dlt.sources.incremental[int] = dlt.sources.incremental("updated_at"), diff --git a/tests/extract/utils.py b/tests/extract/utils.py index 170781ba3c..5239c38de3 100644 --- a/tests/extract/utils.py +++ b/tests/extract/utils.py @@ -8,7 +8,7 @@ from dlt.extract.extract import ExtractStorage from dlt.extract.items import ItemTransform -from tests.utils import TDataItemFormat +from tests.utils import TestDataItemFormat def expect_extracted_file( @@ -46,7 +46,7 @@ def expect_extracted_file( class AssertItems(ItemTransform[TDataItem]): - def __init__(self, expected_items: Any, item_type: TDataItemFormat = "json") -> None: + def __init__(self, expected_items: Any, item_type: TestDataItemFormat = "json") -> None: self.expected_items = expected_items self.item_type = item_type @@ -55,7 +55,7 @@ def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: return item -def data_item_to_list(from_type: TDataItemFormat, values: List[TDataItem]): +def data_item_to_list(from_type: TestDataItemFormat, values: List[TDataItem]): if from_type in ["arrow", "arrow-batch"]: return values[0].to_pylist() elif from_type == "pandas": From fae59cdf5806634f2422f4fea9e8f399173ec09a Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 3 Apr 2024 23:05:16 +0200 Subject: [PATCH 07/22] refactors item normalizers, adds arrow normalization, improves logging --- dlt/normalize/items_normalizers.py | 88 +++++++++++++++++++++--------- dlt/normalize/normalize.py | 86 ++++++++++++++--------------- tests/normalize/test_normalize.py | 77 +++++++++++++++++++------- 3 files changed, 164 insertions(+), 87 deletions(-) diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index fc1e152ff2..bf4073ddbf 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -3,6 +3,7 @@ from dlt.common import json, logger from dlt.common.data_writers import DataWriterMetrics +from dlt.common.data_writers.writers import ArrowToObjectAdapter from dlt.common.json import custom_pua_decode, may_have_pua from dlt.common.runtime import signals from dlt.common.schema.typing import TSchemaEvolutionMode, TTableSchemaColumns, TSchemaContractDict @@ -11,6 +12,7 @@ NormalizeStorage, LoadStorage, ) +from dlt.common.storages.data_item_storage import DataItemStorage from dlt.common.storages.load_package import ParsedLoadJobFileName from dlt.common.typing import DictStrAny, TDataItem from dlt.common.schema import TSchemaUpdate, Schema @@ -30,13 +32,13 @@ class ItemsNormalizer: def __init__( self, - load_storage: LoadStorage, + item_storage: DataItemStorage, normalize_storage: NormalizeStorage, schema: Schema, load_id: str, config: NormalizeConfiguration, ) -> None: - self.load_storage = load_storage + self.item_storage = item_storage self.normalize_storage = normalize_storage self.schema = schema self.load_id = load_id @@ -49,13 +51,13 @@ def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSch class JsonLItemsNormalizer(ItemsNormalizer): def __init__( self, - load_storage: LoadStorage, + item_storage: DataItemStorage, normalize_storage: NormalizeStorage, schema: Schema, load_id: str, config: NormalizeConfiguration, ) -> None: - super().__init__(load_storage, normalize_storage, schema, load_id, config) + super().__init__(item_storage, normalize_storage, schema, load_id, config) self._table_contracts: Dict[str, TSchemaContractDict] = {} self._filtered_tables: Set[str] = set() self._filtered_tables_columns: Dict[str, Dict[str, TSchemaEvolutionMode]] = {} @@ -174,7 +176,7 @@ def _normalize_chunk( # will be useful if we implement bad data sending to a table # we skip write when discovering schema for empty file if not skip_write: - self.load_storage.write_data_item( + self.item_storage.write_data_item( self.load_id, schema_name, table_name, row, columns ) except StopIteration: @@ -211,7 +213,7 @@ def __call__( root_table_name, [{}], False, skip_write=True ) schema_updates.append(partial_update) - self.load_storage.write_empty_items_file( + self.item_storage.write_empty_items_file( self.load_id, self.schema.name, root_table_name, @@ -224,7 +226,7 @@ def __call__( return schema_updates -class ParquetItemsNormalizer(ItemsNormalizer): +class ArrowItemsNormalizer(ItemsNormalizer): REWRITE_ROW_GROUPS = 1 def _write_with_dlt_columns( @@ -279,7 +281,10 @@ def _write_with_dlt_columns( ) items_count = 0 - as_py = self.load_storage.loader_file_format != "arrow" + columns_schema = schema.get_table_columns(root_table_name) + # if we use adapter to convert arrow to dicts, then normalization is not necessary + may_normalize = not issubclass(self.item_storage.writer_cls, ArrowToObjectAdapter) + should_normalize: bool = None with self.normalize_storage.extracted_packages.storage.open_file( extracted_items_file, "rb" ) as f: @@ -287,22 +292,35 @@ def _write_with_dlt_columns( f, new_columns, row_groups_per_read=self.REWRITE_ROW_GROUPS ): items_count += batch.num_rows - if as_py: - # Write python rows to jsonl, insert-values, etc... storage - batch = batch.to_pylist() - self.load_storage.write_data_item( + # we may need to normalize + if may_normalize and should_normalize is None: + should_normalize, _, _, _ = pyarrow.should_normalize_arrow_schema( + batch.schema, columns_schema, schema.naming + ) + if should_normalize: + logger.info( + f"When writing arrow table to {root_table_name} the schema requires" + " normalization because its shape does not match the actual schema of" + " destination table. Arrow table columns will be reordered and missing" + " columns will be added if needed." + ) + if should_normalize: + batch = pyarrow.normalize_py_arrow_item( + batch, columns_schema, schema.naming, self.config.destination_capabilities + ) + self.item_storage.write_data_item( load_id, schema.name, root_table_name, batch, - schema.get_table_columns(root_table_name), + columns_schema, ) if items_count == 0: - self.load_storage.write_empty_items_file( + self.item_storage.write_empty_items_file( load_id, schema.name, root_table_name, - self.schema.get_table_columns(root_table_name), + columns_schema, ) return [schema_update] @@ -328,24 +346,44 @@ def _fix_schema_precisions(self, root_table_name: str) -> List[TSchemaUpdate]: def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSchemaUpdate]: base_schema_update = self._fix_schema_precisions(root_table_name) + # read schema and counts from file metadata + from dlt.common.libs.pyarrow import get_parquet_metadata + + with self.normalize_storage.extracted_packages.storage.open_file( + extracted_items_file, "rb" + ) as f: + num_rows, arrow_schema = get_parquet_metadata(f) + file_metrics = DataWriterMetrics(extracted_items_file, num_rows, f.tell(), 0, 0) + add_dlt_id = self.config.parquet_normalizer.add_dlt_id add_dlt_load_id = self.config.parquet_normalizer.add_dlt_load_id - - if add_dlt_id or add_dlt_load_id or self.load_storage.loader_file_format != "arrow": + # if we need to add any columns or the file format is not parquet, we can't just import files + must_rewrite = ( + add_dlt_id or add_dlt_load_id or self.item_storage.writer_spec.file_format != "parquet" + ) + if not must_rewrite: + # in rare cases normalization may be needed + must_rewrite, _, _, _ = pyarrow.should_normalize_arrow_schema( + arrow_schema, self.schema.get_table_columns(root_table_name), self.schema.naming + ) + if must_rewrite: + logger.info( + f"Table {root_table_name} parquet file {extracted_items_file} must be rewritten:" + f" add_dlt_id: {add_dlt_id} add_dlt_load_id: {add_dlt_load_id} destination file" + f" format: {self.item_storage.writer_spec.file_format} or due to required" + " normalization " + ) schema_update = self._write_with_dlt_columns( extracted_items_file, root_table_name, add_dlt_load_id, add_dlt_id ) return base_schema_update + schema_update - from dlt.common.libs.pyarrow import get_row_count - - with self.normalize_storage.extracted_packages.storage.open_file( - extracted_items_file, "rb" - ) as f: - file_metrics = DataWriterMetrics(extracted_items_file, get_row_count(f), f.tell(), 0, 0) - + logger.info( + f"Table {root_table_name} parquet file {extracted_items_file} will be directly imported" + " without normalization" + ) parts = ParsedLoadJobFileName.parse(extracted_items_file) - self.load_storage.import_items_file( + self.item_storage.import_items_file( self.load_id, self.schema.name, parts.table_name, diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 4a17b9eef8..28c2c81571 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -1,5 +1,4 @@ import os -import datetime # noqa: 251 import itertools from typing import Callable, List, Dict, NamedTuple, Sequence, Tuple, Set, Optional from concurrent.futures import Future, Executor @@ -8,9 +7,8 @@ from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config from dlt.common.configuration.container import Container -from dlt.common.data_writers import DataWriterMetrics +from dlt.common.data_writers import DataWriter, DataWriterMetrics, TDataItemFormat from dlt.common.data_writers.writers import EMPTY_DATA_WRITER_METRICS -from dlt.common.destination import TLoaderFileFormat from dlt.common.runners import TRunMetrics, Runnable, NullExecutor from dlt.common.runtime import signals from dlt.common.runtime.collector import Collector, NULL_COLLECTOR @@ -39,7 +37,7 @@ from dlt.normalize.configuration import NormalizeConfiguration from dlt.normalize.exceptions import NormalizeJobFailed from dlt.normalize.items_normalizers import ( - ParquetItemsNormalizer, + ArrowItemsNormalizer, JsonLItemsNormalizer, ItemsNormalizer, ) @@ -89,7 +87,6 @@ def create_storages(self) -> None: # normalize saves in preferred format but can read all supported formats self.load_storage = LoadStorage( True, - self.config.destination_capabilities.preferred_loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS, config=self.config._load_storage_config, ) @@ -105,42 +102,43 @@ def w_normalize_files( ) -> TWorkerRV: destination_caps = config.destination_capabilities schema_updates: List[TSchemaUpdate] = [] - item_normalizers: Dict[TLoaderFileFormat, ItemsNormalizer] = {} - - def _create_load_storage(file_format: TLoaderFileFormat) -> LoadStorage: - """Creates a load storage for particular file_format""" - # TODO: capabilities.supported_*_formats can be None, it should have defaults - supported_formats = destination_caps.supported_loader_file_formats or [] - if file_format == "parquet": - if file_format in supported_formats: - supported_formats.append( - "arrow" - ) # TODO: Hack to make load storage use the correct writer - file_format = "arrow" - else: - # Use default storage if parquet is not supported to make normalizer fallback to read rows from the file - file_format = ( - destination_caps.preferred_loader_file_format - or destination_caps.preferred_staging_file_format - ) - else: - file_format = ( - destination_caps.preferred_loader_file_format - or destination_caps.preferred_staging_file_format - ) - return LoadStorage(False, file_format, supported_formats, loader_storage_config) + item_normalizers: Dict[TDataItemFormat, ItemsNormalizer] = {} + # Use default storage if parquet is not supported to make normalizer fallback to read rows from the file + preferred_file_format = ( + destination_caps.preferred_loader_file_format + or destination_caps.preferred_staging_file_format + ) + # TODO: capabilities.supported_*_formats can be None, it should have defaults + supported_formats = destination_caps.supported_loader_file_formats or [] # process all files with data items and write to buffered item storage with Container().injectable_context(destination_caps): schema = Schema.from_stored_schema(stored_schema) normalize_storage = NormalizeStorage(False, normalize_storage_config) + load_storage = LoadStorage(False, supported_formats, loader_storage_config) - def _get_items_normalizer(file_format: TLoaderFileFormat) -> ItemsNormalizer: - if file_format in item_normalizers: - return item_normalizers[file_format] - klass = ParquetItemsNormalizer if file_format == "parquet" else JsonLItemsNormalizer - norm = item_normalizers[file_format] = klass( - _create_load_storage(file_format), normalize_storage, schema, load_id, config + def _get_items_normalizer(item_format: TDataItemFormat) -> ItemsNormalizer: + if item_format in item_normalizers: + return item_normalizers[item_format] + item_storage = load_storage.create_item_storage(preferred_file_format, item_format) + if item_storage.writer_spec.file_format != preferred_file_format: + logger.warning( + f"For data items yielded as {item_format} job files in format" + f" {preferred_file_format} cannot be created." + f" {item_storage.writer_spec.file_format} jobs will be used instead." + ) + cls = ArrowItemsNormalizer if item_format == "arrow" else JsonLItemsNormalizer + logger.info( + f"Created items normalizer {cls.__name__} with writer" + f" {item_storage.writer_cls.__name__} for item format {item_format} and file" + f" format {item_storage.writer_spec.file_format}" + ) + norm = item_normalizers[item_format] = cls( + item_storage, + normalize_storage, + schema, + load_id, + config, ) return norm @@ -155,7 +153,9 @@ def _get_items_normalizer(file_format: TLoaderFileFormat) -> ItemsNormalizer: parsed_file_name.table_name ) root_tables.add(root_table_name) - normalizer = _get_items_normalizer(parsed_file_name.file_format) + normalizer = _get_items_normalizer( + DataWriter.item_format_from_file_extension(parsed_file_name.file_format) + ) logger.debug( f"Processing extracted items in {extracted_items_file} in load_id" f" {load_id} with table name {root_table_name} and schema {schema.name}" @@ -168,14 +168,14 @@ def _get_items_normalizer(file_format: TLoaderFileFormat) -> ItemsNormalizer: raise NormalizeJobFailed(load_id, job_id, str(exc)) from exc finally: for normalizer in item_normalizers.values(): - normalizer.load_storage.close_writers(load_id) + normalizer.item_storage.close_writers(load_id) - writer_metrics: List[DataWriterMetrics] = [] - for normalizer in item_normalizers.values(): - norm_metrics = normalizer.load_storage.closed_files(load_id) - writer_metrics.extend(norm_metrics) - logger.info(f"Processed all items in {len(extracted_items_files)} files") - return TWorkerRV(schema_updates, writer_metrics) + writer_metrics: List[DataWriterMetrics] = [] + for normalizer in item_normalizers.values(): + norm_metrics = normalizer.item_storage.closed_files(load_id) + writer_metrics.extend(norm_metrics) + logger.info(f"Processed all items in {len(extracted_items_files)} files") + return TWorkerRV(schema_updates, writer_metrics) def update_table(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> None: for schema_update in schema_updates: diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index 39a18c5de2..ad31e6240e 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -1,15 +1,12 @@ import pytest from fnmatch import fnmatch from typing import Dict, Iterator, List, Sequence, Tuple - -# from multiprocessing import get_start_method, Pool -# from multiprocessing.dummy import Pool as ThreadPool from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor from dlt.common import json +from dlt.common.destination.capabilities import TLoaderFileFormat from dlt.common.schema.schema import Schema from dlt.common.storages.exceptions import SchemaNotFoundError -from dlt.common.utils import uniq_id from dlt.common.typing import StrAny from dlt.common.data_types import TDataType from dlt.common.storages import NormalizeStorage, LoadStorage, ParsedLoadJobFileName, PackageStorage @@ -152,6 +149,7 @@ def test_normalize_filter_user_event( load_id = extract_and_normalize_cases(rasa_normalize, ["event.event.user_load_v228_1"]) _, load_files = expect_load_package( rasa_normalize.load_storage, + caps.preferred_loader_file_format, load_id, [ "event", @@ -178,7 +176,10 @@ def test_normalize_filter_bot_event( rasa_normalize, ["event.event.bot_load_metadata_2987398237498798"] ) _, load_files = expect_load_package( - rasa_normalize.load_storage, load_id, ["event", "event_bot"] + rasa_normalize.load_storage, + caps.preferred_loader_file_format, + load_id, + ["event", "event_bot"], ) event_text, lines = get_line_from_file(rasa_normalize.load_storage, load_files["event_bot"], 0) assert lines == 1 @@ -193,7 +194,10 @@ def test_preserve_slot_complex_value_json_l( ) -> None: load_id = extract_and_normalize_cases(rasa_normalize, ["event.event.slot_session_metadata_1"]) _, load_files = expect_load_package( - rasa_normalize.load_storage, load_id, ["event", "event_slot"] + rasa_normalize.load_storage, + caps.preferred_loader_file_format, + load_id, + ["event", "event_slot"], ) event_text, lines = get_line_from_file(rasa_normalize.load_storage, load_files["event_slot"], 0) assert lines == 1 @@ -208,7 +212,10 @@ def test_preserve_slot_complex_value_insert( ) -> None: load_id = extract_and_normalize_cases(rasa_normalize, ["event.event.slot_session_metadata_1"]) _, load_files = expect_load_package( - rasa_normalize.load_storage, load_id, ["event", "event_slot"] + rasa_normalize.load_storage, + caps.preferred_loader_file_format, + load_id, + ["event", "event_slot"], ) event_text, lines = get_line_from_file(rasa_normalize.load_storage, load_files["event_slot"], 2) assert lines == 3 @@ -224,7 +231,9 @@ def test_normalize_many_events_insert( rasa_normalize, ["event.event.many_load_2", "event.event.user_load_1"] ) expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] - _, load_files = expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) + _, load_files = expect_load_package( + rasa_normalize.load_storage, caps.preferred_loader_file_format, load_id, expected_tables + ) # return first values line from event_user file event_text, lines = get_line_from_file(rasa_normalize.load_storage, load_files["event"], 4) # 2 lines header + 3 lines data @@ -240,7 +249,9 @@ def test_normalize_many_events( rasa_normalize, ["event.event.many_load_2", "event.event.user_load_1"] ) expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] - _, load_files = expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) + _, load_files = expect_load_package( + rasa_normalize.load_storage, caps.preferred_loader_file_format, load_id, expected_tables + ) # return first values line from event_user file event_text, lines = get_line_from_file(rasa_normalize.load_storage, load_files["event"], 2) # 3 lines data @@ -316,10 +327,19 @@ def test_normalize_many_packages( # expect event tables if schema.name == "event": expected_tables = EXPECTED_USER_TABLES_RASA_NORMALIZER + ["event_bot", "event_action"] - expect_load_package(rasa_normalize.load_storage, load_id, expected_tables) + expect_load_package( + rasa_normalize.load_storage, + caps.preferred_loader_file_format, + load_id, + expected_tables, + ) if schema.name == "ethereum": expect_load_package( - rasa_normalize.load_storage, load_id, EXPECTED_ETH_TABLES, full_schema_update=False + rasa_normalize.load_storage, + caps.preferred_loader_file_format, + load_id, + EXPECTED_ETH_TABLES, + full_schema_update=False, ) assert set(schemas) == set(["ethereum", "event"]) @@ -348,7 +368,9 @@ def test_schema_changes(caps: DestinationCapabilitiesContext, raw_normalize: Nor doc = {"str": "text", "int": 1} extract_items(raw_normalize.normalize_storage, [doc], Schema("evolution"), "doc") load_id = normalize_pending(raw_normalize) - _, table_files = expect_load_package(raw_normalize.load_storage, load_id, ["doc"]) + _, table_files = expect_load_package( + raw_normalize.load_storage, caps.preferred_loader_file_format, load_id, ["doc"] + ) get_line_from_file(raw_normalize.load_storage, table_files["doc"], 0) assert len(table_files["doc"]) == 1 schema = raw_normalize.schema_storage.load_schema("evolution") @@ -360,7 +382,9 @@ def test_schema_changes(caps: DestinationCapabilitiesContext, raw_normalize: Nor doc2 = {"str": "text", "int": 1, "bool": True} extract_items(raw_normalize.normalize_storage, [doc, doc2, doc], schema, "doc") load_id = normalize_pending(raw_normalize) - _, table_files = expect_load_package(raw_normalize.load_storage, load_id, ["doc"]) + _, table_files = expect_load_package( + raw_normalize.load_storage, caps.preferred_loader_file_format, load_id, ["doc"] + ) assert len(table_files["doc"]) == 1 schema = raw_normalize.schema_storage.load_schema("evolution") doc_table = schema.get_table("doc") @@ -378,7 +402,9 @@ def test_schema_changes(caps: DestinationCapabilitiesContext, raw_normalize: Nor # extract_items(raw_normalize.normalize_storage, [doc3_2v, doc3_doc_v], schema, "doc") load_id = normalize_pending(raw_normalize) - _, table_files = expect_load_package(raw_normalize.load_storage, load_id, ["doc", "doc__comp"]) + _, table_files = expect_load_package( + raw_normalize.load_storage, caps.preferred_loader_file_format, load_id, ["doc", "doc__comp"] + ) assert len(table_files["doc"]) == 1 assert len(table_files["doc__comp"]) == 1 schema = raw_normalize.schema_storage.load_schema("evolution") @@ -405,7 +431,10 @@ def test_normalize_twice_with_flatten( ) -> None: load_id = extract_and_normalize_cases(raw_normalize, ["github.issues.load_page_5_duck"]) _, table_files = expect_load_package( - raw_normalize.load_storage, load_id, ["issues", "issues__labels", "issues__assignees"] + raw_normalize.load_storage, + caps.preferred_loader_file_format, + load_id, + ["issues", "issues__labels", "issues__assignees"], ) assert len(table_files["issues"]) == 1 _, lines = get_line_from_file(raw_normalize.load_storage, table_files["issues"], 0) @@ -425,6 +454,7 @@ def assert_schema(_schema: Schema): load_id = extract_and_normalize_cases(raw_normalize, ["github.issues.load_page_5_duck"]) _, table_files = expect_load_package( raw_normalize.load_storage, + caps.preferred_loader_file_format, load_id, ["issues", "issues__labels", "issues__assignees"], full_schema_update=False, @@ -454,7 +484,10 @@ def test_normalize_retry(raw_normalize: Normalize) -> None: # subsequent run must succeed raw_normalize.run(None) _, table_files = expect_load_package( - raw_normalize.load_storage, load_id, ["issues", "issues__labels", "issues__assignees"] + raw_normalize.load_storage, + raw_normalize.config.destination_capabilities.preferred_loader_file_format, + load_id, + ["issues", "issues__labels", "issues__assignees"], ) assert len(table_files["issues"]) == 1 @@ -530,7 +563,7 @@ def extract_items( ) -> str: extractor = ExtractStorage(normalize_storage.config) load_id = extractor.create_load_package(schema) - extractor.write_data_item("puae-jsonl", load_id, schema.name, table_name, items, None) + extractor.item_storages["object"].write_data_item(load_id, schema.name, table_name, items, None) extractor.close_writers(load_id) extractor.commit_new_load_package(load_id, schema) return load_id @@ -541,7 +574,12 @@ def normalize_event_user( ) -> Tuple[List[str], Dict[str, List[str]]]: expected_user_tables = expected_user_tables or EXPECTED_USER_TABLES_RASA_NORMALIZER load_id = extract_and_normalize_cases(normalize, [case]) - return expect_load_package(normalize.load_storage, load_id, expected_user_tables) + return expect_load_package( + normalize.load_storage, + normalize.config.destination_capabilities.preferred_loader_file_format, + load_id, + expected_user_tables, + ) def extract_and_normalize_cases(normalize: Normalize, cases: Sequence[str]) -> str: @@ -597,6 +635,7 @@ def load_or_create_schema(normalize: Normalize, schema_name: str) -> Schema: def expect_load_package( load_storage: LoadStorage, + file_format: TLoaderFileFormat, load_id: str, expected_tables: Sequence[str], full_schema_update: bool = True, @@ -621,7 +660,7 @@ def expect_load_package( expected_table, "*", validate_components=False, - loader_file_format=load_storage.loader_file_format, + loader_file_format=file_format, ) # files are in normalized//new_jobs file_path = load_storage.normalized_packages.get_job_file_path( From 7b4a0ed4e45ea37a047a842dbb5f0d1fc3f881ae Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 3 Apr 2024 23:06:07 +0200 Subject: [PATCH 08/22] removes internal file formats from loader file formats, renames, tests improvements --- dlt/common/destination/capabilities.py | 15 +--- dlt/common/destination/exceptions.py | 6 +- dlt/common/storages/data_item_storage.py | 38 +++----- dlt/common/storages/exceptions.py | 13 ++- dlt/common/storages/load_package.py | 15 ++-- dlt/common/storages/load_storage.py | 74 +++++++++------ dlt/common/versioned_state.py | 3 +- dlt/destinations/insert_job_client.py | 9 -- dlt/destinations/sql_client.py | 1 - dlt/helpers/airflow_helper.py | 2 +- dlt/load/load.py | 18 ++-- dlt/pipeline/configuration.py | 2 +- dlt/pipeline/helpers.py | 2 +- dlt/pipeline/pipeline.py | 18 ++-- dlt/sources/helpers/rest_client/auth.py | 21 ++--- dlt/sources/helpers/rest_client/client.py | 24 ++--- dlt/sources/helpers/rest_client/detector.py | 7 +- dlt/sources/helpers/rest_client/paginators.py | 4 +- docs/website/docs/reference/performance.md | 5 ++ tests/cases.py | 90 ++++++++++++++++--- tests/common/data_writers/utils.py | 33 ++++--- tests/common/storages/test_load_package.py | 10 +-- tests/common/storages/test_load_storage.py | 8 +- tests/common/storages/utils.py | 37 +++++++- tests/destinations/test_custom_destination.py | 6 +- tests/load/pipeline/test_redshift.py | 2 +- tests/load/test_dummy_client.py | 6 +- tests/load/utils.py | 10 ++- tests/pipeline/test_schema_contracts.py | 14 +-- tests/sources/helpers/rest_client/conftest.py | 16 +--- .../helpers/rest_client/test_client.py | 4 +- .../helpers/rest_client/test_detector.py | 8 +- .../helpers/rest_client/test_paginators.py | 8 +- tests/utils.py | 8 +- 34 files changed, 297 insertions(+), 240 deletions(-) diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index 7a64f32ea3..6b06d8287e 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -6,24 +6,15 @@ from dlt.common.utils import identity from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - from dlt.common.wei import EVM_DECIMAL_PRECISION # known loader file formats # jsonl - new line separated json documents -# puae-jsonl - internal extract -> normalize format bases on jsonl +# typed-jsonl - internal extract -> normalize format bases on jsonl # insert_values - insert SQL statements # sql - any sql statement -TLoaderFileFormat = Literal[ - "jsonl", "puae-jsonl", "insert_values", "sql", "parquet", "reference", "arrow" -] +TLoaderFileFormat = Literal["jsonl", "typed-jsonl", "insert_values", "parquet", "csv"] ALL_SUPPORTED_FILE_FORMATS: Set[TLoaderFileFormat] = set(get_args(TLoaderFileFormat)) -# file formats used internally by dlt -INTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = {"sql", "reference", "arrow"} -# file formats that may be chosen by the user -EXTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = ( - set(get_args(TLoaderFileFormat)) - INTERNAL_LOADER_FILE_FORMATS -) @configspec @@ -67,7 +58,7 @@ def generic_capabilities( ) -> "DestinationCapabilitiesContext": caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = preferred_loader_file_format - caps.supported_loader_file_formats = ["jsonl", "insert_values", "parquet"] + caps.supported_loader_file_formats = ["jsonl", "insert_values", "parquet", "csv"] caps.preferred_staging_file_format = None caps.supported_staging_file_formats = [] caps.escape_identifier = identity diff --git a/dlt/common/destination/exceptions.py b/dlt/common/destination/exceptions.py index 1b5423ff02..cd8f50bcce 100644 --- a/dlt/common/destination/exceptions.py +++ b/dlt/common/destination/exceptions.py @@ -77,9 +77,9 @@ def __init__( ) else: msg = ( - f"Unsupported file format {file_format} destination {destination}. Supported" - f" formats: {supported_formats_str}. Check the staging option in the dlt.pipeline" - " for additional formats." + f"Unsupported file format {file_format} in destination {destination}. Supported" + f" formats: {supported_formats_str}. If {destination} supports loading data via" + " staging bucket, more formats may be available." ) super().__init__(msg) diff --git a/dlt/common/storages/data_item_storage.py b/dlt/common/storages/data_item_storage.py index 816c6bc494..5b1e360789 100644 --- a/dlt/common/storages/data_item_storage.py +++ b/dlt/common/storages/data_item_storage.py @@ -3,15 +3,22 @@ from abc import ABC, abstractmethod from dlt.common import logger -from dlt.common.destination import TLoaderFileFormat from dlt.common.schema import TTableSchemaColumns from dlt.common.typing import StrAny, TDataItems -from dlt.common.data_writers import BufferedDataWriter, DataWriter, DataWriterMetrics +from dlt.common.data_writers import ( + BufferedDataWriter, + DataWriter, + DataWriterMetrics, + FileWriterSpec, +) class DataItemStorage(ABC): - def __init__(self, load_file_type: TLoaderFileFormat, *args: Any) -> None: - self.loader_file_format = load_file_type + def __init__(self, writer_spec: FileWriterSpec, *args: Any) -> None: + self.writer_spec = writer_spec + self.writer_cls = DataWriter.class_factory( + writer_spec.file_format, writer_spec.data_item_format + ) self.buffered_writers: Dict[str, BufferedDataWriter[DataWriter]] = {} super().__init__(*args) @@ -24,7 +31,7 @@ def _get_writer( if not writer: # assign a writer for each table path = self._get_data_item_path_template(load_id, schema_name, table_name) - writer = BufferedDataWriter(self.loader_file_format, path) + writer = BufferedDataWriter(self.writer_spec, path) self.buffered_writers[writer_id] = writer return writer @@ -90,27 +97,6 @@ def remove_closed_files(self, load_id: str) -> None: if name.startswith(load_id): writer.closed_files.clear() - def _write_temp_job_file( - self, - load_id: str, - table_name: str, - table: TTableSchemaColumns, - file_id: str, - rows: Sequence[StrAny], - ) -> str: - """Writes new file into new packages "new_jobs". Intended for testing""" - file_name = ( - self._get_data_item_path_template(load_id, None, table_name) % file_id - + "." - + self.loader_file_format - ) - format_spec = DataWriter.data_format_from_file_format(self.loader_file_format) - mode = "wb" if format_spec.is_binary_format else "w" - with self.storage.open_file(file_name, mode=mode) as f: # type: ignore[attr-defined] - writer = DataWriter.from_file_format(self.loader_file_format, f) - writer.write_all(table, rows) - return Path(file_name).name - @abstractmethod def _get_data_item_path_template(self, load_id: str, schema_name: str, table_name: str) -> str: """Returns a file template for item writer. note: use %s for file id to create required template format""" diff --git a/dlt/common/storages/exceptions.py b/dlt/common/storages/exceptions.py index f4288719c1..26a76bb5c0 100644 --- a/dlt/common/storages/exceptions.py +++ b/dlt/common/storages/exceptions.py @@ -2,7 +2,6 @@ from typing import Iterable from dlt.common.exceptions import DltException, TerminalValueError -from dlt.common.destination import TLoaderFileFormat class StorageException(DltException): @@ -63,16 +62,14 @@ class LoadStorageException(StorageException): pass -class JobWithUnsupportedWriterException(LoadStorageException, TerminalValueError): - def __init__( - self, load_id: str, expected_file_formats: Iterable[TLoaderFileFormat], wrong_job: str - ) -> None: +class JobFileFormatUnsupported(LoadStorageException, TerminalValueError): + def __init__(self, load_id: str, supported_formats: Iterable[str], wrong_job: str) -> None: self.load_id = load_id - self.expected_file_formats = expected_file_formats + self.expected_file_formats = supported_formats self.wrong_job = wrong_job super().__init__( - f"Job {wrong_job} for load id {load_id} requires loader file format that is not one of" - f" {expected_file_formats}" + f"Job {wrong_job} for load id {load_id} requires job file format that is not one of" + f" {supported_formats}" ) diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 3b8af424ee..3ca5056d8e 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -49,6 +49,9 @@ ) from typing_extensions import NotRequired +TJobFileFormat = Literal["sql", "reference", TLoaderFileFormat] +"""Loader file formats with internal job types""" + class TLoadPackageState(TVersionedState, total=False): created_at: str @@ -67,7 +70,7 @@ class TLoadPackage(TypedDict, total=False): # allows to upgrade state when restored with a new version of state logic/schema -LOADPACKAGE_STATE_ENGINE_VERSION = 1 +LOAD_PACKAGE_STATE_ENGINE_VERSION = 1 def generate_loadpackage_state_version_hash(state: TLoadPackageState) -> str: @@ -97,7 +100,7 @@ def migrate_load_package_state( def default_load_package_state() -> TLoadPackageState: return { **default_versioned_state(), - "_state_engine_version": LOADPACKAGE_STATE_ENGINE_VERSION, + "_state_engine_version": LOAD_PACKAGE_STATE_ENGINE_VERSION, } @@ -116,7 +119,7 @@ class ParsedLoadJobFileName(NamedTuple): table_name: str file_id: str retry_count: int - file_format: TLoaderFileFormat + file_format: TJobFileFormat def job_id(self) -> str: """Unique identifier of the job""" @@ -138,7 +141,7 @@ def parse(file_name: str) -> "ParsedLoadJobFileName": raise TerminalValueError(parts) return ParsedLoadJobFileName( - parts[0], parts[1], int(parts[2]), cast(TLoaderFileFormat, parts[3]) + parts[0], parts[1], int(parts[2]), cast(TJobFileFormat, parts[3]) ) @staticmethod @@ -471,7 +474,7 @@ def get_load_package_state(self, load_id: str) -> TLoadPackageState: state_dump = self.storage.load(self.get_load_package_state_path(load_id)) state = json.loads(state_dump) return migrate_load_package_state( - state, state["_state_engine_version"], LOADPACKAGE_STATE_ENGINE_VERSION + state, state["_state_engine_version"], LOAD_PACKAGE_STATE_ENGINE_VERSION ) except FileNotFoundError: return default_load_package_state() @@ -594,7 +597,7 @@ def build_job_file_name( FileStorage.validate_file_name_component(table_name) fn = f"{table_name}.{file_id}.{int(retry_count)}" if loader_file_format: - format_spec = DataWriter.data_format_from_file_format(loader_file_format) + format_spec = DataWriter.writer_spec_from_file_format(loader_file_format, "object") return fn + f".{format_spec.file_extension}" return fn diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index ffd55e7f29..8b5109d9e2 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -1,13 +1,12 @@ from os.path import join -from typing import Iterable, Optional, Sequence +from typing import Iterable, List, Optional, Sequence -from dlt.common.typing import DictStrAny +from dlt.common.data_writers.exceptions import DataWriterNotFound from dlt.common import json from dlt.common.configuration import known_sections from dlt.common.configuration.inject import with_config from dlt.common.destination import ALL_SUPPORTED_FILE_FORMATS, TLoaderFileFormat from dlt.common.configuration.accessors import config -from dlt.common.exceptions import TerminalValueError from dlt.common.schema import TSchemaTables from dlt.common.storages.file_storage import FileStorage from dlt.common.storages.configuration import LoadStorageConfiguration @@ -20,11 +19,28 @@ ParsedLoadJobFileName, TJobState, TLoadPackageState, + TJobFileFormat, ) -from dlt.common.storages.exceptions import JobWithUnsupportedWriterException, LoadPackageNotFound +from dlt.common.data_writers import DataWriter, FileWriterSpec, TDataItemFormat +from dlt.common.storages.exceptions import JobFileFormatUnsupported, LoadPackageNotFound -class LoadStorage(DataItemStorage, VersionedStorage): +class LoadItemStorage(DataItemStorage): + def __init__(self, package_storage: PackageStorage, writer_spec: FileWriterSpec) -> None: + """Data item storage using `storage` to manage load packages""" + super().__init__(writer_spec) + self.package_storage = package_storage + + def _get_data_item_path_template(self, load_id: str, _: str, table_name: str) -> str: + # implements DataItemStorage._get_data_item_path_template + file_name = PackageStorage.build_job_file_name(table_name, "%s") + file_path = self.package_storage.get_job_file_path( + load_id, PackageStorage.NEW_JOBS_FOLDER, file_name + ) + return self.package_storage.storage.make_full_path(file_path) + + +class LoadStorage(VersionedStorage): STORAGE_VERSION = "1.0.0" NORMALIZED_FOLDER = "normalized" # folder within the volume where load packages are stored LOADED_FOLDER = "loaded" # folder to keep the loads that were completely processed @@ -36,23 +52,14 @@ class LoadStorage(DataItemStorage, VersionedStorage): def __init__( self, is_owner: bool, - preferred_file_format: TLoaderFileFormat, supported_file_formats: Iterable[TLoaderFileFormat], config: LoadStorageConfiguration = config.value, ) -> None: - # puae-jsonl jobs have the extension .jsonl, so cater for this here - if supported_file_formats and "puae-jsonl" in supported_file_formats: - supported_file_formats = list(supported_file_formats) - supported_file_formats.append("jsonl") - - if not LoadStorage.ALL_SUPPORTED_FILE_FORMATS.issuperset(supported_file_formats): - raise TerminalValueError(supported_file_formats) - if preferred_file_format and preferred_file_format not in supported_file_formats: - raise TerminalValueError(preferred_file_format) - self.supported_file_formats = supported_file_formats + self.supported_loader_file_formats = supported_file_formats + # store job file formats to add internal job formats as needed + self.supported_job_file_formats: List[TJobFileFormat] = list(supported_file_formats) self.config = config super().__init__( - preferred_file_format, LoadStorage.STORAGE_VERSION, is_owner, FileStorage(config.load_volume_path, "t", makedirs=is_owner), @@ -75,13 +82,28 @@ def initialize_storage(self) -> None: self.storage.create_folder(LoadStorage.NORMALIZED_FOLDER, exists_ok=True) self.storage.create_folder(LoadStorage.LOADED_FOLDER, exists_ok=True) - def _get_data_item_path_template(self, load_id: str, _: str, table_name: str) -> str: - # implements DataItemStorage._get_data_item_path_template - file_name = PackageStorage.build_job_file_name(table_name, "%s") - file_path = self.new_packages.get_job_file_path( - load_id, PackageStorage.NEW_JOBS_FOLDER, file_name - ) - return self.new_packages.storage.make_full_path(file_path) + def create_item_storage( + self, preferred_format: TLoaderFileFormat, item_format: TDataItemFormat + ) -> DataItemStorage: + """Creates item storage for preferred_format + item_format combination. If not found, it + tries the remaining file formats in supported formats. + """ + try: + return LoadItemStorage( + self.new_packages, + DataWriter.writer_spec_from_file_format(preferred_format, item_format), + ) + except DataWriterNotFound: + for supported_format in self.supported_loader_file_formats: + if supported_format != preferred_format: + try: + return LoadItemStorage( + self.new_packages, + DataWriter.writer_spec_from_file_format(supported_format, item_format), + ) + except DataWriterNotFound: + pass + raise def list_new_jobs(self, load_id: str) -> Sequence[str]: """Lists all jobs in new jobs folder of normalized package storage and checks if file formats are supported""" @@ -91,12 +113,12 @@ def list_new_jobs(self, load_id: str) -> Sequence[str]: ( j for j in new_jobs - if ParsedLoadJobFileName.parse(j).file_format not in self.supported_file_formats + if ParsedLoadJobFileName.parse(j).file_format not in self.supported_job_file_formats ), None, ) if wrong_job is not None: - raise JobWithUnsupportedWriterException(load_id, self.supported_file_formats, wrong_job) + raise JobFileFormatUnsupported(load_id, self.supported_job_file_formats, wrong_job) return new_jobs def commit_new_load_package(self, load_id: str) -> None: diff --git a/dlt/common/versioned_state.py b/dlt/common/versioned_state.py index a051a6660c..6f45df83c4 100644 --- a/dlt/common/versioned_state.py +++ b/dlt/common/versioned_state.py @@ -2,9 +2,8 @@ import hashlib from copy import copy -import datetime # noqa: 251 from dlt.common import json -from typing import TypedDict, Dict, Any, List, Tuple, cast +from typing import TypedDict, List, Tuple class TVersionedState(TypedDict, total=False): diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index 776176078e..066855b894 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -124,12 +124,3 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> if file_path.endswith("insert_values"): job = InsertValuesLoadJob(table["name"], file_path, self.sql_client) return job - - # # TODO: implement indexes and primary keys for postgres - # def _get_in_table_constraints_sql(self, t: TTableSchema) -> str: - # # get primary key - # pass - - # def _get_out_table_constrains_sql(self, t: TTableSchema) -> str: - # # set non unique indexes - # pass diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 9d872a238e..9b73d7d28c 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -24,7 +24,6 @@ from dlt.destinations.exceptions import ( DestinationConnectionError, LoadClientNotConnected, - DatabaseTerminalException, ) from dlt.destinations.typing import DBApi, TNativeConn, DBApiCursor, DataFrame, DBTransaction diff --git a/dlt/helpers/airflow_helper.py b/dlt/helpers/airflow_helper.py index e01cf790d2..6677475499 100644 --- a/dlt/helpers/airflow_helper.py +++ b/dlt/helpers/airflow_helper.py @@ -27,7 +27,7 @@ from dlt.common import pendulum from dlt.common import logger from dlt.common.runtime.telemetry import with_telemetry -from dlt.common.data_writers import TLoaderFileFormat +from dlt.common.destination import TLoaderFileFormat from dlt.common.schema.typing import TWriteDisposition, TSchemaContract from dlt.common.utils import uniq_id from dlt.common.normalizers.naming.snake_case import NamingConvention as SnakeCaseNamingConvention diff --git a/dlt/load/load.py b/dlt/load/load.py index f02a21f98e..b1f786274e 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -86,16 +86,18 @@ def create_storage(self, is_storage_owner: bool) -> LoadStorage: if self.staging_destination: supported_file_formats = ( self.staging_destination.capabilities().supported_loader_file_formats - + ["reference"] ) - if isinstance(self.get_destination_client(Schema("test")), WithStagingDataset): - supported_file_formats += ["sql"] load_storage = LoadStorage( is_storage_owner, - self.capabilities.preferred_loader_file_format, supported_file_formats, config=self.config._load_storage_config, ) + # add internal job formats + if issubclass(self.destination.client_class, WithStagingDataset): + load_storage.supported_job_file_formats += ["sql"] + if self.staging_destination: + load_storage.supported_job_file_formats += ["reference"] + return load_storage def get_destination_client(self, schema: Schema) -> JobClientBase: @@ -139,7 +141,7 @@ def w_spool_job( else job_client ) as client: job_info = ParsedLoadJobFileName.parse(file_path) - if job_info.file_format not in self.load_storage.supported_file_formats: + if job_info.file_format not in self.load_storage.supported_job_file_formats: raise LoadClientUnsupportedFileFormats( job_info.file_format, self.capabilities.supported_loader_file_formats, @@ -177,6 +179,12 @@ def w_spool_job( # return no job so file stays in new jobs (root) folder logger.exception(f"Temporary problem when adding job {file_path}") job = EmptyLoadJob.from_file_path(file_path, "retry", pretty_format_exception()) + if job is None: + raise DestinationTerminalException( + f"Destination could not create a job for file {file_path}. Typically the file" + " extension could not be associated with job type and that indicates an error in" + " the code." + ) self.load_storage.normalized_packages.start_job(load_id, job.file_name()) return job diff --git a/dlt/pipeline/configuration.py b/dlt/pipeline/configuration.py index d7ffca6e89..8c46ed049f 100644 --- a/dlt/pipeline/configuration.py +++ b/dlt/pipeline/configuration.py @@ -4,7 +4,7 @@ from dlt.common.configuration.specs import RunConfiguration, BaseConfiguration from dlt.common.typing import AnyFun, TSecretValue from dlt.common.utils import digest256 -from dlt.common.data_writers import TLoaderFileFormat +from dlt.common.destination import TLoaderFileFormat @configspec diff --git a/dlt/pipeline/helpers.py b/dlt/pipeline/helpers.py index c242a26eaa..c1c3326171 100644 --- a/dlt/pipeline/helpers.py +++ b/dlt/pipeline/helpers.py @@ -1,5 +1,5 @@ import contextlib -from typing import Callable, Sequence, Iterable, Optional, Any, List, Dict, Tuple, Union, TypedDict +from typing import Callable, Sequence, Iterable, Optional, Any, List, Dict, Union, TypedDict from itertools import chain from dlt.common.jsonpath import resolve_paths, TAnyJsonPath, compile_paths diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index de1f7afced..b0d04dfbe8 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -48,7 +48,7 @@ ) from dlt.common.schema.utils import normalize_schema_name from dlt.common.storages.exceptions import LoadPackageNotFound -from dlt.common.typing import DictStrAny, TFun, TSecretValue, is_optional_type +from dlt.common.typing import TFun, TSecretValue, is_optional_type from dlt.common.runners import pool_runner as runner from dlt.common.storages import ( LiveSchemaStorage, @@ -63,7 +63,12 @@ LoadJobInfo, LoadPackageInfo, ) -from dlt.common.destination import DestinationCapabilitiesContext, TDestination +from dlt.common.destination import ( + DestinationCapabilitiesContext, + TDestination, + ALL_SUPPORTED_FILE_FORMATS, + TLoaderFileFormat, +) from dlt.common.destination.reference import ( DestinationClientDwhConfiguration, WithStateSync, @@ -75,7 +80,6 @@ DestinationClientStagingConfiguration, DestinationClientDwhWithStagingConfiguration, ) -from dlt.common.destination.capabilities import INTERNAL_LOADER_FILE_FORMATS from dlt.common.pipeline import ( ExtractInfo, LoadInfo, @@ -92,7 +96,6 @@ ) from dlt.common.schema import Schema from dlt.common.utils import is_interactive -from dlt.common.data_writers import TLoaderFileFormat from dlt.common.warnings import deprecated, Dlt04DeprecationWarning from dlt.extract import DltSource @@ -452,8 +455,8 @@ def normalize( if is_interactive(): workers = 1 - if loader_file_format and loader_file_format in INTERNAL_LOADER_FILE_FORMATS: - raise ValueError(f"{loader_file_format} is one of internal dlt file formats.") + if loader_file_format and loader_file_format not in ALL_SUPPORTED_FILE_FORMATS: + raise ValueError(f"{loader_file_format} is unknown.") # check if any schema is present, if not then no data was extracted if not self.default_schema_name: return None @@ -976,7 +979,6 @@ def _get_load_storage(self) -> LoadStorage: caps = self._get_destination_capabilities() return LoadStorage( True, - caps.preferred_loader_file_format, caps.supported_loader_file_formats, self._load_storage_config(), ) @@ -1320,7 +1322,7 @@ def _resolve_loader_file_format( destination, staging, file_format, - set(possible_file_formats) - INTERNAL_LOADER_FILE_FORMATS, + set(possible_file_formats), ) return file_format diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index 5d7a2f7eb2..99421e2c60 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -30,9 +30,7 @@ else: PrivateKeyTypes = Any -TApiKeyLocation = Literal[ - "header", "cookie", "query", "param" -] # Alias for scheme "in" field +TApiKeyLocation = Literal["header", "cookie", "query", "param"] # Alias for scheme "in" field class AuthConfigBase(AuthBase, CredentialsConfiguration): @@ -102,7 +100,8 @@ def parse_native_representation(self, value: Any) -> None: raise NativeValueError( type(self), value, - f"HttpBasicAuth username and password must be a tuple of two strings, got {type(value)}", + "HttpBasicAuth username and password must be a tuple of two strings, got" + f" {type(value)}", ) def __call__(self, request: PreparedRequest) -> PreparedRequest: @@ -147,9 +146,7 @@ class OAuthJWTAuth(BearerTokenAuth): default_token_expiration: int = 3600 def __post_init__(self) -> None: - self.scopes = ( - self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes) - ) + self.scopes = self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes) self.token = None self.token_expiry: Optional[pendulum.DateTime] = None @@ -171,9 +168,7 @@ def obtain_token(self) -> None: payload = self.create_jwt_payload() data = { "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", - "assertion": jwt.encode( - payload, self.load_private_key(), algorithm="RS256" - ), + "assertion": jwt.encode(payload, self.load_private_key(), algorithm="RS256"), } logger.debug(f"Obtaining token from {self.auth_endpoint}") @@ -208,8 +203,8 @@ def load_private_key(self) -> "PrivateKeyTypes": private_key_bytes = self.private_key.encode("utf-8") return serialization.load_pem_private_key( private_key_bytes, - password=self.private_key_passphrase.encode("utf-8") - if self.private_key_passphrase - else None, + password=( + self.private_key_passphrase.encode("utf-8") if self.private_key_passphrase else None + ), backend=default_backend(), ) diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index 4b5625eebe..027afc7cbb 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -135,9 +135,7 @@ def _send_request(self, request: Request) -> Response: return self.session.send(prepared_request) - def request( - self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any - ) -> Response: + def request(self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any) -> Response: prepared_request = self._create_request( path=path, method=method, @@ -145,14 +143,10 @@ def request( ) return self._send_request(prepared_request) - def get( - self, path: str, params: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> Response: + def get(self, path: str, params: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Response: return self.request(path, method="GET", params=params, **kwargs) - def post( - self, path: str, json: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> Response: + def post(self, path: str, json: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Response: return self.request(path, method="POST", json=json, **kwargs) def paginate( @@ -224,16 +218,12 @@ def raise_for_status(response: Response, *args: Any, **kwargs: Any) -> None: paginator.update_request(request) # yield data with context - yield PageData( - data, request=request, response=response, paginator=paginator, auth=auth - ) + yield PageData(data, request=request, response=response, paginator=paginator, auth=auth) if not paginator.has_next_page: break - def extract_response( - self, response: Response, data_selector: jsonpath.TJsonPath - ) -> List[Any]: + def extract_response(self, response: Response, data_selector: jsonpath.TJsonPath) -> List[Any]: if data_selector: # we should compile data_selector data: Any = jsonpath.find_values(data_selector, response.json()) @@ -257,8 +247,6 @@ def detect_paginator(self, response: Response) -> BasePaginator: """ paginator = self.pagination_factory.create_paginator(response) if paginator is None: - raise ValueError( - f"No suitable paginator found for the response at {response.url}" - ) + raise ValueError(f"No suitable paginator found for the response at {response.url}") logger.info(f"Detected paginator: {paginator.__class__.__name__}") return paginator diff --git a/dlt/sources/helpers/rest_client/detector.py b/dlt/sources/helpers/rest_client/detector.py index f3af31bb4d..547162358c 100644 --- a/dlt/sources/helpers/rest_client/detector.py +++ b/dlt/sources/helpers/rest_client/detector.py @@ -80,8 +80,7 @@ def find_records( return next( list_info[2] for list_info in lists - if list_info[1] in RECORD_KEY_PATTERNS - and list_info[1] not in NON_RECORD_KEY_PATTERNS + if list_info[1] in RECORD_KEY_PATTERNS and list_info[1] not in NON_RECORD_KEY_PATTERNS ) except StopIteration: # return the least nested element @@ -142,9 +141,7 @@ def single_page_detector(response: Response) -> Optional[SinglePagePaginator]: class PaginatorFactory: - def __init__( - self, detectors: List[Callable[[Response], Optional[BasePaginator]]] = None - ): + def __init__(self, detectors: List[Callable[[Response], Optional[BasePaginator]]] = None): if detectors is None: detectors = [ header_links_detector, diff --git a/dlt/sources/helpers/rest_client/paginators.py b/dlt/sources/helpers/rest_client/paginators.py index c098ea667f..65605b7dee 100644 --- a/dlt/sources/helpers/rest_client/paginators.py +++ b/dlt/sources/helpers/rest_client/paginators.py @@ -83,9 +83,7 @@ def update_state(self, response: Response) -> None: total = values[0] if values else None if total is None: - raise ValueError( - f"Total count not found in response for {self.__class__.__name__}" - ) + raise ValueError(f"Total count not found in response for {self.__class__.__name__}") self.offset += self.limit diff --git a/docs/website/docs/reference/performance.md b/docs/website/docs/reference/performance.md index 39db678ca1..605fad1a7c 100644 --- a/docs/website/docs/reference/performance.md +++ b/docs/website/docs/reference/performance.md @@ -245,6 +245,11 @@ from dlt.common import json **orjson** is fast and available on most platforms. It uses binary streams, not strings to load data natively. - open files as binary, not string to use `load` and `dump` - use `loadb` and `dumpb` methods to work with bytes without decoding strings + +You can switch to **simplejson** at any moment by (1) removing **orjson** dependency or (2) setting the following env variable: +``` +DLT_USE_JSON=simplejson +``` ::: ## Using the built in requests client diff --git a/tests/cases.py b/tests/cases.py index 85caec4b8d..9a0213d837 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -1,3 +1,4 @@ +import hashlib from typing import Dict, List, Any, Sequence, Tuple, Literal, Union import base64 from hexbytes import HexBytes @@ -101,22 +102,22 @@ {"name": "col7_precision", "data_type": "binary", "precision": 19, "nullable": False}, {"name": "col11_precision", "data_type": "time", "precision": 3, "nullable": False}, ] -TABLE_UPDATE_COLUMNS_SCHEMA: TTableSchemaColumns = {t["name"]: t for t in TABLE_UPDATE} +TABLE_UPDATE_COLUMNS_SCHEMA: TTableSchemaColumns = {c["name"]: c for c in TABLE_UPDATE} TABLE_ROW_ALL_DATA_TYPES = { "col1": 989127831, "col2": 898912.821982, "col3": True, "col4": "2022-05-23T13:26:45.176451+00:00", - "col5": "string data \n \r \x8e 🦆", + "col5": "string data \n \r 🦆", "col6": Decimal("2323.34"), - "col7": b"binary data \n \r \x8e", + "col7": b"binary data \n \r ", "col8": 2**56 + 92093890840, "col9": { "complex": [1, 2, 3, "a"], "link": ( "?commen\ntU\nrn=urn%3Ali%3Acomment%3A%28acti\012 \6" - " \\vity%3A69'08444473\n\n551163392%2C6n \r \x8e9085" + " \\vity%3A69'08444473\n\n551163392%2C6n \r 9085" ), }, "col10": "2023-02-27", @@ -134,13 +135,21 @@ "col11_null": None, "col1_precision": 22324, "col4_precision": "2022-05-23T13:26:46.167231+00:00", - "col5_precision": "string data 2 \n \r \x8e 🦆", + "col5_precision": "string data 2 \n \r 🦆", "col6_precision": Decimal("2323.34"), - "col7_precision": b"binary data 2 \n \r \x8e", + "col7_precision": b"binary data 2 \n \r A", "col11_precision": "13:26:45.176451", } +TABLE_ROW_ALL_DATA_TYPES_DATETIMES = deepcopy(TABLE_ROW_ALL_DATA_TYPES) +TABLE_ROW_ALL_DATA_TYPES_DATETIMES["col4"] = ensure_pendulum_datetime(TABLE_ROW_ALL_DATA_TYPES_DATETIMES["col4"]) # type: ignore[arg-type] +TABLE_ROW_ALL_DATA_TYPES_DATETIMES["col10"] = ensure_pendulum_date(TABLE_ROW_ALL_DATA_TYPES_DATETIMES["col10"]) # type: ignore[arg-type] +TABLE_ROW_ALL_DATA_TYPES_DATETIMES["col11"] = pendulum.Time.fromisoformat(TABLE_ROW_ALL_DATA_TYPES_DATETIMES["col11"]) # type: ignore[arg-type] +TABLE_ROW_ALL_DATA_TYPES_DATETIMES["col4_precision"] = ensure_pendulum_datetime(TABLE_ROW_ALL_DATA_TYPES_DATETIMES["col4_precision"]) # type: ignore[arg-type] +TABLE_ROW_ALL_DATA_TYPES_DATETIMES["col11_precision"] = pendulum.Time.fromisoformat(TABLE_ROW_ALL_DATA_TYPES_DATETIMES["col11_precision"]) # type: ignore[arg-type] + + def table_update_and_row( exclude_types: Sequence[TDataType] = None, exclude_columns: Sequence[str] = None ) -> Tuple[TTableSchemaColumns, StrAny]: @@ -259,37 +268,57 @@ def arrow_format_from_pandas( raise ValueError("Unknown item type: " + object_format) +def arrow_item_from_table( + table: Any, + object_format: TArrowFormat, +) -> Any: + from dlt.common.libs.pyarrow import pyarrow as pa + + if object_format == "pandas": + return table.to_pandas() + elif object_format == "table": + return table + elif object_format == "record_batch": + return table.to_batches()[0] + raise ValueError("Unknown item type: " + object_format) + + def arrow_table_all_data_types( object_format: TArrowFormat, include_json: bool = True, include_time: bool = True, + include_binary: bool = True, + include_decimal: bool = True, + include_date: bool = True, include_not_normalized_name: bool = True, include_name_clash: bool = False, num_rows: int = 3, + tz="UTC", ) -> Tuple[Any, List[Dict[str, Any]]]: """Create an arrow object or pandas dataframe with all supported data types. Returns the table and its records in python format """ import pandas as pd - from dlt.common.libs.pyarrow import pyarrow as pa + import numpy as np data = { - "string": [random.choice(ascii_lowercase) for _ in range(num_rows)], + "string": [random.choice(ascii_lowercase) + "\"'\\🦆\n\r" for _ in range(num_rows)], "float": [round(random.uniform(0, 100), 4) for _ in range(num_rows)], "int": [random.randrange(0, 100) for _ in range(num_rows)], - "datetime": pd.date_range("2021-01-01T01:02:03.1234", periods=num_rows, tz="UTC"), - "date": pd.date_range("2021-01-01", periods=num_rows, tz="UTC").date, - "binary": [random.choice(ascii_lowercase).encode() for _ in range(num_rows)], - "decimal": [Decimal(str(round(random.uniform(0, 100), 4))) for _ in range(num_rows)], + "datetime": pd.date_range("2021-01-01T01:02:03.1234", periods=num_rows, tz=tz, unit="us"), "bool": [random.choice([True, False]) for _ in range(num_rows)], "string_null": [random.choice(ascii_lowercase) for _ in range(num_rows - 1)] + [None], + "float_null": [round(random.uniform(0, 100), 5) for _ in range(num_rows - 1)] + [ + None + ], # decrease precision "null": pd.Series([None for _ in range(num_rows)]), } if include_name_clash: data["pre Normalized Column"] = [random.choice(ascii_lowercase) for _ in range(num_rows)] include_not_normalized_name = True + if include_not_normalized_name: data["Pre Normalized Column"] = [random.choice(ascii_lowercase) for _ in range(num_rows)] @@ -299,7 +328,19 @@ def arrow_table_all_data_types( if include_time: data["time"] = pd.date_range("2021-01-01", periods=num_rows, tz="UTC").time + if include_binary: + # "binary": [hashlib.sha3_256(random.choice(ascii_lowercase).encode()).digest() for _ in range(num_rows)], + data["binary"] = [random.choice(ascii_lowercase).encode() for _ in range(num_rows)] + + if include_decimal: + data["decimal"] = [Decimal(str(round(random.uniform(0, 100), 4))) for _ in range(num_rows)] + + if include_date: + data["date"] = pd.date_range("2021-01-01", periods=num_rows, tz=tz).date + df = pd.DataFrame(data) + # None integers/floats are converted to nan, also replaces floats with objects and loses precision + df = df.replace(np.nan, None) # records have normalized identifiers for comparing rows = ( df.rename( @@ -311,3 +352,28 @@ def arrow_table_all_data_types( .to_dict("records") ) return arrow_format_from_pandas(df, object_format), rows + + +def prepare_shuffled_tables() -> Tuple[Any, Any, Any]: + from dlt.common.libs.pyarrow import remove_columns + from dlt.common.libs.pyarrow import pyarrow as pa + + table, _ = arrow_table_all_data_types( + "table", + include_json=False, + include_not_normalized_name=False, + tz="Europe/Berlin", + num_rows=5432, + ) + # remove null column from table (it will be removed in extract) + table = remove_columns(table, "null") + # shuffled_columns = table.schema.names + shuffled_indexes = list(range(len(table.schema.names))) + random.shuffle(shuffled_indexes) + shuffled_table = pa.Table.from_arrays( + [table.column(idx) for idx in shuffled_indexes], + schema=pa.schema([table.schema.field(idx) for idx in shuffled_indexes]), + ) + shuffled_removed_column = remove_columns(shuffled_table, ["binary"]) + assert shuffled_table.schema.names != table.schema.names + return table, shuffled_table, shuffled_removed_column diff --git a/tests/common/data_writers/utils.py b/tests/common/data_writers/utils.py index a02d654728..95e7b8ab64 100644 --- a/tests/common/data_writers/utils.py +++ b/tests/common/data_writers/utils.py @@ -1,35 +1,34 @@ import os -from typing import Set, Literal +from typing import Type - -from dlt.common.data_writers.buffered import BufferedDataWriter, DataWriter -from dlt.common.destination import TLoaderFileFormat, DestinationCapabilitiesContext +from dlt.common.data_writers.buffered import BufferedDataWriter +from dlt.common.data_writers.writers import TWriter, ALL_WRITERS +from dlt.common.destination import DestinationCapabilitiesContext from tests.utils import TEST_STORAGE_ROOT -ALL_WRITERS: Set[Literal[TLoaderFileFormat]] = { - "insert_values", - "jsonl", - "parquet", - "arrow", - "puae-jsonl", -} +ALL_OBJECT_WRITERS = [ + writer for writer in ALL_WRITERS if writer.writer_spec().data_item_format == "object" +] def get_writer( - _format: TLoaderFileFormat = "insert_values", + writer: Type[TWriter], buffer_max_items: int = 10, - file_max_items: int = 5000, + file_max_items: int = 10, + file_max_bytes: int = None, disable_compression: bool = False, -) -> BufferedDataWriter[DataWriter]: +) -> BufferedDataWriter[TWriter]: caps = DestinationCapabilitiesContext.generic_capabilities() - caps.preferred_loader_file_format = _format - file_template = os.path.join(TEST_STORAGE_ROOT, f"{_format}.%s") + writer_spec = writer.writer_spec() + caps.preferred_loader_file_format = writer_spec.file_format + file_template = os.path.join(TEST_STORAGE_ROOT, f"{writer_spec.file_format}.%s") return BufferedDataWriter( - _format, + writer_spec, file_template, buffer_max_items=buffer_max_items, file_max_items=file_max_items, + file_max_bytes=file_max_bytes, disable_compression=disable_compression, _caps=caps, ) diff --git a/tests/common/storages/test_load_package.py b/tests/common/storages/test_load_package.py index d61029c8cf..68396a76c8 100644 --- a/tests/common/storages/test_load_package.py +++ b/tests/common/storages/test_load_package.py @@ -185,7 +185,7 @@ def test_build_parse_job_path(load_storage: LoadStorage) -> None: file_id = ParsedLoadJobFileName.new_file_id() f_n_t = ParsedLoadJobFileName("test_table", file_id, 0, "jsonl") job_f_n = PackageStorage.build_job_file_name( - f_n_t.table_name, file_id, 0, loader_file_format=load_storage.loader_file_format + f_n_t.table_name, file_id, 0, loader_file_format="jsonl" ) # test the exact representation but we should probably not test for that assert job_f_n == f"test_table.{file_id}.0.jsonl" @@ -195,12 +195,8 @@ def test_build_parse_job_path(load_storage: LoadStorage) -> None: # parts cannot contain dots with pytest.raises(ValueError): - PackageStorage.build_job_file_name( - "test.table", file_id, 0, loader_file_format=load_storage.loader_file_format - ) - PackageStorage.build_job_file_name( - "test_table", "f.id", 0, loader_file_format=load_storage.loader_file_format - ) + PackageStorage.build_job_file_name("test.table", file_id, 0, loader_file_format="jsonl") + PackageStorage.build_job_file_name("test_table", "f.id", 0, loader_file_format="jsonl") # parsing requires 4 parts and retry count with pytest.raises(ValueError): diff --git a/tests/common/storages/test_load_storage.py b/tests/common/storages/test_load_storage.py index 0fe112581e..a70242001d 100644 --- a/tests/common/storages/test_load_storage.py +++ b/tests/common/storages/test_load_storage.py @@ -160,19 +160,19 @@ def test_get_unknown_package_info(load_storage: LoadStorage) -> None: def test_full_migration_path() -> None: # create directory structure - s = LoadStorage(True, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + s = LoadStorage(True, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) # overwrite known initial version write_version(s.storage, "1.0.0") # must be able to migrate to current version - s = LoadStorage(False, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + s = LoadStorage(False, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) assert s.version == LoadStorage.STORAGE_VERSION def test_unknown_migration_path() -> None: # create directory structure - s = LoadStorage(True, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + s = LoadStorage(True, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) # overwrite known initial version write_version(s.storage, "10.0.0") # must be able to migrate to current version with pytest.raises(NoMigrationPathException): - LoadStorage(False, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + LoadStorage(False, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) diff --git a/tests/common/storages/utils.py b/tests/common/storages/utils.py index 3158e84c24..e500f149ed 100644 --- a/tests/common/storages/utils.py +++ b/tests/common/storages/utils.py @@ -1,3 +1,4 @@ +from pathlib import Path from urllib.parse import urlparse import pytest import gzip @@ -6,7 +7,9 @@ from dlt.common import pendulum, json from dlt.common.configuration.resolve import resolve_configuration +from dlt.common.data_writers import DataWriter from dlt.common.schema import Schema +from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.storages import ( LoadStorageConfiguration, FilesystemConfiguration, @@ -14,15 +17,16 @@ TJobState, LoadStorage, ) +from dlt.common.storages import DataItemStorage, FileStorage from dlt.common.storages.fsspec_filesystem import FileItem, FileItemDict -from dlt.common.typing import StrAny +from dlt.common.typing import StrAny, TDataItems from dlt.common.utils import uniq_id @pytest.fixture def load_storage() -> LoadStorage: C = resolve_configuration(LoadStorageConfiguration()) - s = LoadStorage(True, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS, C) + s = LoadStorage(True, LoadStorage.ALL_SUPPORTED_FILE_FORMATS, C) return s @@ -98,13 +102,40 @@ def assert_sample_files( assert item["encoding"] == "gzip" +def write_temp_job_file( + item_storage: DataItemStorage, + file_storage: FileStorage, + load_id: str, + table_name: str, + table: TTableSchemaColumns, + file_id: str, + rows: TDataItems, +) -> str: + """Writes new file into new packages "new_jobs". Intended for testing""" + file_name = ( + item_storage._get_data_item_path_template(load_id, None, table_name) % file_id + + "." + + item_storage.writer_spec.file_extension + ) + mode = "wb" if item_storage.writer_spec.is_binary_format else "w" + with file_storage.open_file(file_name, mode=mode) as f: + writer = DataWriter.from_file_format( + item_storage.writer_spec.file_format, item_storage.writer_spec.data_item_format, f + ) + writer.write_all(table, rows) + return Path(file_name).name + + def start_loading_file( s: LoadStorage, content: Sequence[StrAny], start_job: bool = True ) -> Tuple[str, str]: load_id = uniq_id() s.new_packages.create_package(load_id) # write test file - file_name = s._write_temp_job_file(load_id, "mock_table", None, uniq_id(), content) + item_storage = s.create_item_storage("jsonl", "object") + file_name = write_temp_job_file( + item_storage, s.storage, load_id, "mock_table", None, uniq_id(), content + ) # write schema and schema update s.new_packages.save_schema(load_id, Schema("mock")) s.new_packages.save_schema_updates(load_id, {}) diff --git a/tests/destinations/test_custom_destination.py b/tests/destinations/test_custom_destination.py index cfefceac88..c8445a94dc 100644 --- a/tests/destinations/test_custom_destination.py +++ b/tests/destinations/test_custom_destination.py @@ -28,7 +28,7 @@ assert_all_data_types_row, ) -SUPPORTED_LOADER_FORMATS = ["parquet", "puae-jsonl"] +SUPPORTED_LOADER_FORMATS = ["parquet", "typed-jsonl"] def _run_through_sink( @@ -539,7 +539,7 @@ def test_sink(items, table): found_dlt_column = True # check actual data items - if loader_file_format == "puae-jsonl": + if loader_file_format == "typed-jsonl": for item in items: for key in item.keys(): if key.startswith("_dlt"): @@ -570,7 +570,7 @@ def test_max_nesting_level(nesting: int) -> None: found_tables = set() - @dlt.destination(loader_file_format="puae-jsonl", max_table_nesting=nesting) + @dlt.destination(loader_file_format="typed-jsonl", max_table_nesting=nesting) def nesting_sink(items, table): nonlocal found_tables found_tables.add(table["name"]) diff --git a/tests/load/pipeline/test_redshift.py b/tests/load/pipeline/test_redshift.py index 44234ec64b..574fb6b356 100644 --- a/tests/load/pipeline/test_redshift.py +++ b/tests/load/pipeline/test_redshift.py @@ -15,7 +15,7 @@ ids=lambda x: x.name, ) def test_redshift_blocks_time_column(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline("athena_" + uniq_id(), full_refresh=True) + pipeline = destination_config.setup_pipeline("redshift_" + uniq_id(), full_refresh=True) column_schemas, data_types = table_update_and_row() diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index c5e4f874fc..b25a643624 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -8,7 +8,7 @@ from dlt.common.exceptions import TerminalException, TerminalValueError from dlt.common.storages import FileStorage, PackageStorage, ParsedLoadJobFileName from dlt.common.storages.load_package import LoadJobInfo -from dlt.common.storages.load_storage import JobWithUnsupportedWriterException +from dlt.common.storages.load_storage import JobFileFormatUnsupported from dlt.common.destination.reference import LoadJob, TDestination from dlt.common.schema.utils import ( fill_hints_from_parent_and_clone_table, @@ -455,7 +455,7 @@ def test_wrong_writer_type() -> None: ], ) with ThreadPoolExecutor() as pool: - with pytest.raises(JobWithUnsupportedWriterException) as exv: + with pytest.raises(JobFileFormatUnsupported) as exv: load.run(pool) assert exv.value.load_id == load_id @@ -812,7 +812,7 @@ def setup_loader( staging = None if filesystem_staging: # do not accept jsonl to not conflict with filesystem destination - client_config = client_config or DummyClientConfiguration(loader_file_format="reference") + client_config = client_config or DummyClientConfiguration(loader_file_format="reference") # type: ignore[arg-type] staging_system_config = FilesystemDestinationClientConfiguration()._bind_dataset_name( dataset_name="dummy" ) diff --git a/tests/load/utils.py b/tests/load/utils.py index 8c5eda6d3b..d8daf996e1 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -573,13 +573,15 @@ def write_dataset( rows: Union[List[Dict[str, Any]], List[StrAny]], columns_schema: TTableSchemaColumns, ) -> None: - data_format = DataWriter.data_format_from_file_format( - client.capabilities.preferred_loader_file_format + spec = DataWriter.writer_spec_from_file_format( + client.capabilities.preferred_loader_file_format, "object" ) # adapt bytes stream to text file format - if not data_format.is_binary_format and isinstance(f.read(0), bytes): + if not spec.is_binary_format and isinstance(f.read(0), bytes): f = codecs.getwriter("utf-8")(f) # type: ignore[assignment] - writer = DataWriter.from_destination_capabilities(client.capabilities, f) + writer = DataWriter.from_file_format( + client.capabilities.preferred_loader_file_format, "object", f, client.capabilities + ) # remove None values for idx, row in enumerate(rows): rows[idx] = {k: v for k, v in row.items() if v is not None} diff --git a/tests/pipeline/test_schema_contracts.py b/tests/pipeline/test_schema_contracts.py index 2f2e6b6932..28a5f03fb1 100644 --- a/tests/pipeline/test_schema_contracts.py +++ b/tests/pipeline/test_schema_contracts.py @@ -12,10 +12,10 @@ from tests.load.pipeline.utils import load_table_counts from tests.utils import ( - TDataItemFormat, + TestDataItemFormat, skip_if_not_active, data_to_item_format, - ALL_DATA_ITEM_FORMATS, + ALL_TEST_DATA_ITEM_FORMATS, ) skip_if_not_active("duckdb") @@ -100,7 +100,7 @@ def run_resource( pipeline: Pipeline, resource_fun: Callable[..., DltResource], settings: Any, - item_format: TDataItemFormat = "json", + item_format: TestDataItemFormat = "json", duplicates: int = 1, ) -> None: for item in settings.keys(): @@ -149,9 +149,9 @@ def get_pipeline(): @pytest.mark.parametrize("contract_setting", schema_contract) @pytest.mark.parametrize("setting_location", LOCATIONS) -@pytest.mark.parametrize("item_format", ALL_DATA_ITEM_FORMATS) +@pytest.mark.parametrize("item_format", ALL_TEST_DATA_ITEM_FORMATS) def test_new_tables( - contract_setting: str, setting_location: str, item_format: TDataItemFormat + contract_setting: str, setting_location: str, item_format: TestDataItemFormat ) -> None: pipeline = get_pipeline() @@ -203,9 +203,9 @@ def test_new_tables( @pytest.mark.parametrize("contract_setting", schema_contract) @pytest.mark.parametrize("setting_location", LOCATIONS) -@pytest.mark.parametrize("item_format", ALL_DATA_ITEM_FORMATS) +@pytest.mark.parametrize("item_format", ALL_TEST_DATA_ITEM_FORMATS) def test_new_columns( - contract_setting: str, setting_location: str, item_format: TDataItemFormat + contract_setting: str, setting_location: str, item_format: TestDataItemFormat ) -> None: full_settings = {setting_location: {"columns": contract_setting}} diff --git a/tests/sources/helpers/rest_client/conftest.py b/tests/sources/helpers/rest_client/conftest.py index 7eec090db6..09676bdf37 100644 --- a/tests/sources/helpers/rest_client/conftest.py +++ b/tests/sources/helpers/rest_client/conftest.py @@ -28,9 +28,7 @@ def __init__(self, base_url: str): self.routes: List[Route] = [] self.base_url = base_url - def _add_route( - self, method: str, pattern: str, func: RequestCallback - ) -> RequestCallback: + def _add_route(self, method: str, pattern: str, func: RequestCallback) -> RequestCallback: compiled_pattern = re.compile(f"{self.base_url}{pattern}") self.routes.append(Route(method, compiled_pattern, func)) return func @@ -98,9 +96,7 @@ def paginate_response(request, records, page_size=10, records_key="data"): start_index = (page_number - 1) * 10 end_index = start_index + 10 records_slice = records[start_index:end_index] - return serialize_page( - records_slice, page_number, total_pages, request.url, records_key - ) + return serialize_page(records_slice, page_number, total_pages, request.url, records_key) @pytest.fixture(scope="module") @@ -137,9 +133,7 @@ def post_detail_404(request, context): @router.get(r"/posts_under_a_different_key$") def posts_with_results_key(request, context): - return paginate_response( - request, generate_posts(), records_key="many-results" - ) + return paginate_response(request, generate_posts(), records_key="many-results") @router.get("/protected/posts/basic-auth") def protected_basic_auth(request, context): @@ -199,6 +193,4 @@ def refresh_token(request, context): def assert_pagination(pages, expected_start=0, page_size=10): for i, page in enumerate(pages): - assert page == [ - {"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10) - ] + assert page == [{"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10)] diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index 7a4c55f9a6..b1038bced0 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -137,9 +137,7 @@ def test_bearer_token_auth_success(self, rest_client: RESTClient): def test_api_key_auth_success(self, rest_client: RESTClient): response = rest_client.get( "/protected/posts/api-key", - auth=APIKeyAuth( - name="x-api-key", api_key=cast(TSecretStrValue, "test-api-key") - ), + auth=APIKeyAuth(name="x-api-key", api_key=cast(TSecretStrValue, "test-api-key")), ) assert response.status_code == 200 assert response.json()["data"][0] == {"id": 0, "title": "Post 0"} diff --git a/tests/sources/helpers/rest_client/test_detector.py b/tests/sources/helpers/rest_client/test_detector.py index a9af1d36a4..933c9be9cc 100644 --- a/tests/sources/helpers/rest_client/test_detector.py +++ b/tests/sources/helpers/rest_client/test_detector.py @@ -101,9 +101,7 @@ }, { "response": { - "_embedded": { - "items": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}] - }, + "_embedded": {"items": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}]}, "_links": { "first": {"href": "http://api.example.com/items?page=0&size=2"}, "self": {"href": "http://api.example.com/items?page=1&size=2"}, @@ -315,9 +313,7 @@ def test_find_records(test_case): @pytest.mark.parametrize("test_case", TEST_RESPONSES) def test_find_next_page_key(test_case): response = test_case["response"] - expected = test_case.get("expected").get( - "next_path", None - ) # Some cases may not have next_path + expected = test_case.get("expected").get("next_path", None) # Some cases may not have next_path assert find_next_page_path(response) == expected diff --git a/tests/sources/helpers/rest_client/test_paginators.py b/tests/sources/helpers/rest_client/test_paginators.py index cc4dea65dc..258099292b 100644 --- a/tests/sources/helpers/rest_client/test_paginators.py +++ b/tests/sources/helpers/rest_client/test_paginators.py @@ -31,9 +31,7 @@ def test_update_state_without_next(self): class TestJSONResponsePaginator: def test_update_state_with_next(self): paginator = JSONResponsePaginator() - response = Mock( - Response, json=lambda: {"next": "http://example.com/next", "results": []} - ) + response = Mock(Response, json=lambda: {"next": "http://example.com/next", "results": []}) paginator.update_state(response) assert paginator.next_reference == "http://example.com/next" assert paginator.has_next_page is True @@ -55,9 +53,7 @@ def test_update_state(self): def test_update_state_with_next(self): paginator = SinglePagePaginator() - response = Mock( - Response, json=lambda: {"next": "http://example.com/next", "results": []} - ) + response = Mock(Response, json=lambda: {"next": "http://example.com/next", "results": []}) response.links = {"next": {"url": "http://example.com/next"}} paginator.update_state(response) assert paginator.has_next_page is False diff --git a/tests/utils.py b/tests/utils.py index 00523486ea..73e99c3fcd 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -78,8 +78,8 @@ # possible TDataItem types -TDataItemFormat = Literal["json", "pandas", "arrow", "arrow-batch"] -ALL_DATA_ITEM_FORMATS = get_args(TDataItemFormat) +TestDataItemFormat = Literal["json", "pandas", "arrow", "arrow-batch"] +ALL_TEST_DATA_ITEM_FORMATS = get_args(TestDataItemFormat) """List with TDataItem formats: json, arrow table/batch / pandas""" @@ -185,7 +185,7 @@ def wipe_pipeline() -> Iterator[None]: def data_to_item_format( - item_format: TDataItemFormat, data: Union[Iterator[TDataItem], Iterable[TDataItem]] + item_format: TestDataItemFormat, data: Union[Iterator[TDataItem], Iterable[TDataItem]] ) -> Any: """Return the given data in the form of pandas, arrow table/batch or json items""" if item_format == "json": @@ -251,7 +251,7 @@ def clean_test_storage( if init_loader: from dlt.common.storages import LoadStorage - LoadStorage(True, "jsonl", LoadStorage.ALL_SUPPORTED_FILE_FORMATS) + LoadStorage(True, LoadStorage.ALL_SUPPORTED_FILE_FORMATS) return storage From 3bef4daa5f6f536cff08f4730308963dbd57551f Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 5 Apr 2024 00:57:32 +0200 Subject: [PATCH 09/22] adds simple csv and postgres docs --- .../dlt-ecosystem/destinations/filesystem.md | 1 + .../dlt-ecosystem/destinations/postgres.md | 10 +++++ .../docs/dlt-ecosystem/file-formats/csv.md | 38 +++++++++++++++++++ .../dlt-ecosystem/file-formats/parquet.md | 6 +-- docs/website/docs/reference/performance.md | 2 +- 5 files changed, 53 insertions(+), 4 deletions(-) create mode 100644 docs/website/docs/dlt-ecosystem/file-formats/csv.md diff --git a/docs/website/docs/dlt-ecosystem/destinations/filesystem.md b/docs/website/docs/dlt-ecosystem/destinations/filesystem.md index 08de1202b1..4f7a924be1 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/filesystem.md +++ b/docs/website/docs/dlt-ecosystem/destinations/filesystem.md @@ -245,6 +245,7 @@ Please note: You can choose the following file formats: * [jsonl](../file-formats/jsonl.md) is used by default * [parquet](../file-formats/parquet.md) is supported +* [csv](../file-formats/csv.md) is supported ## Syncing of `dlt` state diff --git a/docs/website/docs/dlt-ecosystem/destinations/postgres.md b/docs/website/docs/dlt-ecosystem/destinations/postgres.md index b621cc902b..b806ba78fe 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/postgres.md +++ b/docs/website/docs/dlt-ecosystem/destinations/postgres.md @@ -78,8 +78,18 @@ If you set the [`replace` strategy](../../general-usage/full-loading.md) to `sta ## Data loading `dlt` will load data using large INSERT VALUES statements by default. Loading is multithreaded (20 threads by default). +### Fast loading with arrow tables and csv +You can use [arrow tables](../verified-sources/arrow-pandas.md) and [csv](../file-formats/csv.md) to quickly load tabular data. Pick the `csv` loader file format +like below +```py +info = pipeline.run(arrow_table, loader_file_format="csv") +``` +In the example above `arrow_table` will be converted to csv with **pyarrow** and then streamed into **postgres** with COPY command. This method skips the regular +`dlt` normalizer used for Python objects and is several times faster. + ## Supported file formats * [insert-values](../file-formats/insert-format.md) is used by default. +* [csv](../file-formats/csv.md) is supported ## Supported column hints `postgres` will create unique indexes for all columns with `unique` hints. This behavior **may be disabled**. diff --git a/docs/website/docs/dlt-ecosystem/file-formats/csv.md b/docs/website/docs/dlt-ecosystem/file-formats/csv.md new file mode 100644 index 0000000000..d72ef982af --- /dev/null +++ b/docs/website/docs/dlt-ecosystem/file-formats/csv.md @@ -0,0 +1,38 @@ +--- +title: csv +description: The csv file format +keywords: [csv, file formats] +--- + +# CSV File Format + +**csv** is the most basic file format to store tabular data, where all the values are strings and are separated by a delimiter (typically comma). +`dlt` uses it for specific use cases - mostly for the performance and compatibility reasons. + +Internally we use two implementations: +- **pyarrow** csv writer - very fast, multithreaded writer for the [arrow tables](../verified-sources/arrow-pandas.md) +- **python stdlib writer** - a csv writer included in the Python standard library for Python objects + + +## Supported Destinations + +Supported by: **Postgres**, **Filesystem** + +By setting the `loader_file_format` argument to `csv` in the run command, the pipeline will store your data in the csv format at the destination: + +```py +info = pipeline.run(some_source(), loader_file_format="csv") +``` + +## Default Settings +`dlt` attempts to make both writers to generate similarly looking files +* separators are commas +* quotes are **"** and are escaped as **""** +* `NULL` values are empty strings +* UNIX new lines are used +* dates are represented as ISO 8601 + +## Limitations + +* binary columns are supported only if they contain valid UTF-8 characters +* complex (nested, struct) types are not supported diff --git a/docs/website/docs/dlt-ecosystem/file-formats/parquet.md b/docs/website/docs/dlt-ecosystem/file-formats/parquet.md index 94aaaf4884..6be4f8799e 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/parquet.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/parquet.md @@ -16,7 +16,7 @@ pip install dlt[parquet] ## Supported Destinations -Supported by: **BigQuery**, **DuckDB**, **Snowflake**, **filesystem**, **Athena** +Supported by: **BigQuery**, **DuckDB**, **Snowflake**, **filesystem**, **Athena**, **Databricks**, **Synapse** By setting the `loader_file_format` argument to `parquet` in the run command, the pipeline will store your data in the parquet format at the destination: @@ -33,9 +33,9 @@ info = pipeline.run(some_source(), loader_file_format="parquet") Under the hood, `dlt` uses the [pyarrow parquet writer](https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html) to create the files. The following options can be used to change the behavior of the writer: -- `flavor`: Sanitize schema or set other compatibility options to work with various target systems. Defaults to "spark". +- `flavor`: Sanitize schema or set other compatibility options to work with various target systems. Defaults to None which is **pyarrow** default. - `version`: Determine which Parquet logical types are available for use, whether the reduced set from the Parquet 1.x.x format or the expanded logical types added in later format versions. Defaults to "2.4". -- `data_page_size`: Set a target threshold for the approximate encoded size of data pages within a column chunk (in bytes). Defaults to "1048576". +- `data_page_size`: Set a target threshold for the approximate encoded size of data pages within a column chunk (in bytes). Defaults to None which is **pyarrow** default. - `timestamp_timezone`: A string specifying timezone, default is UTC. Read the [pyarrow parquet docs](https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html) to learn more about these settings. diff --git a/docs/website/docs/reference/performance.md b/docs/website/docs/reference/performance.md index 605fad1a7c..657f25a793 100644 --- a/docs/website/docs/reference/performance.md +++ b/docs/website/docs/reference/performance.md @@ -247,7 +247,7 @@ from dlt.common import json - use `loadb` and `dumpb` methods to work with bytes without decoding strings You can switch to **simplejson** at any moment by (1) removing **orjson** dependency or (2) setting the following env variable: -``` +```sh DLT_USE_JSON=simplejson ``` ::: From 912fb828341fa59066f2f6d37a9dc86f4d4af564 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 5 Apr 2024 00:58:54 +0200 Subject: [PATCH 10/22] closes writers on exceptions, passes metrics on exceptions, fixes some edge cases with empty arrow files --- dlt/common/data_writers/buffered.py | 35 +++++--- dlt/common/data_writers/writers.py | 29 +++++-- dlt/common/storages/data_item_storage.py | 8 +- dlt/destinations/impl/bigquery/bigquery.py | 11 +-- .../impl/bigquery/bigquery_adapter.py | 6 +- dlt/extract/extract.py | 85 +++++++++++-------- dlt/extract/extractors.py | 5 +- dlt/extract/storage.py | 4 +- dlt/normalize/exceptions.py | 6 +- dlt/normalize/items_normalizers.py | 7 +- dlt/normalize/normalize.py | 33 ++++--- dlt/pipeline/pipeline.py | 4 +- tests/cases.py | 6 +- tests/common/storages/utils.py | 1 + tests/extract/test_extract.py | 69 ++++++++++++++- tests/libs/test_arrow_csv_writer.py | 6 +- tests/libs/test_parquet_writer.py | 3 +- tests/load/pipeline/test_arrow_loading.py | 2 +- .../load/pipeline/test_filesystem_pipeline.py | 2 +- tests/load/pipeline/test_postgres.py | 3 +- tests/load/utils.py | 1 + tests/normalize/test_normalize.py | 16 +++- tests/pipeline/test_arrow_sources.py | 67 ++++++++------- 23 files changed, 275 insertions(+), 134 deletions(-) diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index 1db18b065e..e358919c7a 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -164,9 +164,10 @@ def import_file(self, file_path: str, metrics: DataWriterMetrics) -> DataWriterM self._rotate_file() return metrics - def close(self) -> None: + def close(self, skip_flush: bool = False) -> None: + """Flushes the data, writes footer (skip_flush is True), collects metrics and closes the underlying file.""" self._ensure_open() - self._flush_and_close_file() + self._flush_and_close_file(skip_flush=skip_flush) self._closed = True @property @@ -177,7 +178,8 @@ def __enter__(self) -> "BufferedDataWriter[TWriter]": return self def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: Any) -> None: - self.close() + # skip flush if we had exception + self.close(skip_flush=exc_val is not None) def _rotate_file(self, allow_empty_file: bool = False) -> DataWriterMetrics: metrics = self._flush_and_close_file(allow_empty_file) @@ -188,7 +190,7 @@ def _rotate_file(self, allow_empty_file: bool = False) -> DataWriterMetrics: return metrics def _flush_items(self, allow_empty_file: bool = False) -> None: - if self._buffered_items_count > 0 or allow_empty_file: + if self._buffered_items or allow_empty_file: # we only open a writer when there are any items in the buffer and first flush is requested if not self._writer: # create new writer and write header @@ -205,15 +207,22 @@ def _flush_items(self, allow_empty_file: bool = False) -> None: self._buffered_items.clear() self._buffered_items_count = 0 - def _flush_and_close_file(self, allow_empty_file: bool = False) -> DataWriterMetrics: - # if any buffered items exist, flush them - self._flush_items(allow_empty_file) - # if writer exists then close it - if not self._writer: - return None - # write the footer of a file - self._writer.write_footer() - self._file.flush() + def _flush_and_close_file( + self, allow_empty_file: bool = False, skip_flush: bool = False + ) -> DataWriterMetrics: + if not skip_flush: + # if any buffered items exist, flush them + self._flush_items(allow_empty_file) + # if writer exists then close it + if not self._writer: + return None + # write the footer of a file + self._writer.write_footer() + self._file.flush() + else: + if not self._writer: + return None + self._writer.close() # add file written to the list so we can commit all the files later metrics = DataWriterMetrics( self._file_name, diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 67e5466d39..9936a6844d 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -84,6 +84,9 @@ def write_data(self, rows: Sequence[Any]) -> None: def write_footer(self) -> None: # noqa pass + def close(self) -> None: # noqa + pass + def write_all(self, columns_schema: TTableSchemaColumns, rows: Sequence[Any]) -> None: self.write_header(columns_schema) self.write_data(rows) @@ -321,9 +324,10 @@ def write_data(self, rows: Sequence[Any]) -> None: # Write self.writer.write_table(table, row_group_size=self.parquet_row_group_size) - def write_footer(self) -> None: - self.writer.close() - self.writer = None + def close(self) -> None: # noqa + if self.writer: + self.writer.close() + self.writer = None @classmethod def writer_spec(cls) -> FileWriterSpec: @@ -362,10 +366,9 @@ def write_data(self, rows: Sequence[Any]) -> None: # count rows that got written self.items_count += sum(len(row) for row in rows) - def write_footer(self) -> None: - if self.writer is None: - self.writer = None - self._first_schema = None + def close(self) -> None: + self.writer = None + self._first_schema = None @classmethod def writer_spec(cls) -> FileWriterSpec: @@ -405,6 +408,9 @@ def write_footer(self) -> None: raise NotImplementedError("Arrow Writer does not support writing empty files") return super().write_footer() + def close(self) -> None: + return super().close() + @classmethod def writer_spec(cls) -> FileWriterSpec: return FileWriterSpec( @@ -488,10 +494,15 @@ def write_footer(self) -> None: # write empty file self._f.write( self.delimiter.join( - [col["name"].encode("utf-8") for col in self._columns_schema.values()] + [ + b'"' + col["name"].encode("utf-8") + b'"' + for col in self._columns_schema.values() + ] ) ) - else: + + def close(self) -> None: + if self.writer: self.writer.close() self.writer = None self._first_schema = None diff --git a/dlt/common/storages/data_item_storage.py b/dlt/common/storages/data_item_storage.py index 5b1e360789..ab15c3ad5b 100644 --- a/dlt/common/storages/data_item_storage.py +++ b/dlt/common/storages/data_item_storage.py @@ -72,15 +72,17 @@ def import_items_file( writer = self._get_writer(load_id, schema_name, table_name) return writer.import_file(file_path, metrics) - def close_writers(self, load_id: str) -> None: - # flush and close all files + def close_writers(self, load_id: str, skip_flush: bool = False) -> None: + """Flush, write footers (skip_flush), write metrics and close files in all + writers belonging to `load_id` package + """ for name, writer in self.buffered_writers.items(): if name.startswith(load_id) and not writer.closed: logger.debug( f"Closing writer for {name} with file {writer._file} and actual name" f" {writer._file_name}" ) - writer.close() + writer.close(skip_flush=skip_flush) def closed_files(self, load_id: str) -> List[DataWriterMetrics]: """Return metrics for all fully processed (closed) files""" diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 279917d3a0..b2e53f9734 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -232,10 +232,9 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> if insert_api == "streaming": if table["write_disposition"] != "append": raise DestinationTerminalException( - ( - "BigQuery streaming insert can only be used with `append` write_disposition, while " - f'the given resource has `{table["write_disposition"]}`.' - ) + "BigQuery streaming insert can only be used with `append`" + " write_disposition, while the given resource has" + f" `{table['write_disposition']}`." ) if file_path.endswith(".jsonl"): job_cls = DestinationJsonlLoadJob @@ -364,7 +363,9 @@ def prepare_load_table( def _get_column_def_sql(self, column: TColumnSchema, table_format: TTableFormat = None) -> str: name = self.capabilities.escape_identifier(column["name"]) - column_def_sql = f"{name} {self.type_mapper.to_db_type(column, table_format)} {self._gen_not_null(column.get('nullable', True))}" + column_def_sql = ( + f"{name} {self.type_mapper.to_db_type(column, table_format)} {self._gen_not_null(column.get('nullable', True))}" + ) if column.get(ROUND_HALF_EVEN_HINT, False): column_def_sql += " OPTIONS (rounding_mode='ROUND_HALF_EVEN')" if column.get(ROUND_HALF_AWAY_FROM_ZERO_HINT, False): diff --git a/dlt/destinations/impl/bigquery/bigquery_adapter.py b/dlt/destinations/impl/bigquery/bigquery_adapter.py index 8943b0da79..6b3ef32b0f 100644 --- a/dlt/destinations/impl/bigquery/bigquery_adapter.py +++ b/dlt/destinations/impl/bigquery/bigquery_adapter.py @@ -153,10 +153,8 @@ def bigquery_adapter( if insert_api is not None: if insert_api == "streaming" and data.write_disposition != "append": raise ValueError( - ( - "BigQuery streaming insert can only be used with `append` write_disposition, while " - f"the given resource has `{data.write_disposition}`." - ) + "BigQuery streaming insert can only be used with `append` write_disposition, while " + f"the given resource has `{data.write_disposition}`." ) additional_table_hints |= {"x-insert-api": insert_api} # type: ignore[operator] diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index 75f22bb802..02dd06eaf3 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -2,7 +2,7 @@ from collections.abc import Sequence as C_Sequence from copy import copy import itertools -from typing import List, Dict, Any +from typing import Iterator, List, Dict, Any import yaml from dlt.common.configuration.container import Container @@ -304,41 +304,58 @@ def _extract_single_source( load_id, self.extract_storage.item_storages["arrow"], schema, collector=collector ), } - + # make sure we close storage on exception with collector(f"Extract {source.name}"): - self._step_info_start_load_id(load_id) - # yield from all selected pipes - with PipeIterator.from_pipes( - source.resources.selected_pipes, - max_parallel_items=max_parallel_items, - workers=workers, - futures_poll_interval=futures_poll_interval, - ) as pipes: - left_gens = total_gens = len(pipes._sources) - collector.update("Resources", 0, total_gens) - for pipe_item in pipes: - curr_gens = len(pipes._sources) - if left_gens > curr_gens: - delta = left_gens - curr_gens - left_gens -= delta - collector.update("Resources", delta) - signals.raise_if_signalled() - resource = source.resources[pipe_item.pipe.name] - item_format = get_data_item_format(pipe_item.item) - extractors[item_format].write_items(resource, pipe_item.item, pipe_item.meta) - - self._write_empty_files(source, extractors) - if left_gens > 0: - # go to 100% - collector.update("Resources", left_gens) - - # flush all buffered writers + with self.manage_writers(load_id, source): + # yield from all selected pipes + with PipeIterator.from_pipes( + source.resources.selected_pipes, + max_parallel_items=max_parallel_items, + workers=workers, + futures_poll_interval=futures_poll_interval, + ) as pipes: + left_gens = total_gens = len(pipes._sources) + collector.update("Resources", 0, total_gens) + for pipe_item in pipes: + curr_gens = len(pipes._sources) + if left_gens > curr_gens: + delta = left_gens - curr_gens + left_gens -= delta + collector.update("Resources", delta) + signals.raise_if_signalled() + resource = source.resources[pipe_item.pipe.name] + item_format = get_data_item_format(pipe_item.item) + extractors[item_format].write_items( + resource, pipe_item.item, pipe_item.meta + ) + + self._write_empty_files(source, extractors) + if left_gens > 0: + # go to 100% + collector.update("Resources", left_gens) + + @contextlib.contextmanager + def manage_writers(self, load_id: str, source: DltSource) -> Iterator[ExtractStorage]: + self._step_info_start_load_id(load_id) + # self.current_source = source + try: + yield self.extract_storage + except Exception: + # kill writers without flushing the content + self.extract_storage.close_writers(load_id, skip_flush=True) + raise + else: self.extract_storage.close_writers(load_id) - # gather metrics - self._step_info_complete_load_id(load_id, self._compute_metrics(load_id, source)) - # remove the metrics of files processed in this extract run - # NOTE: there may be more than one extract run per load id: ie. the resource and then dlt state - self.extract_storage.remove_closed_files(load_id) + finally: + # gather metrics when storage is closed + self.gather_metrics(load_id, source) + + def gather_metrics(self, load_id: str, source: DltSource) -> None: + # gather metrics + self._step_info_complete_load_id(load_id, self._compute_metrics(load_id, source)) + # remove the metrics of files processed in this extract run + # NOTE: there may be more than one extract run per load id: ie. the resource and then dlt state + self.extract_storage.remove_closed_files(load_id) def extract( self, diff --git a/dlt/extract/extractors.py b/dlt/extract/extractors.py index c4b7653164..421250951e 100644 --- a/dlt/extract/extractors.py +++ b/dlt/extract/extractors.py @@ -122,10 +122,11 @@ def _write_item( self.load_id, self.schema.name, table_name, items, columns ) self.collector.update(table_name, inc=new_rows_count) - if new_rows_count > 0: + # if there were rows or item was empty arrow table + if new_rows_count > 0 or self.__class__ is ArrowExtractor: self.resources_with_items.add(resource_name) else: - if isinstance(items, MaterializedEmptyList) or self.__class__ is ArrowExtractor: + if isinstance(items, MaterializedEmptyList): self.resources_with_empty.add(resource_name) def _write_to_dynamic_table(self, resource: DltResource, items: TDataItems) -> None: diff --git a/dlt/extract/storage.py b/dlt/extract/storage.py index b76822a4f2..3e01a020ba 100644 --- a/dlt/extract/storage.py +++ b/dlt/extract/storage.py @@ -74,9 +74,9 @@ def create_load_package(self, schema: Schema, reuse_exiting_package: bool = True self.new_packages.save_schema(load_id, schema) return load_id - def close_writers(self, load_id: str) -> None: + def close_writers(self, load_id: str, skip_flush: bool = False) -> None: for storage in self.item_storages.values(): - storage.close_writers(load_id) + storage.close_writers(load_id, skip_flush=skip_flush) def closed_files(self, load_id: str) -> List[DataWriterMetrics]: files = [] diff --git a/dlt/normalize/exceptions.py b/dlt/normalize/exceptions.py index a172196899..7bc305fcbe 100644 --- a/dlt/normalize/exceptions.py +++ b/dlt/normalize/exceptions.py @@ -1,3 +1,4 @@ +from typing import Any, List from dlt.common.exceptions import DltException @@ -7,10 +8,13 @@ def __init__(self, msg: str) -> None: class NormalizeJobFailed(NormalizeException): - def __init__(self, load_id: str, job_id: str, failed_message: str) -> None: + def __init__( + self, load_id: str, job_id: str, failed_message: str, writer_metrics: List[Any] + ) -> None: self.load_id = load_id self.job_id = job_id self.failed_message = failed_message + self.writer_metrics = writer_metrics super().__init__( f"Job for {job_id} failed terminally in load {load_id} with message {failed_message}." ) diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index bf4073ddbf..1e4e55effd 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -283,7 +283,7 @@ def _write_with_dlt_columns( items_count = 0 columns_schema = schema.get_table_columns(root_table_name) # if we use adapter to convert arrow to dicts, then normalization is not necessary - may_normalize = not issubclass(self.item_storage.writer_cls, ArrowToObjectAdapter) + is_native_arrow_writer = not issubclass(self.item_storage.writer_cls, ArrowToObjectAdapter) should_normalize: bool = None with self.normalize_storage.extracted_packages.storage.open_file( extracted_items_file, "rb" @@ -293,7 +293,7 @@ def _write_with_dlt_columns( ): items_count += batch.num_rows # we may need to normalize - if may_normalize and should_normalize is None: + if is_native_arrow_writer and should_normalize is None: should_normalize, _, _, _ = pyarrow.should_normalize_arrow_schema( batch.schema, columns_schema, schema.naming ) @@ -315,7 +315,8 @@ def _write_with_dlt_columns( batch, columns_schema, ) - if items_count == 0: + # TODO: better to check if anything is in the buffer and skip writing file + if items_count == 0 and not is_native_arrow_writer: self.item_storage.write_empty_items_file( load_id, schema.name, diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 28c2c81571..47d0cd9898 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -142,6 +142,19 @@ def _get_items_normalizer(item_format: TDataItemFormat) -> ItemsNormalizer: ) return norm + def _gather_metrics_and_close(skip_flush: bool) -> List[DataWriterMetrics]: + for normalizer in item_normalizers.values(): + normalizer.item_storage.close_writers(load_id, skip_flush=skip_flush) + + writer_metrics: List[DataWriterMetrics] = [] + for normalizer in item_normalizers.values(): + norm_metrics = normalizer.item_storage.closed_files(load_id) + writer_metrics.extend(norm_metrics) + + for normalizer in item_normalizers.values(): + normalizer.item_storage.remove_closed_files(load_id) + return writer_metrics + parsed_file_name: ParsedLoadJobFileName = None try: root_tables: Set[str] = set() @@ -165,15 +178,11 @@ def _get_items_normalizer(item_format: TDataItemFormat) -> ItemsNormalizer: logger.debug(f"Processed file {extracted_items_file}") except Exception as exc: job_id = parsed_file_name.job_id() if parsed_file_name else "" - raise NormalizeJobFailed(load_id, job_id, str(exc)) from exc - finally: - for normalizer in item_normalizers.values(): - normalizer.item_storage.close_writers(load_id) + writer_metrics = _gather_metrics_and_close(skip_flush=True) + raise NormalizeJobFailed(load_id, job_id, str(exc), writer_metrics) from exc + else: + writer_metrics = _gather_metrics_and_close(skip_flush=False) - writer_metrics: List[DataWriterMetrics] = [] - for normalizer in item_normalizers.values(): - norm_metrics = normalizer.item_storage.closed_files(load_id) - writer_metrics.extend(norm_metrics) logger.info(f"Processed all items in {len(extracted_items_files)} files") return TWorkerRV(schema_updates, writer_metrics) @@ -233,9 +242,11 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TW for task in list(tasks): pending, params = task if pending.done(): - result: TWorkerRV = ( - pending.result() - ) # Exception in task (if any) is raised here + # collect metrics from the exception (if any) + if isinstance(pending.exception(), NormalizeJobFailed): + summary.file_metrics.extend(pending.exception().writer_metrics) # type: ignore[attr-defined] + # Exception in task (if any) is raised here + result: TWorkerRV = pending.result() try: # gather schema from all manifests, validate consistency and combine self.update_table(schema, result[0]) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index b0d04dfbe8..683251c2a8 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -436,11 +436,13 @@ def extract( extract_step.commit_packages() return self._get_step_info(extract_step) except Exception as exc: + # emit step info step_info = self._get_step_info(extract_step) + current_load_id = step_info.loads_ids[-1] if len(step_info.loads_ids) > 0 else None raise PipelineStepFailed( self, "extract", - extract_step.current_load_id, + current_load_id, exc, step_info, ) from exc diff --git a/tests/cases.py b/tests/cases.py index 9a0213d837..b598f1169e 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -294,7 +294,7 @@ def arrow_table_all_data_types( include_name_clash: bool = False, num_rows: int = 3, tz="UTC", -) -> Tuple[Any, List[Dict[str, Any]]]: +) -> Tuple[Any, List[Dict[str, Any]], Dict[str, List[Any]]]: """Create an arrow object or pandas dataframe with all supported data types. Returns the table and its records in python format @@ -351,14 +351,14 @@ def arrow_table_all_data_types( .drop(columns=["null"]) .to_dict("records") ) - return arrow_format_from_pandas(df, object_format), rows + return arrow_format_from_pandas(df, object_format), rows, data def prepare_shuffled_tables() -> Tuple[Any, Any, Any]: from dlt.common.libs.pyarrow import remove_columns from dlt.common.libs.pyarrow import pyarrow as pa - table, _ = arrow_table_all_data_types( + table, _, _ = arrow_table_all_data_types( "table", include_json=False, include_not_normalized_name=False, diff --git a/tests/common/storages/utils.py b/tests/common/storages/utils.py index e500f149ed..13ec253e2f 100644 --- a/tests/common/storages/utils.py +++ b/tests/common/storages/utils.py @@ -123,6 +123,7 @@ def write_temp_job_file( item_storage.writer_spec.file_format, item_storage.writer_spec.data_item_format, f ) writer.write_all(table, rows) + writer.close() return Path(file_name).name diff --git a/tests/extract/test_extract.py b/tests/extract/test_extract.py index 1879eaa9eb..9620e7fdfb 100644 --- a/tests/extract/test_extract.py +++ b/tests/extract/test_extract.py @@ -11,12 +11,12 @@ from dlt.common.storages.schema_storage import SchemaStorage from dlt.extract import DltResource, DltSource -from dlt.extract.exceptions import DataItemRequiredForDynamicTableHints +from dlt.extract.exceptions import DataItemRequiredForDynamicTableHints, ResourceExtractionError from dlt.extract.extract import ExtractStorage, Extract from dlt.extract.hints import make_hints from dlt.extract.items import TableNameMeta -from tests.utils import clean_test_storage, TEST_STORAGE_ROOT +from tests.utils import MockPipeline, clean_test_storage, TEST_STORAGE_ROOT from tests.extract.utils import expect_extracted_file @@ -211,6 +211,71 @@ def with_table_hints(): extract_step.extract(source, 20, 1) +def test_extract_metrics_on_exception_no_flush(extract_step: Extract) -> None: + @dlt.resource + def letters(): + # extract 7 items + yield from "ABCDEFG" + # then fail + raise RuntimeError() + yield from "HI" + + source = DltSource(dlt.Schema("letters"), "module", [letters]) + with pytest.raises(ResourceExtractionError): + extract_step.extract(source, 20, 1) + step_info = extract_step.get_step_info(MockPipeline("buba", first_run=False)) # type: ignore[abstract] + # no jobs were created + assert len(step_info.load_packages[0].jobs["new_jobs"]) == 0 + # make sure all writers are closed but not yet removed + current_load_id = step_info.loads_ids[-1] if len(step_info.loads_ids) > 0 else None + # get buffered writers + writers = extract_step.extract_storage.item_storages["object"].buffered_writers + assert len(writers) == 1 + for name, writer in writers.items(): + assert name.startswith(current_load_id) + assert writer._file is None + + +def test_extract_metrics_on_exception_without_flush(extract_step: Extract) -> None: + @dlt.resource + def letters(): + # extract 7 items + yield from "ABCDEFG" + # then fail + raise RuntimeError() + yield from "HI" + + # flush buffer + os.environ["DATA_WRITER__BUFFER_MAX_ITEMS"] = "4" + source = DltSource(dlt.Schema("letters"), "module", [letters]) + with pytest.raises(ResourceExtractionError): + extract_step.extract(source, 20, 1) + step_info = extract_step.get_step_info(MockPipeline("buba", first_run=False)) # type: ignore[abstract] + # one job created because the file was flushed + jobs = step_info.load_packages[0].jobs["new_jobs"] + # print(jobs[0].job_file_info.job_id()) + assert len(jobs) == 1 + current_load_id = step_info.loads_ids[-1] if len(step_info.loads_ids) > 0 else None + # 7 items were extracted + assert ( + step_info.metrics[current_load_id][0]["job_metrics"][ + jobs[0].job_file_info.job_id() + ].items_count + == 4 + ) + # get buffered writers + writers = extract_step.extract_storage.item_storages["object"].buffered_writers + assert len(writers) == 1 + for name, writer in writers.items(): + assert name.startswith(current_load_id) + assert writer._file is None + + +def test_extract_empty_metrics(extract_step: Extract) -> None: + step_info = extract_step.get_step_info(MockPipeline("buba", first_run=False)) # type: ignore[abstract] + assert step_info.load_packages == step_info.loads_ids == [] + + # def test_extract_pipe_from_unknown_resource(): # pass diff --git a/tests/libs/test_arrow_csv_writer.py b/tests/libs/test_arrow_csv_writer.py index 91038f01c4..85a15cc169 100644 --- a/tests/libs/test_arrow_csv_writer.py +++ b/tests/libs/test_arrow_csv_writer.py @@ -18,7 +18,7 @@ def test_csv_writer_all_data_fields() -> None: - data = TABLE_ROW_ALL_DATA_TYPES_DATETIMES + data = copy(TABLE_ROW_ALL_DATA_TYPES_DATETIMES) # write parquet and read it with get_writer(ParquetDataWriter) as pq_writer: @@ -94,7 +94,7 @@ def test_csv_writer_all_data_fields() -> None: def test_non_utf8_binary() -> None: - data = TABLE_ROW_ALL_DATA_TYPES_DATETIMES + data = copy(TABLE_ROW_ALL_DATA_TYPES_DATETIMES) data["col7"] += b"\x8e" # type: ignore[operator] # write parquet and read it @@ -110,7 +110,7 @@ def test_non_utf8_binary() -> None: def test_arrow_struct() -> None: - item, _ = arrow_table_all_data_types("table", include_json=True, include_time=False) + item, _, _ = arrow_table_all_data_types("table", include_json=True, include_time=False) with pytest.raises(InvalidDataItem): with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: writer.write_data_item(item, TABLE_UPDATE_COLUMNS_SCHEMA) diff --git a/tests/libs/test_parquet_writer.py b/tests/libs/test_parquet_writer.py index 3b4239f2b0..786617ef55 100644 --- a/tests/libs/test_parquet_writer.py +++ b/tests/libs/test_parquet_writer.py @@ -128,8 +128,7 @@ def test_parquet_writer_all_data_fields() -> None: assert actual == value assert table.schema.field("col1_precision").type == pa.int16() - # flavor=spark only writes ns precision timestamp, so this is expected - assert table.schema.field("col4_precision").type == pa.timestamp("ns") + assert table.schema.field("col4_precision").type == pa.timestamp("ms", tz="UTC") assert table.schema.field("col5_precision").type == pa.string() assert table.schema.field("col6_precision").type == pa.decimal128(6, 2) assert table.schema.field("col7_precision").type == pa.binary(19) diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 2c649c18de..98f44b1c8a 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -52,7 +52,7 @@ def test_load_arrow_item( destination_config.destination == "databricks" and destination_config.file_format == "jsonl" ) - item, records = arrow_table_all_data_types( + item, records, _ = arrow_table_all_data_types( item_type, include_json=False, include_time=include_time, diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 6d33b477fc..8401f9d3af 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -109,7 +109,7 @@ def test_pipeline_csv_filesystem_destination() -> None: dataset_name="parquet_test_" + uniq_id(), ) - item, _ = arrow_table_all_data_types("table", include_json=False, include_time=True) + item, _, _ = arrow_table_all_data_types("table", include_json=False, include_time=True) info = pipeline.run(item, table_name="table", loader_file_format="csv") info.raise_on_failed_jobs() job = info.load_packages[0].jobs["completed_jobs"][0].file_path diff --git a/tests/load/pipeline/test_postgres.py b/tests/load/pipeline/test_postgres.py index bf57bb0c4e..50c14e9cda 100644 --- a/tests/load/pipeline/test_postgres.py +++ b/tests/load/pipeline/test_postgres.py @@ -65,12 +65,13 @@ def test_postgres_empty_csv_from_arrow(destination_config: DestinationTestConfig os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" os.environ["RESTORE_FROM_DESTINATION"] = "False" pipeline = destination_config.setup_pipeline("postgres_" + uniq_id(), full_refresh=True) - table, _ = arrow_table_all_data_types("table", include_json=False) + table, _, _ = arrow_table_all_data_types("table", include_json=False) load_info = pipeline.run( table.schema.empty_table(), table_name="table", loader_file_format="csv" ) assert_load_info(load_info) + assert len(load_info.load_packages[0].jobs["completed_jobs"]) == 1 job = load_info.load_packages[0].jobs["completed_jobs"][0].file_path assert job.endswith("csv") assert_data_table_counts(pipeline, {"table": 0}) diff --git a/tests/load/utils.py b/tests/load/utils.py index d8daf996e1..3972f3ad95 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -586,6 +586,7 @@ def write_dataset( for idx, row in enumerate(rows): rows[idx] = {k: v for k, v in row.items() if v is not None} writer.write_all(columns_schema, rows) + writer.close() def prepare_load_package( diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index ad31e6240e..91997a921e 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -472,7 +472,7 @@ def test_normalize_retry(raw_normalize: Normalize) -> None: schema = raw_normalize.normalize_storage.extracted_packages.load_schema(load_id) schema.set_schema_contract("freeze") raw_normalize.normalize_storage.extracted_packages.save_schema(load_id, schema) - # will fail on contract violatiom + # will fail on contract violation with pytest.raises(NormalizeJobFailed): raw_normalize.run(None) @@ -492,6 +492,20 @@ def test_normalize_retry(raw_normalize: Normalize) -> None: assert len(table_files["issues"]) == 1 +def test_collect_metrics_on_exception(raw_normalize: Normalize) -> None: + load_id = extract_cases(raw_normalize, ["github.issues.load_page_5_duck"]) + schema = raw_normalize.normalize_storage.extracted_packages.load_schema(load_id) + schema.set_schema_contract("freeze") + raw_normalize.normalize_storage.extracted_packages.save_schema(load_id, schema) + # will fail on contract violation + with pytest.raises(NormalizeJobFailed) as job_ex: + raw_normalize.run(None) + # we excepted on a first row so nothing was written + # TODO: improve this test to write some rows in buffered writer + assert len(job_ex.value.writer_metrics) == 0 + raw_normalize.get_step_info(MockPipeline("multiprocessing_pipeline", True)) # type: ignore[abstract] + + def test_group_worker_files() -> None: files = ["f%03d" % idx for idx in range(0, 100)] diff --git a/tests/pipeline/test_arrow_sources.py b/tests/pipeline/test_arrow_sources.py index 96159648ea..b16da73868 100644 --- a/tests/pipeline/test_arrow_sources.py +++ b/tests/pipeline/test_arrow_sources.py @@ -9,7 +9,7 @@ import dlt from dlt.common import json, Decimal from dlt.common.utils import uniq_id -from dlt.common.libs.pyarrow import NameNormalizationClash +from dlt.common.libs.pyarrow import NameNormalizationClash, remove_columns, normalize_py_arrow_item from dlt.pipeline.exceptions import PipelineStepFailed @@ -35,7 +35,7 @@ ], ) def test_extract_and_normalize(item_type: TArrowFormat, is_list: bool): - item, records = arrow_table_all_data_types(item_type) + item, records, data = arrow_table_all_data_types(item_type) pipeline = dlt.pipeline("arrow_" + uniq_id(), destination="filesystem") @@ -72,19 +72,25 @@ def some_data(): assert normalized_bytes == extracted_bytes f.seek(0) - pq = pa.parquet.ParquetFile(f) - tbl = pq.read() - - # To make tables comparable exactly write the expected data to parquet and read it back - # The spark parquet writer loses timezone info - tbl_expected = pa.Table.from_pandas(pd.DataFrame(records)) - with io.BytesIO() as f: - pa.parquet.write_table(tbl_expected, f, flavor="spark") - f.seek(0) - tbl_expected = pa.parquet.read_table(f) - df_tbl = tbl_expected.to_pandas(ignore_metadata=True) + with pa.parquet.ParquetFile(f) as pq: + tbl = pq.read() + + # use original data to create data frame to preserve timestamp precision, timezones etc. + tbl_expected = pa.Table.from_pandas(pd.DataFrame(data)) + # null is removed by dlt + tbl_expected = remove_columns(tbl_expected, ["null"]) + # we want to normalize column names + tbl_expected = normalize_py_arrow_item( + tbl_expected, + pipeline.default_schema.get_table_columns("some_data"), + pipeline.default_schema.naming, + None, + ) + assert tbl_expected.schema.equals(tbl.schema) + + df_tbl = tbl_expected.to_pandas(ignore_metadata=False) # Data is identical to the original dataframe - df_result = tbl.to_pandas(ignore_metadata=True) + df_result = tbl.to_pandas(ignore_metadata=False) assert df_result.equals(df_tbl) schema = pipeline.default_schema @@ -116,7 +122,7 @@ def some_data(): def test_normalize_jsonl(item_type: TArrowFormat, is_list: bool): os.environ["DUMMY__LOADER_FILE_FORMAT"] = "jsonl" - item, records = arrow_table_all_data_types(item_type) + item, records, _ = arrow_table_all_data_types(item_type, tz="Europe/Berlin") pipeline = dlt.pipeline("arrow_" + uniq_id(), destination="dummy") @@ -136,21 +142,18 @@ def some_data(): job = [j for j in jobs if "some_data" in j][0] with storage.normalized_packages.storage.open_file(job, "r") as f: result = [json.loads(line) for line in f] - for row in result: - row["decimal"] = Decimal(row["decimal"]) - - for record in records: - record["datetime"] = record["datetime"].replace(tzinfo=None) expected = json.loads(json.dumps(records)) - for record in expected: - record["decimal"] = Decimal(record["decimal"]) - assert result == expected + assert len(result) == len(expected) + for res_item, exp_item in zip(result, expected): + res_item["decimal"] = Decimal(res_item["decimal"]) + exp_item["decimal"] = Decimal(exp_item["decimal"]) + assert res_item == exp_item @pytest.mark.parametrize("item_type", ["table", "record_batch"]) def test_add_map(item_type: TArrowFormat): - item, records = arrow_table_all_data_types(item_type, num_rows=200) + item, _, _ = arrow_table_all_data_types(item_type, num_rows=200) @dlt.resource def some_data(): @@ -180,7 +183,7 @@ def test_extract_normalize_file_rotation(item_type: TArrowFormat) -> None: pipeline_name = "arrow_" + uniq_id() pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination="dummy") - item, rows = arrow_table_all_data_types(item_type) + item, rows, _ = arrow_table_all_data_types(item_type) @dlt.resource def data_frames(): @@ -209,7 +212,7 @@ def test_arrow_clashing_names(item_type: TArrowFormat) -> None: pipeline_name = "arrow_" + uniq_id() pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination="dummy") - item, _ = arrow_table_all_data_types(item_type, include_name_clash=True) + item, _, _ = arrow_table_all_data_types(item_type, include_name_clash=True) @dlt.resource def data_frames(): @@ -226,10 +229,10 @@ def test_load_arrow_vary_schema(item_type: TArrowFormat) -> None: pipeline_name = "arrow_" + uniq_id() pipeline = dlt.pipeline(pipeline_name=pipeline_name, destination="duckdb") - item, _ = arrow_table_all_data_types(item_type, include_not_normalized_name=False) + item, _, _ = arrow_table_all_data_types(item_type, include_not_normalized_name=False) pipeline.run(item, table_name="data").raise_on_failed_jobs() - item, _ = arrow_table_all_data_types(item_type, include_not_normalized_name=False) + item, _, _ = arrow_table_all_data_types(item_type, include_not_normalized_name=False) # remove int column try: item = item.drop("int") @@ -245,7 +248,7 @@ def test_arrow_as_data_loading(item_type: TArrowFormat) -> None: os.environ["RESTORE_FROM_DESTINATION"] = "False" os.environ["DESTINATION__LOADER_FILE_FORMAT"] = "parquet" - item, rows = arrow_table_all_data_types(item_type) + item, rows, _ = arrow_table_all_data_types(item_type) item_resource = dlt.resource(item, name="item") assert id(item) == id(list(item_resource)[0]) @@ -260,7 +263,7 @@ def test_arrow_as_data_loading(item_type: TArrowFormat) -> None: @pytest.mark.parametrize("item_type", ["table"]) # , "pandas", "record_batch" def test_normalize_with_dlt_columns(item_type: TArrowFormat): - item, records = arrow_table_all_data_types(item_type, num_rows=5432) + item, records, _ = arrow_table_all_data_types(item_type, num_rows=5432) os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_LOAD_ID"] = "True" os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_ID"] = "True" # Test with buffer smaller than the number of batches to be written @@ -316,7 +319,7 @@ def some_data(): pipeline.run(item, table_name="some_data").raise_on_failed_jobs() # should be able to load arrow with a new column - item, records = arrow_table_all_data_types(item_type, num_rows=200) + item, records, _ = arrow_table_all_data_types(item_type, num_rows=200) item = item.append_column("static_int", [[0] * 200]) pipeline.run(item, table_name="some_data").raise_on_failed_jobs() @@ -475,7 +478,7 @@ def test_empty_arrow(item_type: TArrowFormat) -> None: os.environ["DESTINATION__LOADER_FILE_FORMAT"] = "parquet" # always return pandas - item, _ = arrow_table_all_data_types("pandas", num_rows=1) + item, _, _ = arrow_table_all_data_types("pandas", num_rows=1) item_resource = dlt.resource(item, name="items", write_disposition="replace") pipeline_name = "arrow_" + uniq_id() From f94e1e5ab750df6084bb6462a163e779e8d8d596 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 5 Apr 2024 10:44:15 +0200 Subject: [PATCH 11/22] fixes empty tables writer tests and bugs --- dlt/common/data_writers/writers.py | 4 + tests/common/data_writers/utils.py | 3 + tests/libs/test_buffered_writer_arrow,py | 50 ---------- tests/libs/test_buffered_writers.py | 115 +++++++++++++++++++++++ tests/pipeline/test_pipeline_trace.py | 6 +- 5 files changed, 125 insertions(+), 53 deletions(-) delete mode 100644 tests/libs/test_buffered_writer_arrow,py create mode 100644 tests/libs/test_buffered_writers.py diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 9936a6844d..ae18fc03ab 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -197,6 +197,10 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: def write_data(self, rows: Sequence[Any]) -> None: super().write_data(rows) + # do not write empty rows, such things may be produced by Arrow adapters + if len(rows) == 0: + return + def write_row(row: StrAny, last_row: bool = False) -> None: output = ["NULL"] * len(self._headers_lookup) for n, v in row.items(): diff --git a/tests/common/data_writers/utils.py b/tests/common/data_writers/utils.py index 95e7b8ab64..3c584211e3 100644 --- a/tests/common/data_writers/utils.py +++ b/tests/common/data_writers/utils.py @@ -10,6 +10,9 @@ ALL_OBJECT_WRITERS = [ writer for writer in ALL_WRITERS if writer.writer_spec().data_item_format == "object" ] +ALL_ARROW_WRITERS = [ + writer for writer in ALL_WRITERS if writer.writer_spec().data_item_format == "arrow" +] def get_writer( diff --git a/tests/libs/test_buffered_writer_arrow,py b/tests/libs/test_buffered_writer_arrow,py deleted file mode 100644 index f0f0968942..0000000000 --- a/tests/libs/test_buffered_writer_arrow,py +++ /dev/null @@ -1,50 +0,0 @@ -import pytest - -from dlt.common.destination import TLoaderFileFormat -from dlt.common.schema.utils import new_column - -from tests.common.data_writers.utils import get_writer, ALL_WRITERS - - -@pytest.mark.parametrize("writer_format", ALL_WRITERS - {"arrow"}) -def test_writer_items_count(writer_format: TLoaderFileFormat) -> None: - c1 = {"col1": new_column("col1", "bigint")} - with get_writer(_format=writer_format) as writer: - assert writer._buffered_items_count == 0 - # single item - writer.write_data_item({"col1": 1}, columns=c1) - assert writer._buffered_items_count == 1 - # list - writer.write_data_item([{"col1": 1}, {"col1": 2}], columns=c1) - assert writer._buffered_items_count == 3 - writer._flush_items() - assert writer._buffered_items_count == 0 - assert writer._writer.items_count == 3 - - -def test_writer_items_count_arrow() -> None: - import pyarrow as pa - c1 = {"col1": new_column("col1", "bigint")} - with get_writer(_format="arrow") as writer: - assert writer._buffered_items_count == 0 - # single item - writer.write_data_item(pa.Table.from_pylist([{"col1": 1}]), columns=c1) - assert writer._buffered_items_count == 1 - # single item with many rows - writer.write_data_item(pa.Table.from_pylist([{"col1": 1}, {"col1": 2}]), columns=c1) - assert writer._buffered_items_count == 3 - # empty list - writer.write_data_item([], columns=c1) - assert writer._buffered_items_count == 3 - # list with one item - writer.write_data_item([pa.Table.from_pylist([{"col1": 1}])], columns=c1) - assert writer._buffered_items_count == 4 - # list with many items - writer.write_data_item( - [pa.Table.from_pylist([{"col1": 1}]), pa.Table.from_pylist([{"col1": 1}, {"col1": 2}])], - columns=c1 - ) - assert writer._buffered_items_count == 7 - writer._flush_items() - assert writer._buffered_items_count == 0 - assert writer._writer.items_count == 7 diff --git a/tests/libs/test_buffered_writers.py b/tests/libs/test_buffered_writers.py new file mode 100644 index 0000000000..728c86b81b --- /dev/null +++ b/tests/libs/test_buffered_writers.py @@ -0,0 +1,115 @@ +from typing import Type +import pytest + +from dlt.common.schema.utils import new_column +from dlt.common.data_writers import DataWriter + +from tests.common.data_writers.utils import get_writer, ALL_OBJECT_WRITERS, ALL_ARROW_WRITERS + + +@pytest.mark.parametrize("writer_type", ALL_OBJECT_WRITERS) +def test_writer_items_count(writer_type: Type[DataWriter]) -> None: + c1 = {"col1": new_column("col1", "bigint")} + with get_writer(writer_type) as writer: + assert writer._buffered_items_count == 0 + # write empty list + writer.write_data_item([], columns=c1) + assert writer._buffered_items_count == 0 + writer._flush_items(allow_empty_file=True) + assert writer._buffered_items_count == 0 + assert writer._writer.items_count == 0 + assert writer.closed_files[0].items_count == 0 + + # single item + with get_writer(writer_type) as writer: + assert writer._buffered_items_count == 0 + writer.write_data_item({"col1": 1}, columns=c1) + assert writer._buffered_items_count == 1 + writer._flush_items() + assert writer._buffered_items_count == 0 + assert writer._writer.items_count == 1 + assert writer.closed_files[0].items_count == 1 + + # list + with get_writer(writer_type) as writer: + assert writer._buffered_items_count == 0 + writer.write_data_item([{"col1": 1}, {"col1": 2}], columns=c1) + assert writer._buffered_items_count == 2 + writer._flush_items() + assert writer._buffered_items_count == 0 + assert writer._writer.items_count == 2 + assert writer.closed_files[0].items_count == 2 + + +@pytest.mark.parametrize("writer_type", ALL_ARROW_WRITERS) +def test_writer_items_count_arrow(writer_type: Type[DataWriter]) -> None: + import pyarrow as pa + + c1 = {"col1": new_column("col1", "bigint")} + single_elem_table = pa.Table.from_pylist([{"col1": 1}]) + + # empty frame + with get_writer(writer_type) as writer: + assert writer._buffered_items_count == 0 + writer.write_data_item(single_elem_table.schema.empty_table(), columns=c1) + assert writer._buffered_items_count == 0 + # there's an empty frame to be written + assert len(writer._buffered_items) == 1 + # we flush empty frame + writer._flush_items() + assert writer._buffered_items_count == 0 + # no items were written + assert writer._writer.items_count == 0 + # but file was created + assert writer.closed_files[0].items_count == 0 + + # single item + with get_writer(writer_type) as writer: + assert writer._buffered_items_count == 0 + writer.write_data_item(pa.Table.from_pylist([{"col1": 1}]), columns=c1) + assert writer._buffered_items_count == 1 + writer._flush_items() + assert writer._buffered_items_count == 0 + assert writer._writer.items_count == 1 + assert writer.closed_files[0].items_count == 1 + + # single item with many rows + with get_writer(writer_type) as writer: + writer.write_data_item(pa.Table.from_pylist([{"col1": 1}, {"col1": 2}]), columns=c1) + assert writer._buffered_items_count == 2 + writer._flush_items() + assert writer._buffered_items_count == 0 + assert writer._writer.items_count == 2 + assert writer.closed_files[0].items_count == 2 + + # empty list + with get_writer(writer_type) as writer: + writer.write_data_item([], columns=c1) + assert writer._buffered_items_count == 0 + writer._flush_items() + assert writer._buffered_items_count == 0 + # no file was created + assert writer._file is None + assert writer._writer is None + assert len(writer.closed_files) == 0 + + # list with one item + with get_writer(writer_type) as writer: + writer.write_data_item([pa.Table.from_pylist([{"col1": 1}])], columns=c1) + assert writer._buffered_items_count == 1 + writer._flush_items() + assert writer._buffered_items_count == 0 + assert writer._writer.items_count == 1 + assert writer.closed_files[0].items_count == 1 + + # list with many items + with get_writer(writer_type) as writer: + writer.write_data_item( + [pa.Table.from_pylist([{"col1": 1}]), pa.Table.from_pylist([{"col1": 1}, {"col1": 2}])], + columns=c1, + ) + assert writer._buffered_items_count == 3 + writer._flush_items() + assert writer._buffered_items_count == 0 + assert writer._writer.items_count == 3 + assert writer.closed_files[0].items_count == 3 diff --git a/tests/pipeline/test_pipeline_trace.py b/tests/pipeline/test_pipeline_trace.py index cec578cb7b..05268f09b3 100644 --- a/tests/pipeline/test_pipeline_trace.py +++ b/tests/pipeline/test_pipeline_trace.py @@ -160,10 +160,10 @@ def data(): load_id = extract_info.loads_ids[0] package = extract_info.load_packages[0] assert package.state == "new" - # no jobs + # no jobs - exceptions happened before save assert len(package.jobs["new_jobs"]) == 0 - # no metrics - exception happened first - assert len(extract_info.metrics[load_id]) == 0 + # metrics should be collected + assert len(extract_info.metrics[load_id]) == 1 # normalize norm_info = p.normalize() From 4b19c49cafae3625cebb0e2cf2d9e09c9049f4a4 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 5 Apr 2024 11:53:26 +0200 Subject: [PATCH 12/22] fixes closing writers when exception during flush, missing tzdata on windows handling --- dlt/common/data_writers/buffered.py | 18 ++++++++++--- dlt/common/data_writers/writers.py | 9 +++++++ dlt/normalize/normalize.py | 42 ++++++++++++++++++++--------- tests/libs/test_arrow_csv_writer.py | 3 ++- 4 files changed, 55 insertions(+), 17 deletions(-) diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index e358919c7a..e61ee0edf1 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -166,9 +166,10 @@ def import_file(self, file_path: str, metrics: DataWriterMetrics) -> DataWriterM def close(self, skip_flush: bool = False) -> None: """Flushes the data, writes footer (skip_flush is True), collects metrics and closes the underlying file.""" - self._ensure_open() - self._flush_and_close_file(skip_flush=skip_flush) - self._closed = True + # like regular files, we do not except on double close + if not self._closed: + self._flush_and_close_file(skip_flush=skip_flush) + self._closed = True @property def closed(self) -> bool: @@ -179,7 +180,16 @@ def __enter__(self) -> "BufferedDataWriter[TWriter]": def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: Any) -> None: # skip flush if we had exception - self.close(skip_flush=exc_val is not None) + in_exception = exc_val is not None + try: + self.close(skip_flush=in_exception) + except Exception: + if not in_exception: + # close again but without flush + self.close(skip_flush=True) + # raise the on close exception if we are not handling another exception + if not in_exception: + raise def _rotate_file(self, allow_empty_file: bool = False) -> DataWriterMetrics: metrics = self._flush_and_close_file(allow_empty_file) diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index ae18fc03ab..786833fb68 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -487,6 +487,15 @@ def write_data(self, rows: Sequence[Any]) -> None: " characters. Remove binary columns or replace their content with a hex" " representation: \\x... while keeping data type as binary.", ) + if "Timezone database not found" in str(inv_ex): + raise InvalidDataItem( + "csv", + "arrow", + str(inv_ex) + + ". Arrow does not ship with tzdata on Windows. You need to install it" + " yourself:" + " https://arrow.apache.org/docs/cpp/build_system.html#runtime-dependencies", + ) raise else: raise ValueError(f"Unsupported type {type(row)}") diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 47d0cd9898..0125d5a525 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -142,17 +142,35 @@ def _get_items_normalizer(item_format: TDataItemFormat) -> ItemsNormalizer: ) return norm - def _gather_metrics_and_close(skip_flush: bool) -> List[DataWriterMetrics]: - for normalizer in item_normalizers.values(): - normalizer.item_storage.close_writers(load_id, skip_flush=skip_flush) - + def _gather_metrics_and_close( + parsed_fn: ParsedLoadJobFileName, in_exception: bool + ) -> List[DataWriterMetrics]: writer_metrics: List[DataWriterMetrics] = [] - for normalizer in item_normalizers.values(): - norm_metrics = normalizer.item_storage.closed_files(load_id) - writer_metrics.extend(norm_metrics) - - for normalizer in item_normalizers.values(): - normalizer.item_storage.remove_closed_files(load_id) + try: + try: + for normalizer in item_normalizers.values(): + normalizer.item_storage.close_writers(load_id, skip_flush=in_exception) + except Exception: + # if we had exception during flushing the writers, close them without flushing + if not in_exception: + for normalizer in item_normalizers.values(): + normalizer.item_storage.close_writers(load_id, skip_flush=True) + raise + finally: + # always gather metrics + for normalizer in item_normalizers.values(): + norm_metrics = normalizer.item_storage.closed_files(load_id) + writer_metrics.extend(norm_metrics) + for normalizer in item_normalizers.values(): + normalizer.item_storage.remove_closed_files(load_id) + except Exception as exc: + if in_exception: + # swallow exception if we already handle exceptions + return writer_metrics + else: + # enclose the exception during the closing in job failed exception + job_id = parsed_fn.job_id() if parsed_fn else "" + raise NormalizeJobFailed(load_id, job_id, str(exc), writer_metrics) return writer_metrics parsed_file_name: ParsedLoadJobFileName = None @@ -178,10 +196,10 @@ def _gather_metrics_and_close(skip_flush: bool) -> List[DataWriterMetrics]: logger.debug(f"Processed file {extracted_items_file}") except Exception as exc: job_id = parsed_file_name.job_id() if parsed_file_name else "" - writer_metrics = _gather_metrics_and_close(skip_flush=True) + writer_metrics = _gather_metrics_and_close(parsed_file_name, in_exception=True) raise NormalizeJobFailed(load_id, job_id, str(exc), writer_metrics) from exc else: - writer_metrics = _gather_metrics_and_close(skip_flush=False) + writer_metrics = _gather_metrics_and_close(parsed_file_name, in_exception=False) logger.info(f"Processed all items in {len(extracted_items_files)} files") return TWorkerRV(schema_updates, writer_metrics) diff --git a/tests/libs/test_arrow_csv_writer.py b/tests/libs/test_arrow_csv_writer.py index 85a15cc169..507fe4e946 100644 --- a/tests/libs/test_arrow_csv_writer.py +++ b/tests/libs/test_arrow_csv_writer.py @@ -104,9 +104,10 @@ def test_non_utf8_binary() -> None: with open(pq_writer.closed_files[0].file_path, "rb") as f: table = pq.read_table(f) - with pytest.raises(InvalidDataItem): + with pytest.raises(InvalidDataItem) as inv_ex: with get_writer(ArrowToCsvWriter, disable_compression=True) as writer: writer.write_data_item(table, TABLE_UPDATE_COLUMNS_SCHEMA) + assert "Arrow data contains string or binary columns" in str(inv_ex.value) def test_arrow_struct() -> None: From 76403c184793a584ef79fafd27c2cf1ade246d4a Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 5 Apr 2024 12:48:49 +0200 Subject: [PATCH 13/22] installs tzdata on windows ci --- .github/workflows/test_common.yml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/.github/workflows/test_common.yml b/.github/workflows/test_common.yml index 2d96d2eb95..c1dfd63004 100644 --- a/.github/workflows/test_common.yml +++ b/.github/workflows/test_common.yml @@ -111,6 +111,16 @@ jobs: - name: Install pipeline dependencies run: poetry install --no-interaction -E duckdb -E cli -E parquet --with sentry-sdk --with pipeline + - name: Install tzdata on windows + run: | + curl https://data.iana.org/time-zones/releases/tzdata2021e.tar.gz --output tzdata.tar.gz + mkdir tzdata + tar --extract --file tzdata.tar.gz --directory tzdata + move tzdata %USERPROFILE%\Downloads\tzdata + curl https://raw.githubusercontent.com/unicode-org/cldr/master/common/supplemental/windowsZones.xml --output %USERPROFILE%\Downloads\tzdata\windowsZones.xml + if: runner.os == 'Windows' + shell: cmd + - run: | poetry run pytest tests/extract tests/pipeline tests/libs tests/cli/common tests/destinations if: runner.os != 'Windows' From 0adb113be705b4f623c8b0a53bfca748f54cb970 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Fri, 5 Apr 2024 13:03:31 +0200 Subject: [PATCH 14/22] adds csv to docs index --- .github/workflows/test_common.yml | 24 ++++++++++++++---------- docs/website/sidebars.js | 1 + 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/.github/workflows/test_common.yml b/.github/workflows/test_common.yml index c1dfd63004..5561811d73 100644 --- a/.github/workflows/test_common.yml +++ b/.github/workflows/test_common.yml @@ -52,6 +52,20 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install tzdata on windows + run: | + cd %USERPROFILE% + curl https://data.iana.org/time-zones/releases/tzdata2021e.tar.gz --output tzdata.tar.gz + mkdir tzdata + tar --extract --file tzdata.tar.gz --directory tzdata + dir %USERPROFILE%\Downloads\ + mkdir %USERPROFILE%\Downloads\tzdata + dir %USERPROFILE%\Downloads\ + copy tzdata %USERPROFILE%\Downloads\tzdata + curl https://raw.githubusercontent.com/unicode-org/cldr/master/common/supplemental/windowsZones.xml --output %USERPROFILE%\Downloads\tzdata\windowsZones.xml + if: runner.os == 'Windows' + shell: cmd + - name: Install Poetry # https://github.com/snok/install-poetry#running-on-windows uses: snok/install-poetry@v1.3.2 @@ -111,16 +125,6 @@ jobs: - name: Install pipeline dependencies run: poetry install --no-interaction -E duckdb -E cli -E parquet --with sentry-sdk --with pipeline - - name: Install tzdata on windows - run: | - curl https://data.iana.org/time-zones/releases/tzdata2021e.tar.gz --output tzdata.tar.gz - mkdir tzdata - tar --extract --file tzdata.tar.gz --directory tzdata - move tzdata %USERPROFILE%\Downloads\tzdata - curl https://raw.githubusercontent.com/unicode-org/cldr/master/common/supplemental/windowsZones.xml --output %USERPROFILE%\Downloads\tzdata\windowsZones.xml - if: runner.os == 'Windows' - shell: cmd - - run: | poetry run pytest tests/extract tests/pipeline tests/libs tests/cli/common tests/destinations if: runner.os != 'Windows' diff --git a/docs/website/sidebars.js b/docs/website/sidebars.js index bc8d16d05a..9776de0818 100644 --- a/docs/website/sidebars.js +++ b/docs/website/sidebars.js @@ -170,6 +170,7 @@ const sidebars = { items: [ 'dlt-ecosystem/file-formats/jsonl', 'dlt-ecosystem/file-formats/parquet', + 'dlt-ecosystem/file-formats/csv', 'dlt-ecosystem/file-formats/insert-format', ] }, From 7eb6ca84244158b57fb0b86e22658903562eae82 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 6 Apr 2024 00:44:36 +0200 Subject: [PATCH 15/22] fixes athena sql job client tests setup --- tests/load/utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/load/utils.py b/tests/load/utils.py index 3972f3ad95..17dd6a24fd 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -504,11 +504,10 @@ def yield_client( # athena requires staging config to be present, so stick this in there here if destination_type == "athena": staging_config = DestinationClientStagingConfiguration( - destination_type="fake-stage", # type: ignore - dataset_name=dest_config.dataset_name, - default_schema_name=dest_config.default_schema_name, bucket_url=AWS_BUCKET, - ) + )._bind_dataset_name(dataset_name=dest_config.dataset_name) + staging_config.destination_type = "filesystem" # type: ignore[misc] + staging_config.resolve() dest_config.staging_config = staging_config # type: ignore[attr-defined] # lookup for credentials in the section that is destination name From d30d17b1964b86bf61aaec25bda4333d71ef797e Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 6 Apr 2024 00:45:33 +0200 Subject: [PATCH 16/22] adjusts for timezone for the preferred precision, all other precision use timestamp w/o tz --- .github/workflows/test_common.yml | 2 - dlt/common/data_writers/writers.py | 15 +++- dlt/common/libs/pyarrow.py | 19 +++-- dlt/common/schema/utils.py | 6 ++ .../dlt-ecosystem/file-formats/parquet.md | 14 +++- tests/common/data_writers/utils.py | 3 +- tests/libs/test_parquet_writer.py | 79 +++++++++++++++++-- .../load/pipeline/test_filesystem_pipeline.py | 1 + 8 files changed, 120 insertions(+), 19 deletions(-) diff --git a/.github/workflows/test_common.yml b/.github/workflows/test_common.yml index 5561811d73..dab9e92a6f 100644 --- a/.github/workflows/test_common.yml +++ b/.github/workflows/test_common.yml @@ -58,9 +58,7 @@ jobs: curl https://data.iana.org/time-zones/releases/tzdata2021e.tar.gz --output tzdata.tar.gz mkdir tzdata tar --extract --file tzdata.tar.gz --directory tzdata - dir %USERPROFILE%\Downloads\ mkdir %USERPROFILE%\Downloads\tzdata - dir %USERPROFILE%\Downloads\ copy tzdata %USERPROFILE%\Downloads\tzdata curl https://raw.githubusercontent.com/unicode-org/cldr/master/common/supplemental/windowsZones.xml --output %USERPROFILE%\Downloads\tzdata\windowsZones.xml if: runner.os == 'Windows' diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 786833fb68..b3ece70c8c 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -254,6 +254,8 @@ class ParquetDataWriterConfiguration(BaseConfiguration): timestamp_precision: str = "us" timestamp_timezone: str = "UTC" row_group_size: Optional[int] = None + coerce_timestamps: Optional[Literal["s", "ms", "us", "ns"]] = None + allow_truncated_timestamps: bool = False __section__: ClassVar[str] = known_sections.DATA_WRITER @@ -270,6 +272,8 @@ def __init__( data_page_size: Optional[int] = None, timestamp_timezone: str = "UTC", row_group_size: Optional[int] = None, + coerce_timestamps: Optional[Literal["s", "ms", "us", "ns"]] = None, + allow_truncated_timestamps: bool = False, ) -> None: super().__init__(f, caps) from dlt.common.libs.pyarrow import pyarrow @@ -282,6 +286,8 @@ def __init__( self.parquet_data_page_size = data_page_size self.timestamp_timezone = timestamp_timezone self.parquet_row_group_size = row_group_size + self.coerce_timestamps = coerce_timestamps + self.allow_truncated_timestamps = allow_truncated_timestamps def _create_writer(self, schema: "pa.Schema") -> "pa.parquet.ParquetWriter": from dlt.common.libs.pyarrow import pyarrow @@ -292,6 +298,8 @@ def _create_writer(self, schema: "pa.Schema") -> "pa.parquet.ParquetWriter": flavor=self.parquet_flavor, version=self.parquet_version, data_page_size=self.parquet_data_page_size, + coerce_timestamps=self.coerce_timestamps, + allow_truncated_timestamps=self.allow_truncated_timestamps, ) def write_header(self, columns_schema: TTableSchemaColumns) -> None: @@ -302,7 +310,12 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: [ pyarrow.field( name, - get_py_arrow_datatype(schema_item, self._caps, self.timestamp_timezone), + get_py_arrow_datatype( + schema_item, + self._caps, + self.timestamp_timezone, + self.coerce_timestamps or "us", + ), nullable=schema_item.get("nullable", True), ) for name, schema_item in columns_schema.items() diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index cb19c8c00a..c3991b9a5b 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -39,7 +39,10 @@ def get_py_arrow_datatype( - column: TColumnType, caps: DestinationCapabilitiesContext, tz: str + column: TColumnType, + caps: DestinationCapabilitiesContext, + tz: str, + default_ts_precision: str = "us", ) -> Any: column_type = column["data_type"] if column_type == "text": @@ -49,7 +52,9 @@ def get_py_arrow_datatype( elif column_type == "bool": return pyarrow.bool_() elif column_type == "timestamp": - return get_py_arrow_timestamp(column.get("precision") or caps.timestamp_precision, tz) + return get_py_arrow_timestamp( + column.get("precision") or caps.timestamp_precision, tz, default_ts_precision + ) elif column_type == "bigint": return get_pyarrow_int(column.get("precision")) elif column_type == "binary": @@ -75,14 +80,14 @@ def get_py_arrow_datatype( raise ValueError(column_type) -def get_py_arrow_timestamp(precision: int, tz: str) -> Any: +def get_py_arrow_timestamp(precision: int, tz: str, default_ts_precision: str = "us") -> Any: if precision == 0: - return pyarrow.timestamp("s", tz=tz) + return pyarrow.timestamp("s", tz=tz if default_ts_precision == "s" else None) if precision <= 3: - return pyarrow.timestamp("ms", tz=tz) + return pyarrow.timestamp("ms", tz=tz if default_ts_precision == "ms" else None) if precision <= 6: - return pyarrow.timestamp("us", tz=tz) - return pyarrow.timestamp("ns", tz=tz) + return pyarrow.timestamp("us", tz=tz if default_ts_precision == "us" else None) + return pyarrow.timestamp("ns", tz=tz if default_ts_precision == "ns" else None) def get_py_arrow_time(precision: int) -> Any: diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 0a4e00759d..4c1071a8a9 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -713,11 +713,17 @@ def new_column( column_name: str, data_type: TDataType = None, nullable: bool = True, + precision: int = None, + scale: int = None, validate_schema: bool = False, ) -> TColumnSchema: column: TColumnSchema = {"name": column_name, "nullable": nullable} if data_type: column["data_type"] = data_type + if precision is not None: + column["precision"] = precision + if scale is not None: + column["scale"] = scale if validate_schema: validate_dict_ignoring_xkeys( spec=TColumnSchema, diff --git a/docs/website/docs/dlt-ecosystem/file-formats/parquet.md b/docs/website/docs/dlt-ecosystem/file-formats/parquet.md index 6be4f8799e..9514f3e5c9 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/parquet.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/parquet.md @@ -29,14 +29,21 @@ info = pipeline.run(some_source(), loader_file_format="parquet") * It uses decimal and wei precision to pick the right **decimal type** and sets precision and scale. * It uses timestamp precision to pick the right **timestamp type** resolution (seconds, micro, or nano). -## Options +## Writer settings Under the hood, `dlt` uses the [pyarrow parquet writer](https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html) to create the files. The following options can be used to change the behavior of the writer: - `flavor`: Sanitize schema or set other compatibility options to work with various target systems. Defaults to None which is **pyarrow** default. -- `version`: Determine which Parquet logical types are available for use, whether the reduced set from the Parquet 1.x.x format or the expanded logical types added in later format versions. Defaults to "2.4". +- `version`: Determine which Parquet logical types are available for use, whether the reduced set from the Parquet 1.x.x format or the expanded logical types added in later format versions. Defaults to "2.6". - `data_page_size`: Set a target threshold for the approximate encoded size of data pages within a column chunk (in bytes). Defaults to None which is **pyarrow** default. - `timestamp_timezone`: A string specifying timezone, default is UTC. +- `coerce_timestamps`: resolution to which coerce timestamps, choose from **s**, **ms**, **us**, **ns** +- `allow_truncated_timestamps` - will raise if precision is lost on truncated timestamp. + +:::tip +Default parquet version used by `dlt` is 2.4. It coerces timestamps to microseconds and truncates nanoseconds silently. Such setting +provides best interoperability with database systems, including loading panda frames which have nanosecond resolution by default +::: Read the [pyarrow parquet docs](https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html) to learn more about these settings. @@ -59,3 +66,6 @@ NORMALIZE__DATA_WRITER__VERSION NORMALIZE__DATA_WRITER__DATA_PAGE_SIZE NORMALIZE__DATA_WRITER__TIMESTAMP_TIMEZONE ``` + +### Timestamps and timezones +`dlt` adds timezone (UTC adjustment) to timestamps with microseconds precision. All other timestamp (second, millisecond, nanosecond) are stored without the adjustment. diff --git a/tests/common/data_writers/utils.py b/tests/common/data_writers/utils.py index 3c584211e3..2cb440bde1 100644 --- a/tests/common/data_writers/utils.py +++ b/tests/common/data_writers/utils.py @@ -21,8 +21,9 @@ def get_writer( file_max_items: int = 10, file_max_bytes: int = None, disable_compression: bool = False, + caps: DestinationCapabilitiesContext = None, ) -> BufferedDataWriter[TWriter]: - caps = DestinationCapabilitiesContext.generic_capabilities() + caps = caps or DestinationCapabilitiesContext.generic_capabilities() writer_spec = writer.writer_spec() caps.preferred_loader_file_format = writer_spec.file_format file_template = os.path.join(TEST_STORAGE_ROOT, f"{writer_spec.file_format}.%s") diff --git a/tests/libs/test_parquet_writer.py b/tests/libs/test_parquet_writer.py index 786617ef55..6c7e1383c5 100644 --- a/tests/libs/test_parquet_writer.py +++ b/tests/libs/test_parquet_writer.py @@ -2,14 +2,16 @@ import pyarrow as pa import pyarrow.parquet as pq import datetime # noqa: 251 +import time from dlt.common import pendulum, Decimal, json from dlt.common.configuration import inject_section from dlt.common.data_writers.writers import ParquetDataWriter -from dlt.common.destination import TLoaderFileFormat, DestinationCapabilitiesContext +from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.schema.utils import new_column from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.time import ensure_pendulum_date, ensure_pendulum_datetime +from dlt.common.time import ensure_pendulum_datetime +from dlt.common.libs.pyarrow import from_arrow_scalar from tests.common.data_writers.utils import get_writer from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA, TABLE_ROW_ALL_DATA_TYPES_DATETIMES @@ -128,7 +130,7 @@ def test_parquet_writer_all_data_fields() -> None: assert actual == value assert table.schema.field("col1_precision").type == pa.int16() - assert table.schema.field("col4_precision").type == pa.timestamp("ms", tz="UTC") + assert table.schema.field("col4_precision").type == pa.timestamp("ms") assert table.schema.field("col5_precision").type == pa.string() assert table.schema.field("col6_precision").type == pa.decimal128(6, 2) assert table.schema.field("col7_precision").type == pa.binary(19) @@ -188,15 +190,49 @@ def test_parquet_writer_config() -> None: # tz can column_type = writer._writer.schema.field("col2").type assert column_type.tz == "America/New York" + # read parquet back and check + with pa.parquet.ParquetFile(writer.closed_files[0].file_path) as reader: + # parquet schema is utc adjusted + col2_info = json.loads(reader.metadata.schema.column(1).logical_type.to_json()) + assert col2_info["isAdjustedToUTC"] is True + assert col2_info["timeUnit"] == "microseconds" + assert reader.schema_arrow.field(1).type.tz == "America/New York" + + +def test_parquet_writer_config_spark() -> None: + os.environ["NORMALIZE__DATA_WRITER__FLAVOR"] = "spark" + os.environ["NORMALIZE__DATA_WRITER__TIMESTAMP_TIMEZONE"] = "Europe/Berlin" + + now = pendulum.now(tz="Europe/Berlin") + with inject_section(ConfigSectionContext(pipeline_name=None, sections=("normalize",))): + with get_writer(ParquetDataWriter, file_max_bytes=2**8, buffer_max_items=2) as writer: + for i in range(0, 5): + writer.write_data_item( + [{"col1": i, "col2": now}], + {"col1": new_column("col1", "bigint"), "col2": new_column("col2", "timestamp")}, + ) + # force the parquet writer to be created + writer._flush_items() + with pa.parquet.ParquetFile(writer.closed_files[0].file_path) as reader: + # no logical type for timestamp + col2_info = json.loads(reader.metadata.schema.column(1).logical_type.to_json()) + assert col2_info == {"Type": "None"} + table = reader.read() + # when compared as naive UTC adjusted timestamps it works + assert table.column(1)[0].as_py() == now.in_timezone(tz="UTC").replace(tzinfo=None) def test_parquet_writer_schema_from_caps() -> None: + # store nanoseconds + os.environ["DATA_WRITER__VERSION"] = "2.6" caps = DestinationCapabilitiesContext.generic_capabilities() caps.decimal_precision = (18, 9) caps.wei_precision = (156, 78) # will be trimmed to dec256 caps.timestamp_precision = 9 # nanoseconds - with get_writer(ParquetDataWriter, file_max_bytes=2**8, buffer_max_items=2) as writer: + with get_writer( + ParquetDataWriter, file_max_bytes=2**8, buffer_max_items=2, caps=caps + ) as writer: for _ in range(0, 5): writer.write_data_item( [{"col1": Decimal("2617.27"), "col2": pendulum.now(), "col3": Decimal(2**250)}], @@ -210,13 +246,44 @@ def test_parquet_writer_schema_from_caps() -> None: writer._flush_items() column_type = writer._writer.schema.field("col2").type - assert column_type.tz == "UTC" + assert column_type == pa.timestamp("ns") + assert column_type.tz is None column_type = writer._writer.schema.field("col1").type assert isinstance(column_type, pa.Decimal128Type) - assert column_type.precision == 38 + assert column_type.precision == 18 assert column_type.scale == 9 column_type = writer._writer.schema.field("col3").type assert isinstance(column_type, pa.Decimal256Type) # got scaled down to maximum assert column_type.precision == 76 assert column_type.scale == 0 + + with pa.parquet.ParquetFile(writer.closed_files[0].file_path) as reader: + col2_info = json.loads(reader.metadata.schema.column(1).logical_type.to_json()) + assert col2_info["isAdjustedToUTC"] is False + assert col2_info["timeUnit"] == "nanoseconds" + + +def test_parquet_writer_timestamp_precision() -> None: + now = pendulum.now() + now_ns = time.time_ns() + + # store nanoseconds + os.environ["DATA_WRITER__VERSION"] = "2.6" + + with get_writer(ParquetDataWriter, file_max_bytes=2**8, buffer_max_items=2) as writer: + for _ in range(0, 5): + writer.write_data_item( + [{"col1": now, "col2": now, "col3": now, "col4": now_ns}], + { + "col1": new_column("col1", "timestamp", precision=0), + "col2": new_column("col2", "timestamp", precision=3), + "col3": new_column("col2", "timestamp", precision=6), + "col4": new_column("col2", "timestamp", precision=9), + }, + ) + # force the parquet writer to be created + writer._flush_items() + print(writer._writer.schema) + with pa.parquet.ParquetFile(writer.closed_files[0].file_path) as reader: + print(reader.metadata.schema) diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 8401f9d3af..a0af885484 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -101,6 +101,7 @@ def some_source(): def test_pipeline_csv_filesystem_destination() -> None: os.environ["DATA_WRITER__DISABLE_COMPRESSION"] = "True" + os.environ["RESTORE_FROM_DESTINATION"] = "False" # store locally os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "file://_storage" pipeline = dlt.pipeline( From cb6dc6f53f5e3ff6fffc28aeab688aade4da7aad Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 7 Apr 2024 00:28:34 +0200 Subject: [PATCH 17/22] generates create table statements for synapse outside of a job --- dlt/destinations/impl/synapse/synapse.py | 24 ++++++++++++++---------- dlt/destinations/sql_jobs.py | 7 ++++--- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index 457e128ba0..bb5c045a89 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -127,7 +127,19 @@ def _create_replace_followup_jobs( self, table_chain: Sequence[TTableSchema] ) -> List[NewLoadJob]: if self.config.replace_strategy == "staging-optimized": - return [SynapseStagingCopyJob.from_table_chain(table_chain, self.sql_client)] + # we must recreate staging table after SCHEMA TRANSFER + job_params: SqlJobParams = {"table_chain_create_table_statements": {}} + create_statements = job_params["table_chain_create_table_statements"] + with self.with_staging_dataset(): + for table in table_chain: + columns = [c for c in self.schema.get_table_columns(table["name"]).values()] + # generate CREATE TABLE statement + create_statements[table["name"]] = self._get_table_update_sql( + table["name"], columns, generate_alter=False + ) + return [ + SynapseStagingCopyJob.from_table_chain(table_chain, self.sql_client, job_params) + ] return super()._create_replace_followup_jobs(table_chain) def prepare_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: @@ -194,15 +206,7 @@ def generate_sql( f" {staging_table_name};" ) # recreate staging table - job_client = current.pipeline().destination_client() # type: ignore[operator] - with job_client.with_staging_dataset(): - # get table columns from schema - columns = [c for c in job_client.schema.get_table_columns(table["name"]).values()] - # generate CREATE TABLE statement - create_table_stmt = job_client._get_table_update_sql( - table["name"], columns, generate_alter=False - ) - sql.extend(create_table_stmt) + sql.extend(params["table_chain_create_table_statements"][table["name"]]) return sql diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 91be3a60c9..2f09414ef1 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Sequence, Tuple, cast, TypedDict, Optional +from typing import Any, Dict, List, Sequence, Tuple, cast, TypedDict, Optional import yaml from dlt.common.logger import pretty_format_exception @@ -17,8 +17,9 @@ from dlt.destinations.sql_client import SqlClientBase -class SqlJobParams(TypedDict): +class SqlJobParams(TypedDict, total=False): replace: Optional[bool] + table_chain_create_table_statements: Dict[str, Sequence[str]] DEFAULTS: SqlJobParams = {"replace": False} @@ -40,7 +41,7 @@ def from_table_chain( The `table_chain` contains a list schemas of a tables with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). """ - params = cast(SqlJobParams, {**DEFAULTS, **(params or {})}) # type: ignore + params = cast(SqlJobParams, {**DEFAULTS, **(params or {})}) top_table = table_chain[0] file_info = ParsedLoadJobFileName( top_table["name"], ParsedLoadJobFileName.new_file_id(), 0, "sql" From 60f20c9a963037a49eecda17e364b6796150d50d Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 7 Apr 2024 00:28:56 +0200 Subject: [PATCH 18/22] fixes athena table undefinded detection --- dlt/destinations/impl/athena/athena.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index b323832418..1beb249386 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -31,6 +31,7 @@ ) from dlt.common import logger +from dlt.common.exceptions import TerminalValueError from dlt.common.utils import without_none from dlt.common.data_types import TDataType from dlt.common.schema import TColumnSchema, Schema, TSchemaTables, TTableSchema @@ -108,7 +109,11 @@ def to_db_integer_type( return "int" if table_format == "iceberg" else "smallint" elif precision <= 32: return "int" - return "bigint" + elif precision <= 64: + return "bigint" + raise TerminalValueError( + f"bigint with {precision} bits precision cannot be mapped into athena integer type" + ) def from_db_type( self, db_type: str, precision: Optional[int], scale: Optional[int] @@ -236,7 +241,7 @@ def _make_database_exception(ex: Exception) -> Exception: return DatabaseUndefinedRelation(ex) elif "SCHEMA_NOT_FOUND" in str(ex): return DatabaseUndefinedRelation(ex) - elif "Table not found" in str(ex): + elif "Table" in str(ex) and " not found" in str(ex): return DatabaseUndefinedRelation(ex) elif "Database does not exist" in str(ex): return DatabaseUndefinedRelation(ex) From 04e46e9b8c4722be6f9bbf79c5ad094a207f9db9 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 7 Apr 2024 00:29:29 +0200 Subject: [PATCH 19/22] generates all timestamps with timezones in parquet. tests workarounds in duckdb --- dlt/common/data_writers/writers.py | 2 - dlt/common/libs/pyarrow.py | 16 ++-- .../impl/databricks/databricks.py | 7 +- dlt/destinations/impl/duckdb/duck.py | 9 +- dlt/destinations/impl/mssql/mssql.py | 7 +- dlt/destinations/impl/postgres/postgres.py | 7 +- dlt/destinations/impl/redshift/redshift.py | 7 +- dlt/destinations/impl/snowflake/snowflake.py | 2 +- .../dlt-ecosystem/file-formats/parquet.md | 10 ++- tests/cases.py | 23 +++++ tests/libs/conftest.py | 1 + tests/libs/test_arrow_csv_writer.py | 1 - tests/libs/test_parquet_writer.py | 54 +++++++++--- .../load/duckdb/test_duckdb_table_builder.py | 30 ++++++- .../load/filesystem/test_filesystem_common.py | 2 +- tests/load/pipeline/test_duckdb.py | 85 ++++++++++++++++++- tests/load/pipeline/test_pipelines.py | 5 ++ .../postgres/test_postgres_table_builder.py | 25 +++++- 18 files changed, 250 insertions(+), 43 deletions(-) create mode 100644 tests/libs/conftest.py diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index b3ece70c8c..038d90d22e 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -251,7 +251,6 @@ class ParquetDataWriterConfiguration(BaseConfiguration): flavor: Optional[str] = None # could be ie. "spark" version: Optional[str] = "2.4" data_page_size: Optional[int] = None - timestamp_precision: str = "us" timestamp_timezone: str = "UTC" row_group_size: Optional[int] = None coerce_timestamps: Optional[Literal["s", "ms", "us", "ns"]] = None @@ -314,7 +313,6 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: schema_item, self._caps, self.timestamp_timezone, - self.coerce_timestamps or "us", ), nullable=schema_item.get("nullable", True), ) diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index c3991b9a5b..3380157600 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -42,7 +42,6 @@ def get_py_arrow_datatype( column: TColumnType, caps: DestinationCapabilitiesContext, tz: str, - default_ts_precision: str = "us", ) -> Any: column_type = column["data_type"] if column_type == "text": @@ -52,9 +51,7 @@ def get_py_arrow_datatype( elif column_type == "bool": return pyarrow.bool_() elif column_type == "timestamp": - return get_py_arrow_timestamp( - column.get("precision") or caps.timestamp_precision, tz, default_ts_precision - ) + return get_py_arrow_timestamp(column.get("precision") or caps.timestamp_precision, tz) elif column_type == "bigint": return get_pyarrow_int(column.get("precision")) elif column_type == "binary": @@ -80,14 +77,15 @@ def get_py_arrow_datatype( raise ValueError(column_type) -def get_py_arrow_timestamp(precision: int, tz: str, default_ts_precision: str = "us") -> Any: +def get_py_arrow_timestamp(precision: int, tz: str) -> Any: + tz = tz if tz else None if precision == 0: - return pyarrow.timestamp("s", tz=tz if default_ts_precision == "s" else None) + return pyarrow.timestamp("s", tz=tz) if precision <= 3: - return pyarrow.timestamp("ms", tz=tz if default_ts_precision == "ms" else None) + return pyarrow.timestamp("ms", tz=tz) if precision <= 6: - return pyarrow.timestamp("us", tz=tz if default_ts_precision == "us" else None) - return pyarrow.timestamp("ns", tz=tz if default_ts_precision == "ns" else None) + return pyarrow.timestamp("us", tz=tz) + return pyarrow.timestamp("ns", tz=tz) def get_py_arrow_time(precision: int) -> Any: diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 07e827cd28..0008599349 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -16,6 +16,7 @@ AzureCredentialsWithoutDefaults, ) from dlt.common.data_types import TDataType +from dlt.common.exceptions import TerminalValueError from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TColumnType, TSchemaTables, TTableFormat @@ -81,7 +82,11 @@ def to_db_integer_type( return "SMALLINT" if precision <= 32: return "INT" - return "BIGINT" + if precision <= 64: + return "BIGINT" + raise TerminalValueError( + f"bigint with {precision} bits precision cannot be mapped into databricks integer type" + ) def from_db_type( self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None diff --git a/dlt/destinations/impl/duckdb/duck.py b/dlt/destinations/impl/duckdb/duck.py index 735a4ce7e3..7016e9bfff 100644 --- a/dlt/destinations/impl/duckdb/duck.py +++ b/dlt/destinations/impl/duckdb/duck.py @@ -81,7 +81,11 @@ def to_db_integer_type( return "INTEGER" elif precision <= 64: return "BIGINT" - return "HUGEINT" + elif precision <= 128: + return "HUGEINT" + raise TerminalValueError( + f"bigint with {precision} bits precision cannot be mapped into duckdb integer type" + ) def to_db_datetime_type( self, precision: Optional[int], table_format: TTableFormat = None @@ -95,7 +99,8 @@ def to_db_datetime_type( if precision == 9: return "TIMESTAMP_NS" raise TerminalValueError( - f"timestamp {precision} cannot be mapped into duckdb TIMESTAMP typ" + f"timestamp with {precision} decimals after seconds cannot be mapped into duckdb" + " TIMESTAMP type" ) def from_db_type( diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index b6af345e36..85e2b9e475 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -1,5 +1,6 @@ from typing import ClassVar, Dict, Optional, Sequence, List, Any, Tuple +from dlt.common.exceptions import TerminalValueError from dlt.common.wei import EVM_DECIMAL_PRECISION from dlt.common.destination.reference import NewLoadJob from dlt.common.destination import DestinationCapabilitiesContext @@ -73,7 +74,11 @@ def to_db_integer_type( return "smallint" if precision <= 32: return "int" - return "bigint" + elif precision <= 64: + return "bigint" + raise TerminalValueError( + f"bigint with {precision} bits precision cannot be mapped into mssql integer type" + ) def from_db_type( self, db_type: str, precision: Optional[int], scale: Optional[int] diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index b585967196..29c0f1b7e2 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -2,6 +2,7 @@ from dlt.common.destination.reference import FollowupJob, LoadJob, NewLoadJob, TLoadJobState from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.exceptions import TerminalValueError from dlt.common.schema import TColumnSchema, TColumnHint, Schema from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.common.storages.file_storage import FileStorage @@ -64,7 +65,11 @@ def to_db_integer_type( return "smallint" elif precision <= 32: return "integer" - return "bigint" + elif precision <= 64: + return "bigint" + raise TerminalValueError( + f"bigint with {precision} bits precision cannot be mapped into postgres integer type" + ) def from_db_type( self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None diff --git a/dlt/destinations/impl/redshift/redshift.py b/dlt/destinations/impl/redshift/redshift.py index eaa1968133..672fceb7b2 100644 --- a/dlt/destinations/impl/redshift/redshift.py +++ b/dlt/destinations/impl/redshift/redshift.py @@ -1,6 +1,7 @@ import platform import os +from dlt.common.exceptions import TerminalValueError from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient from dlt.common.schema.utils import table_schema_has_type, table_schema_has_type_with_precision @@ -92,7 +93,11 @@ def to_db_integer_type( return "smallint" elif precision <= 32: return "integer" - return "bigint" + elif precision <= 64: + return "bigint" + raise TerminalValueError( + f"bigint with {precision} bits precision cannot be mapped into postgres integer type" + ) def from_db_type( self, db_type: str, precision: Optional[int], scale: Optional[int] diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index 7fafbf83b7..70377de709 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -166,7 +166,7 @@ def __init__( # decide on source format, stage_file_path will either be a local file or a bucket path source_format = "( TYPE = 'JSON', BINARY_FORMAT = 'BASE64' )" if file_name.endswith("parquet"): - source_format = "(TYPE = 'PARQUET', BINARY_AS_TEXT = FALSE)" + source_format = "(TYPE = 'PARQUET', BINARY_AS_TEXT = FALSE, USE_LOGICAL_TYPE = TRUE)" with client.begin_transaction(): # PUT and COPY in one tx if local file, otherwise only copy diff --git a/docs/website/docs/dlt-ecosystem/file-formats/parquet.md b/docs/website/docs/dlt-ecosystem/file-formats/parquet.md index 9514f3e5c9..76600571f0 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/parquet.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/parquet.md @@ -68,4 +68,12 @@ NORMALIZE__DATA_WRITER__TIMESTAMP_TIMEZONE ``` ### Timestamps and timezones -`dlt` adds timezone (UTC adjustment) to timestamps with microseconds precision. All other timestamp (second, millisecond, nanosecond) are stored without the adjustment. +`dlt` adds timezone (UTC adjustment) to all timestamps regardless of a precision (from seconds to nanoseconds). `dlt` will also create TZ aware timestamp columns in +the destinations. If the latter is impossible, there are workaround + +### Disable timezones / utc adjustment flags +You can generate parquet files without timezone adjustment information in two ways: +1. Set the **flavor** to spark. All timestamps will be generated via deprecated `int96` physical data type, without the logical one +2. Set the **timestamp_timezone** to empty string (ie. `DATA_WRITER__TIMESTAMP_TIMEZONE=""`) to generate logical type without UTC adjustment. + +To our best knowledge, arrow will convert your timezone aware DateTime(s) to UTC and store them in parquet without timezone information. diff --git a/tests/cases.py b/tests/cases.py index b598f1169e..8885df0c1b 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -8,6 +8,7 @@ from dlt.common import Decimal, pendulum, json from dlt.common.data_types import TDataType +from dlt.common.schema.utils import new_column from dlt.common.typing import StrAny, TDataItems from dlt.common.wei import Wei from dlt.common.time import ( @@ -150,6 +151,28 @@ TABLE_ROW_ALL_DATA_TYPES_DATETIMES["col11_precision"] = pendulum.Time.fromisoformat(TABLE_ROW_ALL_DATA_TYPES_DATETIMES["col11_precision"]) # type: ignore[arg-type] +TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS = [ + new_column("col1_ts", "timestamp", precision=0), + new_column("col2_ts", "timestamp", precision=3), + new_column("col3_ts", "timestamp", precision=6), + new_column("col4_ts", "timestamp", precision=9), +] +TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS_COLUMNS: TTableSchemaColumns = { + c["name"]: c for c in TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS +} + +TABLE_UPDATE_ALL_INT_PRECISIONS = [ + new_column("col1_int", "bigint", precision=8), + new_column("col2_int", "bigint", precision=16), + new_column("col3_int", "bigint", precision=32), + new_column("col4_int", "bigint", precision=64), + new_column("col5_int", "bigint", precision=128), +] +TABLE_UPDATE_ALL_INT_PRECISIONS_COLUMNS: TTableSchemaColumns = { + c["name"]: c for c in TABLE_UPDATE_ALL_INT_PRECISIONS +} + + def table_update_and_row( exclude_types: Sequence[TDataType] = None, exclude_columns: Sequence[str] = None ) -> Tuple[TTableSchemaColumns, StrAny]: diff --git a/tests/libs/conftest.py b/tests/libs/conftest.py new file mode 100644 index 0000000000..0248c1dfaa --- /dev/null +++ b/tests/libs/conftest.py @@ -0,0 +1 @@ +from tests.utils import write_version, autouse_test_storage, preserve_environ diff --git a/tests/libs/test_arrow_csv_writer.py b/tests/libs/test_arrow_csv_writer.py index 507fe4e946..b9b0555f1d 100644 --- a/tests/libs/test_arrow_csv_writer.py +++ b/tests/libs/test_arrow_csv_writer.py @@ -14,7 +14,6 @@ TABLE_ROW_ALL_DATA_TYPES, arrow_table_all_data_types, ) -from tests.utils import write_version, autouse_test_storage, preserve_environ def test_csv_writer_all_data_fields() -> None: diff --git a/tests/libs/test_parquet_writer.py b/tests/libs/test_parquet_writer.py index 6c7e1383c5..158ed047d8 100644 --- a/tests/libs/test_parquet_writer.py +++ b/tests/libs/test_parquet_writer.py @@ -1,6 +1,7 @@ import os import pyarrow as pa import pyarrow.parquet as pq +import pytest import datetime # noqa: 251 import time @@ -14,8 +15,11 @@ from dlt.common.libs.pyarrow import from_arrow_scalar from tests.common.data_writers.utils import get_writer -from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA, TABLE_ROW_ALL_DATA_TYPES_DATETIMES -from tests.utils import write_version, autouse_test_storage, preserve_environ +from tests.cases import ( + TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS_COLUMNS, + TABLE_UPDATE_COLUMNS_SCHEMA, + TABLE_ROW_ALL_DATA_TYPES_DATETIMES, +) def test_parquet_writer_schema_evolution_with_big_buffer() -> None: @@ -130,7 +134,7 @@ def test_parquet_writer_all_data_fields() -> None: assert actual == value assert table.schema.field("col1_precision").type == pa.int16() - assert table.schema.field("col4_precision").type == pa.timestamp("ms") + assert table.schema.field("col4_precision").type == pa.timestamp("ms", tz="UTC") assert table.schema.field("col5_precision").type == pa.string() assert table.schema.field("col6_precision").type == pa.decimal128(6, 2) assert table.schema.field("col7_precision").type == pa.binary(19) @@ -246,8 +250,8 @@ def test_parquet_writer_schema_from_caps() -> None: writer._flush_items() column_type = writer._writer.schema.field("col2").type - assert column_type == pa.timestamp("ns") - assert column_type.tz is None + assert column_type == pa.timestamp("ns", tz="UTC") + assert column_type.tz == "UTC" column_type = writer._writer.schema.field("col1").type assert isinstance(column_type, pa.Decimal128Type) assert column_type.precision == 18 @@ -260,30 +264,52 @@ def test_parquet_writer_schema_from_caps() -> None: with pa.parquet.ParquetFile(writer.closed_files[0].file_path) as reader: col2_info = json.loads(reader.metadata.schema.column(1).logical_type.to_json()) - assert col2_info["isAdjustedToUTC"] is False + assert col2_info["isAdjustedToUTC"] is True assert col2_info["timeUnit"] == "nanoseconds" -def test_parquet_writer_timestamp_precision() -> None: +@pytest.mark.parametrize("tz", ["UTC", "Europe/Berlin", ""]) +def test_parquet_writer_timestamp_precision(tz: str) -> None: now = pendulum.now() now_ns = time.time_ns() # store nanoseconds os.environ["DATA_WRITER__VERSION"] = "2.6" + os.environ["DATA_WRITER__TIMESTAMP_TIMEZONE"] = tz + + adjusted = tz != "" with get_writer(ParquetDataWriter, file_max_bytes=2**8, buffer_max_items=2) as writer: for _ in range(0, 5): writer.write_data_item( [{"col1": now, "col2": now, "col3": now, "col4": now_ns}], - { - "col1": new_column("col1", "timestamp", precision=0), - "col2": new_column("col2", "timestamp", precision=3), - "col3": new_column("col2", "timestamp", precision=6), - "col4": new_column("col2", "timestamp", precision=9), - }, + TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS_COLUMNS, ) # force the parquet writer to be created writer._flush_items() - print(writer._writer.schema) + + def _assert_arrow_field(field: int, prec: str) -> None: + column_type = writer._writer.schema.field(field).type + assert column_type == pa.timestamp(prec, tz=tz) + if adjusted: + assert column_type.tz == tz + else: + assert column_type.tz is None + + _assert_arrow_field(0, "us") + _assert_arrow_field(1, "ms") + _assert_arrow_field(2, "us") + _assert_arrow_field(3, "ns") + with pa.parquet.ParquetFile(writer.closed_files[0].file_path) as reader: print(reader.metadata.schema) + + def _assert_pq_column(col: int, prec: str) -> None: + info = json.loads(reader.metadata.schema.column(col).logical_type.to_json()) + assert info["isAdjustedToUTC"] is adjusted + assert info["timeUnit"] == prec + + _assert_pq_column(0, "microseconds") + _assert_pq_column(1, "milliseconds") + _assert_pq_column(2, "microseconds") + _assert_pq_column(3, "nanoseconds") diff --git a/tests/load/duckdb/test_duckdb_table_builder.py b/tests/load/duckdb/test_duckdb_table_builder.py index 542b18993c..904e4c7bcf 100644 --- a/tests/load/duckdb/test_duckdb_table_builder.py +++ b/tests/load/duckdb/test_duckdb_table_builder.py @@ -8,7 +8,12 @@ from dlt.destinations.impl.duckdb.duck import DuckDbClient from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration -from tests.load.utils import TABLE_UPDATE, empty_schema +from tests.cases import ( + TABLE_UPDATE_ALL_INT_PRECISIONS, + TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS, + TABLE_UPDATE, +) +from tests.load.utils import empty_schema @pytest.fixture @@ -44,15 +49,34 @@ def test_create_table(client: DuckDbClient) -> None: assert '"col11_precision" TIME NOT NULL' in sql +def test_create_table_all_precisions(client: DuckDbClient) -> None: + # non existing table + sql = client._get_table_update_sql( + "event_test_table", + TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS + TABLE_UPDATE_ALL_INT_PRECISIONS, + False, + )[0] + sqlfluff.parse(sql, dialect="duckdb") + assert '"col1_ts" TIMESTAMP_S ' in sql + assert '"col2_ts" TIMESTAMP_MS ' in sql + assert '"col3_ts" TIMESTAMP WITH TIME ZONE ' in sql + assert '"col4_ts" TIMESTAMP_NS ' in sql + assert '"col1_int" TINYINT ' in sql + assert '"col2_int" SMALLINT ' in sql + assert '"col3_int" INTEGER ' in sql + assert '"col4_int" BIGINT ' in sql + assert '"col5_int" HUGEINT ' in sql + + def test_alter_table(client: DuckDbClient) -> None: # existing table has no columns sqls = client._get_table_update_sql("event_test_table", TABLE_UPDATE, True) for sql in sqls: sqlfluff.parse(sql, dialect="duckdb") - cannonical_name = client.sql_client.make_qualified_table_name("event_test_table") + canonical_name = client.sql_client.make_qualified_table_name("event_test_table") # must have several ALTER TABLE statements sql = ";\n".join(sqls) - assert sql.count(f"ALTER TABLE {cannonical_name}\nADD COLUMN") == 28 + assert sql.count(f"ALTER TABLE {canonical_name}\nADD COLUMN") == 28 assert "event_test_table" in sql assert '"col1" BIGINT NOT NULL' in sql assert '"col2" DOUBLE NOT NULL' in sql diff --git a/tests/load/filesystem/test_filesystem_common.py b/tests/load/filesystem/test_filesystem_common.py index 4c94766097..7b29027fe2 100644 --- a/tests/load/filesystem/test_filesystem_common.py +++ b/tests/load/filesystem/test_filesystem_common.py @@ -52,7 +52,7 @@ def check_file_exists(): def check_file_changed(): details = filesystem.info(file_url) assert details["size"] == 11 - assert (MTIME_DISPATCH[config.protocol](details) - now).seconds < 120 + assert (MTIME_DISPATCH[config.protocol](details) - now).seconds < 160 bucket_url = os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] config = get_config() diff --git a/tests/load/pipeline/test_duckdb.py b/tests/load/pipeline/test_duckdb.py index 6064392976..d5bd5c13a0 100644 --- a/tests/load/pipeline/test_duckdb.py +++ b/tests/load/pipeline/test_duckdb.py @@ -1,10 +1,11 @@ import pytest import os -import dlt +from dlt.common.time import ensure_pendulum_datetime from dlt.destinations.exceptions import DatabaseTerminalException from dlt.pipeline.exceptions import PipelineStepFailed +from tests.cases import TABLE_UPDATE_ALL_INT_PRECISIONS, TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS from tests.pipeline.utils import airtable_emojis from tests.load.pipeline.utils import ( destinations_configs, @@ -24,8 +25,17 @@ def test_duck_case_names(destination_config: DestinationTestConfiguration) -> No os.environ["SCHEMA__NAMING"] = "duck_case" pipeline = destination_config.setup_pipeline("test_duck_case_names") # create tables and columns with emojis and other special characters - pipeline.run(airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock")) - pipeline.run([{"🐾Feet": 2, "1+1": "two", "\nhey": "value"}], table_name="🦚Peacocks🦚") + info = pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock"), + loader_file_format=destination_config.file_format, + ) + info.raise_on_failed_jobs() + info = pipeline.run( + [{"🐾Feet": 2, "1+1": "two", "\nhey": "value"}], + table_name="🦚Peacocks🦚", + loader_file_format=destination_config.file_format, + ) + info.raise_on_failed_jobs() table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) @@ -40,7 +50,11 @@ def test_duck_case_names(destination_config: DestinationTestConfiguration) -> No # this will fail - duckdb preserves case but is case insensitive when comparing identifiers with pytest.raises(PipelineStepFailed) as pip_ex: - pipeline.run([{"🐾Feet": 2, "1+1": "two", "🐾feet": "value"}], table_name="🦚peacocks🦚") + pipeline.run( + [{"🐾Feet": 2, "1+1": "two", "🐾feet": "value"}], + table_name="🦚peacocks🦚", + loader_file_format=destination_config.file_format, + ) assert isinstance(pip_ex.value.__context__, DatabaseTerminalException) # show tables and columns @@ -48,3 +62,66 @@ def test_duck_case_names(destination_config: DestinationTestConfiguration) -> No with client.execute_query("DESCRIBE 🦚peacocks🦚;") as q: tables = q.df() assert tables["column_name"].tolist() == ["🐾Feet", "1+1", "hey", "_dlt_load_id", "_dlt_id"] + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb"]), + ids=lambda x: x.name, +) +def test_duck_precision_types(destination_config: DestinationTestConfiguration) -> None: + import pyarrow as pa + + # store timestamps without timezone adjustments + os.environ["DATA_WRITER__TIMESTAMP_TIMEZONE"] = "" + + now_s = ensure_pendulum_datetime("2022-05-23T13:26:46+01:00") + now_ms = ensure_pendulum_datetime("2022-05-23T13:26:46.167+01:00") + now_us = ensure_pendulum_datetime("2022-05-23T13:26:46.167231+01:00") + now_ns = ensure_pendulum_datetime("2022-05-23T13:26:46.167231+01:00") # time.time_ns() + + # TODO: we can't really handle integers > 64 bit (so nanoseconds and HUGEINT) + pipeline = destination_config.setup_pipeline("test_duck_all_precision_types") + row = [ + { + "col1_ts": now_s, + "col2_ts": now_ms, + "col3_ts": now_us, + "col4_ts": now_ns, + "col1_int": -128, + "col2_int": 16383, + "col3_int": 2**32 // 2 - 1, + "col4_int": 2**64 // 2 - 1, + "col5_int": 2**64 // 2 - 1, + } + ] + info = pipeline.run( + row, + table_name="row", + loader_file_format=destination_config.file_format, + columns=TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS + TABLE_UPDATE_ALL_INT_PRECISIONS, + ) + info.raise_on_failed_jobs() + + with pipeline.sql_client() as client: + table = client.native_connection.sql("SELECT * FROM row").arrow() + + # only us has TZ aware timestamp in duckdb, also we have UTC here + assert table.schema.field(0).type == pa.timestamp("s") + assert table.schema.field(1).type == pa.timestamp("ms") + assert table.schema.field(2).type == pa.timestamp("us", tz="UTC") + assert table.schema.field(3).type == pa.timestamp("ns") + + assert table.schema.field(4).type == pa.int8() + assert table.schema.field(5).type == pa.int16() + assert table.schema.field(6).type == pa.int32() + assert table.schema.field(7).type == pa.int64() + assert table.schema.field(8).type == pa.decimal128(38, 0) + + table_row = table.to_pylist()[0] + table_row["col1_ts"] = ensure_pendulum_datetime(table_row["col1_ts"]) + table_row["col2_ts"] = ensure_pendulum_datetime(table_row["col2_ts"]) + table_row["col4_ts"] = ensure_pendulum_datetime(table_row["col4_ts"]) + table_row.pop("_dlt_id") + table_row.pop("_dlt_load_id") + assert table_row == row[0] diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 017bef2c01..de47b90cd9 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -798,6 +798,11 @@ def other_data(): # duckdb 0.9.1 does not support TIME other than 6 if destination_config.destination in ["duckdb", "motherduck"]: column_schemas["col11_precision"]["precision"] = 0 + # also we do not want to test col4_precision (datetime) because + # those timestamps are not TZ aware in duckdb and we'd need to + # disable TZ when generating parquet + # this is tested in test_duckdb.py + column_schemas["col4_precision"]["precision"] = 6 # drop TIME from databases not supporting it via parquet if destination_config.destination in ["redshift", "athena", "synapse", "databricks"]: diff --git a/tests/load/postgres/test_postgres_table_builder.py b/tests/load/postgres/test_postgres_table_builder.py index 0ab1343a3b..f8e1d07207 100644 --- a/tests/load/postgres/test_postgres_table_builder.py +++ b/tests/load/postgres/test_postgres_table_builder.py @@ -2,6 +2,7 @@ from copy import deepcopy import sqlfluff +from dlt.common.exceptions import TerminalValueError from dlt.common.utils import uniq_id from dlt.common.schema import Schema @@ -11,7 +12,12 @@ PostgresCredentials, ) -from tests.load.utils import TABLE_UPDATE, empty_schema +from tests.cases import ( + TABLE_UPDATE, + TABLE_UPDATE_ALL_INT_PRECISIONS, + TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS, +) +from tests.load.utils import empty_schema @pytest.fixture @@ -49,6 +55,23 @@ def test_create_table(client: PostgresClient) -> None: assert '"col11_precision" time (3) without time zone NOT NULL' in sql +def test_create_table_all_precisions(client: PostgresClient) -> None: + # 128 bit integer will fail + table_update = list(TABLE_UPDATE_ALL_INT_PRECISIONS) + with pytest.raises(TerminalValueError) as tv_ex: + sql = client._get_table_update_sql("event_test_table", table_update, False)[0] + assert "128" in str(tv_ex.value) + + # remove col5 HUGEINT which is last + table_update.pop() + sql = client._get_table_update_sql("event_test_table", table_update, False)[0] + sqlfluff.parse(sql, dialect="duckdb") + assert '"col1_int" smallint ' in sql + assert '"col2_int" smallint ' in sql + assert '"col3_int" integer ' in sql + assert '"col4_int" bigint ' in sql + + def test_alter_table(client: PostgresClient) -> None: # existing table has no columns sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, True)[0] From 67f2fb50d7f758cb778db337798e3e0c3af2bd2d Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 7 Apr 2024 09:32:25 +0200 Subject: [PATCH 20/22] fixes quoting in regular csv writer and force nulls in postgres copy job --- dlt/common/data_writers/writers.py | 1 + dlt/destinations/impl/postgres/postgres.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 29cd3ffe42..b632176c5a 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -371,6 +371,7 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: extrasaction="ignore", dialect=csv.unix_dialect, delimiter=self.delimiter, + quoting=csv.QUOTE_NONNUMERIC, ) self.writer.writeheader() diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index 29c0f1b7e2..0922ed025e 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -113,9 +113,13 @@ def __init__(self, table_name: str, file_path: str, sql_client: Psycopg2SqlClien # all headers in first line headers = f.readline().decode("utf-8").strip() qualified_table_name = sql_client.make_qualified_table_name(table_name) - copy_sql = "COPY %s (%s) FROM STDIN WITH CSV DELIMITER ',' NULL ''" % ( - qualified_table_name, - headers, + copy_sql = ( + "COPY %s (%s) FROM STDIN WITH (FORMAT CSV, DELIMITER ',', NULL '', FORCE_NULL(%s))" + % ( + qualified_table_name, + headers, + headers, + ) ) with sql_client.begin_transaction(): with sql_client.native_connection.cursor() as cursor: From 5039b2f41c5292c7eb5b561700bbaafe36d64253 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 7 Apr 2024 09:32:39 +0200 Subject: [PATCH 21/22] finalizes the docs --- .../website/docs/dlt-ecosystem/destinations/duckdb.md | 11 +++++++++++ .../docs/dlt-ecosystem/file-formats/parquet.md | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/website/docs/dlt-ecosystem/destinations/duckdb.md b/docs/website/docs/dlt-ecosystem/destinations/duckdb.md index 61c9aa203c..79e26554f6 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/duckdb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/duckdb.md @@ -65,6 +65,17 @@ You can configure the following file formats to load data to duckdb: ::: * [jsonl](../file-formats/jsonl.md) **is supported but does not work if JSON fields are optional. The missing keys fail the COPY instead of being interpreted as NULL.** +:::tip +`duckdb` has [timestamp types](https://duckdb.org/docs/sql/data_types/timestamp.html) with resolutions from milliseconds to nanoseconds. However +only microseconds resolution (the most common used) is time zone aware. `dlt` generates timestamps with timezones by default so loading parquet files +with default settings will fail (`duckdb` does not coerce tz-aware timestamps to naive timestamps). +Disable the timezones by changing `dlt` [parquet writer settings](../file-formats/parquet.md#writer-settings) as follows: +```sh +DATA_WRITER__TIMESTAMP_TIMEZONE="" +``` +to disable tz adjustments. +::: + ## Supported column hints `duckdb` may create unique indexes for all columns with `unique` hints, but this behavior **is disabled by default** because it slows the loading down significantly. diff --git a/docs/website/docs/dlt-ecosystem/file-formats/parquet.md b/docs/website/docs/dlt-ecosystem/file-formats/parquet.md index 76600571f0..3fc1a180da 100644 --- a/docs/website/docs/dlt-ecosystem/file-formats/parquet.md +++ b/docs/website/docs/dlt-ecosystem/file-formats/parquet.md @@ -69,7 +69,7 @@ NORMALIZE__DATA_WRITER__TIMESTAMP_TIMEZONE ### Timestamps and timezones `dlt` adds timezone (UTC adjustment) to all timestamps regardless of a precision (from seconds to nanoseconds). `dlt` will also create TZ aware timestamp columns in -the destinations. If the latter is impossible, there are workaround +the destinations. [duckdb is an exception here](../destinations/duckdb.md#supported-file-formats) ### Disable timezones / utc adjustment flags You can generate parquet files without timezone adjustment information in two ways: From ce00e7408cfddaf05a41e304106edc6ff3e88323 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sun, 7 Apr 2024 15:45:55 +0200 Subject: [PATCH 22/22] renames jobs in tests so it is possible to select them as required --- .github/workflows/lint.yml | 2 +- .github/workflows/test_airflow.yml | 2 +- .github/workflows/test_build_images.yml | 2 +- .github/workflows/test_common.yml | 2 +- .github/workflows/test_dbt_cloud.yml | 2 +- .github/workflows/test_dbt_runner.yml | 2 +- .github/workflows/test_destination_athena.yml | 2 +- .github/workflows/test_destination_athena_iceberg.yml | 2 +- .github/workflows/test_destination_bigquery.yml | 4 ++-- .github/workflows/test_destination_databricks.yml | 2 +- .github/workflows/test_destination_mssql.yml | 2 +- .github/workflows/test_destination_qdrant.yml | 2 +- .github/workflows/test_destination_snowflake.yml | 2 +- .github/workflows/test_destination_synapse.yml | 2 +- .github/workflows/test_destinations.yml | 2 +- .github/workflows/test_doc_snippets.yml | 2 +- .github/workflows/test_local_destinations.yml | 4 ++-- 17 files changed, 19 insertions(+), 19 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 0c6ddcee73..d6b1639685 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -67,7 +67,7 @@ jobs: make lint matrix_job_required_check: - name: Lint results + name: lint | code & tests needs: run_lint runs-on: ubuntu-latest if: always() diff --git a/.github/workflows/test_airflow.yml b/.github/workflows/test_airflow.yml index 8e3a9cf3d8..2a96c4475e 100644 --- a/.github/workflows/test_airflow.yml +++ b/.github/workflows/test_airflow.yml @@ -17,7 +17,7 @@ jobs: uses: ./.github/workflows/get_docs_changes.yml run_airflow: - name: test + name: tools | airflow tests needs: get_docs_changes if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' runs-on: ubuntu-latest diff --git a/.github/workflows/test_build_images.yml b/.github/workflows/test_build_images.yml index c9a99eda2d..665f8b2509 100644 --- a/.github/workflows/test_build_images.yml +++ b/.github/workflows/test_build_images.yml @@ -17,7 +17,7 @@ jobs: uses: ./.github/workflows/get_docs_changes.yml run_airflow: - name: build + name: tools | docker images build needs: get_docs_changes if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' runs-on: ubuntu-latest diff --git a/.github/workflows/test_common.yml b/.github/workflows/test_common.yml index 35ef9d0696..68c4768af6 100644 --- a/.github/workflows/test_common.yml +++ b/.github/workflows/test_common.yml @@ -148,7 +148,7 @@ jobs: # shell: cmd matrix_job_required_check: - name: Common tests + name: common | common tests needs: run_common runs-on: ubuntu-latest if: always() diff --git a/.github/workflows/test_dbt_cloud.yml b/.github/workflows/test_dbt_cloud.yml index 0f5c169e6e..98fa44d304 100644 --- a/.github/workflows/test_dbt_cloud.yml +++ b/.github/workflows/test_dbt_cloud.yml @@ -27,7 +27,7 @@ jobs: if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} run_dbt_cloud: - name: test + name: tools | dbt cloud tests needs: get_docs_changes if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' defaults: diff --git a/.github/workflows/test_dbt_runner.yml b/.github/workflows/test_dbt_runner.yml index 0ca784a1ae..cb26a97b96 100644 --- a/.github/workflows/test_dbt_runner.yml +++ b/.github/workflows/test_dbt_runner.yml @@ -24,7 +24,7 @@ jobs: if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} run_dbt: - name: test + name: tools | dbt runner tests needs: get_docs_changes if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' defaults: diff --git a/.github/workflows/test_destination_athena.yml b/.github/workflows/test_destination_athena.yml index 959fffcfd4..81ef86f713 100644 --- a/.github/workflows/test_destination_athena.yml +++ b/.github/workflows/test_destination_athena.yml @@ -32,7 +32,7 @@ jobs: if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} run_loader: - name: test + name: dest | athena tests needs: get_docs_changes if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' defaults: diff --git a/.github/workflows/test_destination_athena_iceberg.yml b/.github/workflows/test_destination_athena_iceberg.yml index ea3cb3c06b..c1041be26c 100644 --- a/.github/workflows/test_destination_athena_iceberg.yml +++ b/.github/workflows/test_destination_athena_iceberg.yml @@ -32,7 +32,7 @@ jobs: if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} run_loader: - name: test + name: dest | athena iceberg tests needs: get_docs_changes if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' defaults: diff --git a/.github/workflows/test_destination_bigquery.yml b/.github/workflows/test_destination_bigquery.yml index 95f7edfb4d..cc55d5a5b2 100644 --- a/.github/workflows/test_destination_bigquery.yml +++ b/.github/workflows/test_destination_bigquery.yml @@ -30,7 +30,7 @@ jobs: if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} run_loader: - name: test + name: dest | bigquery tests needs: get_docs_changes if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' defaults: @@ -69,7 +69,7 @@ jobs: - name: create secrets.toml run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - + - run: | poetry run pytest tests/load -m "essential" name: Run essential tests Linux diff --git a/.github/workflows/test_destination_databricks.yml b/.github/workflows/test_destination_databricks.yml index 95ce20fb90..eb98f24fd5 100644 --- a/.github/workflows/test_destination_databricks.yml +++ b/.github/workflows/test_destination_databricks.yml @@ -30,7 +30,7 @@ jobs: if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} run_loader: - name: test + name: dest | databricks tests needs: get_docs_changes if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' defaults: diff --git a/.github/workflows/test_destination_mssql.yml b/.github/workflows/test_destination_mssql.yml index adf7437f1b..57f83694dc 100644 --- a/.github/workflows/test_destination_mssql.yml +++ b/.github/workflows/test_destination_mssql.yml @@ -31,7 +31,7 @@ jobs: if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} run_loader: - name: test + name: dest | mssql tests needs: get_docs_changes if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' defaults: diff --git a/.github/workflows/test_destination_qdrant.yml b/.github/workflows/test_destination_qdrant.yml index 9131e9c62d..1bc45ff643 100644 --- a/.github/workflows/test_destination_qdrant.yml +++ b/.github/workflows/test_destination_qdrant.yml @@ -29,7 +29,7 @@ jobs: if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} run_loader: - name: test + name: dest | qdrant tests needs: get_docs_changes if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' defaults: diff --git a/.github/workflows/test_destination_snowflake.yml b/.github/workflows/test_destination_snowflake.yml index c46ca95a6b..ab55f2f18f 100644 --- a/.github/workflows/test_destination_snowflake.yml +++ b/.github/workflows/test_destination_snowflake.yml @@ -30,7 +30,7 @@ jobs: if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} run_loader: - name: test + name: dest | snowflake tests needs: get_docs_changes if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' defaults: diff --git a/.github/workflows/test_destination_synapse.yml b/.github/workflows/test_destination_synapse.yml index 997b0a2903..9ee48edc46 100644 --- a/.github/workflows/test_destination_synapse.yml +++ b/.github/workflows/test_destination_synapse.yml @@ -29,7 +29,7 @@ jobs: if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} run_loader: - name: test + name: dest | synapse tests needs: get_docs_changes if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' defaults: diff --git a/.github/workflows/test_destinations.yml b/.github/workflows/test_destinations.yml index a23f5f3f4d..3b445eea82 100644 --- a/.github/workflows/test_destinations.yml +++ b/.github/workflows/test_destinations.yml @@ -38,7 +38,7 @@ jobs: if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} run_loader: - name: test + name: dest | redshift, postgres and fs tests needs: get_docs_changes if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' strategy: diff --git a/.github/workflows/test_doc_snippets.yml b/.github/workflows/test_doc_snippets.yml index 877ec0e530..70ecad3325 100644 --- a/.github/workflows/test_doc_snippets.yml +++ b/.github/workflows/test_doc_snippets.yml @@ -27,7 +27,7 @@ env: jobs: run_lint: - name: lint and test + name: docs | snippets & examples lint and test runs-on: ubuntu-latest # Do not run on forks, unless allowed, secrets are used here if: ${{ !github.event.pull_request.head.repo.fork || contains(github.event.pull_request.labels.*.name, 'ci from fork')}} diff --git a/.github/workflows/test_local_destinations.yml b/.github/workflows/test_local_destinations.yml index 653b4dbd75..0138e087ff 100644 --- a/.github/workflows/test_local_destinations.yml +++ b/.github/workflows/test_local_destinations.yml @@ -29,7 +29,7 @@ jobs: uses: ./.github/workflows/get_docs_changes.yml run_loader: - name: test + name: dest | postgres, duckdb and fs local tests needs: get_docs_changes if: needs.get_docs_changes.outputs.changes_outside_docs == 'true' strategy: @@ -87,7 +87,7 @@ jobs: - name: Install dependencies run: poetry install --no-interaction -E postgres -E duckdb -E parquet -E filesystem -E cli -E weaviate --with sentry-sdk --with pipeline - + - name: create secrets.toml run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml