From d3db284e83db13e5b4e14ae86a4922ad99f964e0 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Mon, 16 Oct 2023 17:59:05 -0400 Subject: [PATCH] Pyarrow direct loading (#679) * poc: direct pyarrow load * arrow to schema types with precision, test * Fix naming copy_atomic -> move_atomic * jsonl/parquet file normalizer classes * pathlib extension checks * indent fix * Write parquet with original schema * extract refactoring * Init testing, bugfix * Fix import filename * Dep * Mockup incremental implementation for arrow tables * Create loadstorage per filetype, import with hardlink * Fallback for extract item format, detect type of lists * Error message, load tests with arrow * Some incremental optimizations some * Incremental fixes and run incremental tests on arrow & pandas * Add/update normalize tests * Fix load test * Lint * Add docs page for arrow loading * Handle none capability * Fix extract lists * Exclude TIME in redshift test * Fix type errors * Typo * Create col from numpy array for >200x speedup, index after filter * in -> not in * Format binary as hex for redshift * enables bool and duckdb test on pyarrow loading --------- Co-authored-by: Marcin Rudolf --- dlt/common/data_writers/writers.py | 48 +- dlt/common/destination/__init__.py | 4 +- dlt/common/destination/capabilities.py | 5 +- dlt/common/libs/pyarrow.py | 93 ++- dlt/common/storages/data_item_storage.py | 2 +- dlt/common/storages/file_storage.py | 54 +- dlt/common/storages/load_storage.py | 10 +- dlt/common/storages/normalize_storage.py | 22 +- dlt/common/typing.py | 4 +- dlt/extract/extract.py | 269 ++++++-- .../__init__.py} | 162 ++--- dlt/extract/incremental/exceptions.py | 18 + dlt/extract/incremental/transform.py | 276 ++++++++ dlt/extract/incremental/typing.py | 10 + dlt/normalize/items_normalizers.py | 134 ++++ dlt/normalize/normalize.py | 116 ++-- .../verified-sources/arrow-pandas.md | 107 +++ tests/cases.py | 41 +- .../common/storages/test_normalize_storage.py | 4 +- tests/common/test_pyarrow.py | 51 ++ tests/extract/test_incremental.py | 608 ++++++++++++------ tests/extract/utils.py | 40 +- tests/load/pipeline/test_arrow_loading.py | 56 ++ tests/load/pipeline/test_pipelines.py | 2 +- tests/load/utils.py | 4 + tests/normalize/test_normalize.py | 4 +- tests/pipeline/test_arrow_sources.py | 113 ++++ 27 files changed, 1774 insertions(+), 483 deletions(-) rename dlt/extract/{incremental.py => incremental/__init__.py} (81%) create mode 100644 dlt/extract/incremental/exceptions.py create mode 100644 dlt/extract/incremental/transform.py create mode 100644 dlt/extract/incremental/typing.py create mode 100644 dlt/normalize/items_normalizers.py create mode 100644 docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md create mode 100644 tests/common/test_pyarrow.py create mode 100644 tests/load/pipeline/test_arrow_loading.py create mode 100644 tests/pipeline/test_arrow_sources.py diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 00542535ca..73e13ec46f 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -1,14 +1,14 @@ import abc - from dataclasses import dataclass -from typing import Any, Dict, Sequence, IO, Type, Optional, List, cast +from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, Union from dlt.common import json -from dlt.common.typing import StrAny -from dlt.common.schema.typing import TTableSchemaColumns -from dlt.common.destination import TLoaderFileFormat, DestinationCapabilitiesContext -from dlt.common.configuration import with_config, known_sections, configspec +from dlt.common.configuration import configspec, known_sections, with_config from dlt.common.configuration.specs import BaseConfiguration +from dlt.common.destination import DestinationCapabilitiesContext, TLoaderFileFormat +from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common.typing import StrAny + @dataclass class TFileFormatSpec: @@ -70,6 +70,8 @@ def class_factory(file_format: TLoaderFileFormat) -> Type["DataWriter"]: return InsertValuesWriter elif file_format == "parquet": return ParquetDataWriter # type: ignore + elif file_format == "arrow": + return ArrowWriter # type: ignore else: raise ValueError(file_format) @@ -249,3 +251,37 @@ def write_footer(self) -> None: @classmethod def data_format(cls) -> TFileFormatSpec: return TFileFormatSpec("parquet", "parquet", True, False, requires_destination_capabilities=True, supports_compression=False) + + +class ArrowWriter(ParquetDataWriter): + def write_header(self, columns_schema: TTableSchemaColumns) -> None: + # Schema will be written as-is from the arrow table + pass + + def write_data(self, rows: Sequence[Any]) -> None: + from dlt.common.libs.pyarrow import pyarrow + rows = list(rows) + if not rows: + return + first = rows[0] + self.writer = self.writer or pyarrow.parquet.ParquetWriter( + self._f, first.schema, flavor=self.parquet_flavor, version=self.parquet_version, data_page_size=self.parquet_data_page_size + ) + for row in rows: + if isinstance(row, pyarrow.Table): + self.writer.write_table(row) + elif isinstance(row, pyarrow.RecordBatch): + self.writer.write_batch(row) + else: + raise ValueError(f"Unsupported type {type(row)}") + + @classmethod + def data_format(cls) -> TFileFormatSpec: + return TFileFormatSpec( + "arrow", + file_extension="parquet", + is_binary_format=True, + supports_schema_changes=False, + requires_destination_capabilities=False, + supports_compression=False, + ) diff --git a/dlt/common/destination/__init__.py b/dlt/common/destination/__init__.py index d4e91acdad..79940cfbc0 100644 --- a/dlt/common/destination/__init__.py +++ b/dlt/common/destination/__init__.py @@ -1,2 +1,2 @@ -from dlt.common.destination.capabilities import DestinationCapabilitiesContext, TLoaderFileFormat -from dlt.common.destination.reference import DestinationReference, TDestinationReferenceArg \ No newline at end of file +from dlt.common.destination.capabilities import DestinationCapabilitiesContext, TLoaderFileFormat, ALL_SUPPORTED_FILE_FORMATS +from dlt.common.destination.reference import DestinationReference, TDestinationReferenceArg diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index 245e72725c..06504ee590 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -14,9 +14,10 @@ # puae-jsonl - internal extract -> normalize format bases on jsonl # insert_values - insert SQL statements # sql - any sql statement -TLoaderFileFormat = Literal["jsonl", "puae-jsonl", "insert_values", "sql", "parquet", "reference"] +TLoaderFileFormat = Literal["jsonl", "puae-jsonl", "insert_values", "sql", "parquet", "reference", "arrow"] +ALL_SUPPORTED_FILE_FORMATS: Set[TLoaderFileFormat] = set(get_args(TLoaderFileFormat)) # file formats used internally by dlt -INTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = {"puae-jsonl", "sql", "reference"} +INTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = {"puae-jsonl", "sql", "reference", "arrow"} # file formats that may be chosen by the user EXTERNAL_LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = set(get_args(TLoaderFileFormat)) - INTERNAL_LOADER_FILE_FORMATS diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index cb34856406..f50226ea4f 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -1,9 +1,12 @@ -from typing import Any, Tuple, Optional +from typing import Any, Tuple, Optional, Union from dlt import version from dlt.common.exceptions import MissingDependencyException +from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.schema.typing import TColumnType +from dlt.common.data_types import TDataType +from dlt.common.typing import TFileOrPath try: import pyarrow @@ -12,6 +15,9 @@ raise MissingDependencyException("DLT parquet Helpers", [f"{version.DLT_PKG_NAME}[parquet]"], "DLT Helpers for for parquet.") +TAnyArrowItem = Union[pyarrow.Table, pyarrow.RecordBatch] + + def get_py_arrow_datatype(column: TColumnType, caps: DestinationCapabilitiesContext, tz: str) -> Any: column_type = column["data_type"] if column_type == "text": @@ -82,3 +88,88 @@ def get_pyarrow_int(precision: Optional[int]) -> Any: elif precision <= 32: return pyarrow.int32() return pyarrow.int64() + + +def _get_column_type_from_py_arrow(dtype: pyarrow.DataType) -> TColumnType: + """Returns (data_type, precision, scale) tuple from pyarrow.DataType + """ + if pyarrow.types.is_string(dtype) or pyarrow.types.is_large_string(dtype): + return dict(data_type="text") + elif pyarrow.types.is_floating(dtype): + return dict(data_type="double") + elif pyarrow.types.is_boolean(dtype): + return dict(data_type="bool") + elif pyarrow.types.is_timestamp(dtype): + if dtype.unit == "s": + precision = 0 + elif dtype.unit == "ms": + precision = 3 + elif dtype.unit == "us": + precision = 6 + else: + precision = 9 + return dict(data_type="timestamp", precision=precision) + elif pyarrow.types.is_date(dtype): + return dict(data_type="date") + elif pyarrow.types.is_time(dtype): + # Time fields in schema are `DataType` instead of `Time64Type` or `Time32Type` + if dtype == pyarrow.time32("s"): + precision = 0 + elif dtype == pyarrow.time32("ms"): + precision = 3 + elif dtype == pyarrow.time64("us"): + precision = 6 + else: + precision = 9 + return dict(data_type="time", precision=precision) + elif pyarrow.types.is_integer(dtype): + result: TColumnType = dict(data_type="bigint") + if dtype.bit_width != 64: # 64bit is a default bigint + result["precision"] = dtype.bit_width + return result + elif pyarrow.types.is_fixed_size_binary(dtype): + return dict(data_type="binary", precision=dtype.byte_width) + elif pyarrow.types.is_binary(dtype) or pyarrow.types.is_large_binary(dtype): + return dict(data_type="binary") + elif pyarrow.types.is_decimal(dtype): + return dict(data_type="decimal", precision=dtype.precision, scale=dtype.scale) + elif pyarrow.types.is_nested(dtype): + return dict(data_type="complex") + else: + raise ValueError(dtype) + + +def py_arrow_to_table_schema_columns(schema: pyarrow.Schema) -> TTableSchemaColumns: + """Convert a PyArrow schema to a table schema columns dict. + + Args: + schema (pyarrow.Schema): pyarrow schema + + Returns: + TTableSchemaColumns: table schema columns + """ + result: TTableSchemaColumns = {} + for field in schema: + result[field.name] = { + "name": field.name, + "nullable": field.nullable, + **_get_column_type_from_py_arrow(field.type), + } + return result + + +def get_row_count(parquet_file: TFileOrPath) -> int: + """Get the number of rows in a parquet file. + + Args: + parquet_file (str): path to parquet file + + Returns: + int: number of rows + """ + with pyarrow.parquet.ParquetFile(parquet_file) as reader: + return reader.metadata.num_rows # type: ignore[no-any-return] + + +def is_arrow_item(item: Any) -> bool: + return isinstance(item, (pyarrow.Table, pyarrow.RecordBatch)) diff --git a/dlt/common/storages/data_item_storage.py b/dlt/common/storages/data_item_storage.py index c0ddcf725f..8de95a6f60 100644 --- a/dlt/common/storages/data_item_storage.py +++ b/dlt/common/storages/data_item_storage.py @@ -18,7 +18,7 @@ def get_writer(self, load_id: str, schema_name: str, table_name: str) -> Buffere writer_id = f"{load_id}.{schema_name}.{table_name}" writer = self.buffered_writers.get(writer_id, None) if not writer: - # assign a jsonl writer for each table + # assign a writer for each table path = self._get_data_item_path_template(load_id, schema_name, table_name) writer = BufferedDataWriter(self.loader_file_format, path) self.buffered_writers[writer_id] = writer diff --git a/dlt/common/storages/file_storage.py b/dlt/common/storages/file_storage.py index cf495100ba..006ff4843d 100644 --- a/dlt/common/storages/file_storage.py +++ b/dlt/common/storages/file_storage.py @@ -45,14 +45,19 @@ def save_atomic(storage_path: str, relative_path: str, data: Any, file_type: str raise @staticmethod - def copy_atomic(source_file_path: str, dest_folder_path: str) -> str: + def move_atomic_to_folder(source_file_path: str, dest_folder_path: str) -> str: file_name = os.path.basename(source_file_path) dest_file_path = os.path.join(dest_folder_path, file_name) + return FileStorage.move_atomic_to_file(source_file_path, dest_file_path) + + @staticmethod + def move_atomic_to_file(source_file_path: str, dest_file_path: str) -> str: try: os.rename(source_file_path, dest_file_path) except OSError: # copy to local temp file - dest_temp_file = os.path.join(dest_folder_path, uniq_id()) + folder_name = os.path.dirname(dest_file_path) + dest_temp_file = os.path.join(folder_name, uniq_id()) try: shutil.copyfile(source_file_path, dest_temp_file) os.rename(dest_temp_file, dest_file_path) @@ -63,6 +68,19 @@ def copy_atomic(source_file_path: str, dest_folder_path: str) -> str: raise return dest_file_path + @staticmethod + def copy_atomic_to_file(source_file_path: str, dest_file_path: str) -> str: + folder_name = os.path.dirname(dest_file_path) + dest_temp_file = os.path.join(folder_name, uniq_id()) + try: + shutil.copyfile(source_file_path, dest_temp_file) + os.rename(dest_temp_file, dest_file_path) + except Exception: + if os.path.isfile(dest_temp_file): + os.remove(dest_temp_file) + raise + return dest_file_path + def load(self, relative_path: str) -> Any: # raises on file not existing with self.open_file(relative_path) as text_file: @@ -144,6 +162,19 @@ def link_hard(self, from_relative_path: str, to_relative_path: str) -> None: self.make_full_path(to_relative_path) ) + @staticmethod + def link_hard_with_fallback(external_file_path: str, to_file_path: str) -> None: + """Try to create a hardlink and fallback to copying when filesystem doesn't support links + """ + try: + os.link(external_file_path, to_file_path) + except OSError as ex: + # Fallback to copy when fs doesn't support links or attempting to make a cross-device link + if ex.errno in (errno.EPERM, errno.EXDEV, errno.EMLINK): + FileStorage.copy_atomic_to_file(external_file_path, to_file_path) + else: + raise + def atomic_rename(self, from_relative_path: str, to_relative_path: str) -> None: """Renames a path using os.rename which is atomic on POSIX, Windows and NFS v4. @@ -195,11 +226,20 @@ def rename_tree_files(self, from_relative_path: str, to_relative_path: str) -> N if not os.listdir(root): os.rmdir(root) - def atomic_import(self, external_file_path: str, to_folder: str) -> str: - """Moves a file at `external_file_path` into the `to_folder` effectively importing file into storage""" - return self.to_relative_path(FileStorage.copy_atomic(external_file_path, self.make_full_path(to_folder))) - # file_name = FileStorage.get_file_name_from_file_path(external_path) - # os.rename(external_path, os.path.join(self.make_full_path(to_folder), file_name)) + def atomic_import(self, external_file_path: str, to_folder: str, new_file_name: Optional[str] = None) -> str: + """Moves a file at `external_file_path` into the `to_folder` effectively importing file into storage + + Args: + external_file_path: Path to file to be imported + to_folder: Path to folder where file should be imported + new_file_name: Optional new file name for the imported file, otherwise the original file name is used + + Returns: + Path to imported file relative to storage root + """ + new_file_name = new_file_name or os.path.basename(external_file_path) + dest_file_path = os.path.join(self.make_full_path(to_folder), new_file_name) + return self.to_relative_path(FileStorage.move_atomic_to_file(external_file_path, dest_file_path)) def in_storage(self, path: str) -> bool: assert path is not None diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index 2f52365787..f4a0d88017 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -13,7 +13,8 @@ from dlt.common.configuration.inject import with_config from dlt.common.typing import DictStrAny, StrAny from dlt.common.storages.file_storage import FileStorage -from dlt.common.data_writers import TLoaderFileFormat, DataWriter +from dlt.common.data_writers import DataWriter +from dlt.common.destination import ALL_SUPPORTED_FILE_FORMATS, TLoaderFileFormat from dlt.common.configuration.accessors import config from dlt.common.exceptions import TerminalValueError from dlt.common.schema import Schema, TSchemaTables, TTableSchemaColumns @@ -138,7 +139,7 @@ class LoadStorage(DataItemStorage, VersionedStorage): SCHEMA_FILE_NAME = "schema.json" # package schema PACKAGE_COMPLETED_FILE_NAME = "package_completed.json" # completed package marker file, currently only to store data with os.stat - ALL_SUPPORTED_FILE_FORMATS: Set[TLoaderFileFormat] = set(get_args(TLoaderFileFormat)) + ALL_SUPPORTED_FILE_FORMATS = ALL_SUPPORTED_FILE_FORMATS @with_config(spec=LoadStorageConfiguration, sections=(known_sections.LOAD,)) def __init__( @@ -314,6 +315,11 @@ def add_new_job(self, load_id: str, job_file_path: str, job_state: TJobState = " """Adds new job by moving the `job_file_path` into `new_jobs` of package `load_id`""" self.storage.atomic_import(job_file_path, self._get_job_folder_path(load_id, job_state)) + def atomic_import(self, external_file_path: str, to_folder: str) -> str: + """Copies or links a file at `external_file_path` into the `to_folder` effectively importing file into storage""" + # LoadStorage.parse_job_file_name + return self.storage.to_relative_path(FileStorage.move_atomic_to_folder(external_file_path, self.storage.make_full_path(to_folder))) + def start_job(self, load_id: str, file_name: str) -> str: return self._move_job(load_id, LoadStorage.NEW_JOBS_FOLDER, LoadStorage.STARTED_JOBS_FOLDER, file_name) diff --git a/dlt/common/storages/normalize_storage.py b/dlt/common/storages/normalize_storage.py index 45f541f5ec..44e6fe2f1c 100644 --- a/dlt/common/storages/normalize_storage.py +++ b/dlt/common/storages/normalize_storage.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Sequence, NamedTuple +from typing import ClassVar, Sequence, NamedTuple, Union from itertools import groupby from pathlib import Path @@ -7,11 +7,14 @@ from dlt.common.storages.file_storage import FileStorage from dlt.common.storages.configuration import NormalizeStorageConfiguration from dlt.common.storages.versioned_storage import VersionedStorage +from dlt.common.destination import TLoaderFileFormat, ALL_SUPPORTED_FILE_FORMATS +from dlt.common.exceptions import TerminalValueError class TParsedNormalizeFileName(NamedTuple): schema_name: str table_name: str file_id: str + file_format: TLoaderFileFormat class NormalizeStorage(VersionedStorage): @@ -47,10 +50,13 @@ def build_extracted_file_stem(schema_name: str, table_name: str, file_id: str) - @staticmethod def parse_normalize_file_name(file_name: str) -> TParsedNormalizeFileName: # parse extracted file name and returns (events found, load id, schema_name) - if not file_name.endswith("jsonl"): - raise ValueError(file_name) - - parts = Path(file_name).stem.split(".") - if len(parts) != 3: - raise ValueError(file_name) - return TParsedNormalizeFileName(*parts) + file_name_p: Path = Path(file_name) + parts = file_name_p.name.split(".") + ext = parts[-1] + if ext not in ALL_SUPPORTED_FILE_FORMATS: + raise TerminalValueError(f"File format {ext} not supported. Filename: {file_name}") + return TParsedNormalizeFileName(*parts) # type: ignore[arg-type] + + def delete_extracted_files(self, files: Sequence[str]) -> None: + for file_name in files: + self.storage.delete(file_name) diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 402e116ed8..b2bd03f7e6 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -1,8 +1,9 @@ from collections.abc import Mapping as C_Mapping, Sequence as C_Sequence from datetime import datetime, date # noqa: I251 import inspect +import os from re import Pattern as _REPattern -from typing import Callable, Dict, Any, Final, Literal, List, Mapping, NewType, Optional, Tuple, Type, TypeVar, Generic, Protocol, TYPE_CHECKING, Union, runtime_checkable, get_args, get_origin +from typing import Callable, Dict, Any, Final, Literal, List, Mapping, NewType, Optional, Tuple, Type, TypeVar, Generic, Protocol, TYPE_CHECKING, Union, runtime_checkable, get_args, get_origin, IO from typing_extensions import TypeAlias, ParamSpec, Concatenate from dlt.common.pendulum import timedelta, pendulum @@ -44,6 +45,7 @@ TVariantBase = TypeVar("TVariantBase", covariant=True) TVariantRV = Tuple[str, Any] VARIANT_FIELD_FORMAT = "v_%s" +TFileOrPath = Union[str, os.PathLike, IO[Any]] @runtime_checkable class SupportsVariant(Protocol, Generic[TVariantBase]): diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index 655f0dec8a..7f8142e807 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -1,18 +1,22 @@ import contextlib import os -from typing import ClassVar, List, Set +from typing import ClassVar, List, Set, Dict, Type, Any, Sequence, Optional +from collections import defaultdict from dlt.common.configuration.container import Container from dlt.common.configuration.resolve import inject_section from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.pipeline import reset_resource_state +from dlt.common.data_writers import TLoaderFileFormat +from dlt.common.exceptions import MissingDependencyException from dlt.common.runtime import signals from dlt.common.runtime.collector import Collector, NULL_COLLECTOR from dlt.common.utils import uniq_id from dlt.common.typing import TDataItems, TDataItem from dlt.common.schema import Schema, utils, TSchemaUpdate -from dlt.common.storages import NormalizeStorageConfiguration, NormalizeStorage, DataItemStorage +from dlt.common.schema.typing import TColumnSchema, TTableSchemaColumns +from dlt.common.storages import NormalizeStorageConfiguration, NormalizeStorage, DataItemStorage, FileStorage from dlt.common.configuration.specs import known_sections from dlt.extract.decorators import SourceSchemaInjectableContext @@ -20,21 +24,68 @@ from dlt.extract.pipe import PipeIterator from dlt.extract.source import DltResource, DltSource from dlt.extract.typing import TableNameMeta +try: + from dlt.common.libs import pyarrow +except MissingDependencyException: + pyarrow = None +try: + import pandas as pd +except ModuleNotFoundError: + pd = None -class ExtractorStorage(DataItemStorage, NormalizeStorage): +class ExtractorItemStorage(DataItemStorage): + load_file_type: TLoaderFileFormat + + def __init__(self, storage: FileStorage, extract_folder: str="extract") -> None: + # data item storage with jsonl with pua encoding + super().__init__(self.load_file_type) + self.extract_folder = extract_folder + self.storage = storage + + + def _get_data_item_path_template(self, load_id: str, schema_name: str, table_name: str) -> str: + template = NormalizeStorage.build_extracted_file_stem(schema_name, table_name, "%s") + return self.storage.make_full_path(os.path.join(self._get_extract_path(load_id), template)) + + def _get_extract_path(self, extract_id: str) -> str: + return os.path.join(self.extract_folder, extract_id) + + +class JsonLExtractorStorage(ExtractorItemStorage): + load_file_type: TLoaderFileFormat = "puae-jsonl" + + +class ArrowExtractorStorage(ExtractorItemStorage): + load_file_type: TLoaderFileFormat = "arrow" + + +class ExtractorStorage(NormalizeStorage): EXTRACT_FOLDER: ClassVar[str] = "extract" + """Wrapper around multiple extractor storages with different file formats""" def __init__(self, C: NormalizeStorageConfiguration) -> None: - # data item storage with jsonl with pua encoding - super().__init__("puae-jsonl", True, C) - self.storage.create_folder(ExtractorStorage.EXTRACT_FOLDER, exists_ok=True) + super().__init__(True, C) + self._item_storages: Dict[TLoaderFileFormat, ExtractorItemStorage] = { + "puae-jsonl": JsonLExtractorStorage(self.storage, extract_folder=self.EXTRACT_FOLDER), + "arrow": ArrowExtractorStorage(self.storage, extract_folder=self.EXTRACT_FOLDER) + } + + def _get_extract_path(self, extract_id: str) -> str: + return os.path.join(self.EXTRACT_FOLDER, extract_id) def create_extract_id(self) -> str: extract_id = uniq_id() self.storage.create_folder(self._get_extract_path(extract_id)) return extract_id + def get_storage(self, loader_file_format: TLoaderFileFormat) -> ExtractorItemStorage: + return self._item_storages[loader_file_format] + + def close_writers(self, extract_id: str) -> None: + for storage in self._item_storages.values(): + storage.close_writers(extract_id) + def commit_extract_files(self, extract_id: str, with_delete: bool = True) -> None: extract_path = self._get_extract_path(extract_id) for file in self.storage.list_folder_files(extract_path, to_root=False): @@ -48,12 +99,135 @@ def commit_extract_files(self, extract_id: str, with_delete: bool = True) -> Non if with_delete: self.storage.delete_folder(extract_path, recursively=True) - def _get_data_item_path_template(self, load_id: str, schema_name: str, table_name: str) -> str: - template = NormalizeStorage.build_extracted_file_stem(schema_name, table_name, "%s") - return self.storage.make_full_path(os.path.join(self._get_extract_path(load_id), template)) + def write_data_item(self, file_format: TLoaderFileFormat, load_id: str, schema_name: str, table_name: str, item: TDataItems, columns: TTableSchemaColumns) -> None: + self.get_storage(file_format).write_data_item(load_id, schema_name, table_name, item, columns) + - def _get_extract_path(self, extract_id: str) -> str: - return os.path.join(ExtractorStorage.EXTRACT_FOLDER, extract_id) + +class Extractor: + file_format: TLoaderFileFormat + dynamic_tables: TSchemaUpdate + def __init__( + self, + extract_id: str, + storage: ExtractorStorage, + schema: Schema, + resources_with_items: Set[str], + dynamic_tables: TSchemaUpdate, + collector: Collector = NULL_COLLECTOR + ) -> None: + self._storage = storage + self.schema = schema + self.dynamic_tables = dynamic_tables + self.collector = collector + self.resources_with_items = resources_with_items + self.extract_id = extract_id + + @property + def storage(self) -> ExtractorItemStorage: + return self._storage.get_storage(self.file_format) + + @staticmethod + def item_format(items: TDataItems) -> Optional[TLoaderFileFormat]: + """Detect the loader file format of the data items based on type. + Currently this is either 'arrow' or 'puae-jsonl' + + Returns: + The loader file format or `None` if if can't be detected. + """ + for item in items if isinstance(items, list) else [items]: + # Assume all items in list are the same type + if (pyarrow and pyarrow.is_arrow_item(item)) or (pd and isinstance(item, pd.DataFrame)): + return "arrow" + return "puae-jsonl" + return None # Empty list is unknown format + + def write_table(self, resource: DltResource, items: TDataItems, meta: Any) -> None: + if isinstance(meta, TableNameMeta): + table_name = meta.table_name + self._write_static_table(resource, table_name, items) + self._write_item(table_name, resource.name, items) + else: + if resource._table_name_hint_fun: + if isinstance(items, list): + for item in items: + self._write_dynamic_table(resource, item) + else: + self._write_dynamic_table(resource, items) + else: + # write item belonging to table with static name + table_name = resource.table_name # type: ignore[assignment] + self._write_static_table(resource, table_name, items) + self._write_item(table_name, resource.name, items) + + def write_empty_file(self, table_name: str) -> None: + table_name = self.schema.naming.normalize_table_identifier(table_name) + self.storage.write_empty_file(self.extract_id, self.schema.name, table_name, None) + + def _write_item(self, table_name: str, resource_name: str, items: TDataItems) -> None: + # normalize table name before writing so the name match the name in schema + # note: normalize function should be cached so there's almost no penalty on frequent calling + # note: column schema is not required for jsonl writer used here + table_name = self.schema.naming.normalize_identifier(table_name) + self.collector.update(table_name) + self.resources_with_items.add(resource_name) + self.storage.write_data_item(self.extract_id, self.schema.name, table_name, items, None) + + def _write_dynamic_table(self, resource: DltResource, item: TDataItem) -> None: + table_name = resource._table_name_hint_fun(item) + existing_table = self.dynamic_tables.get(table_name) + if existing_table is None: + self.dynamic_tables[table_name] = [resource.compute_table_schema(item)] + else: + # quick check if deep table merge is required + if resource._table_has_other_dynamic_hints: + new_table = resource.compute_table_schema(item) + # this merges into existing table in place + utils.merge_tables(existing_table[0], new_table) + else: + # if there are no other dynamic hints besides name then we just leave the existing partial table + pass + # write to storage with inferred table name + self._write_item(table_name, resource.name, item) + + def _write_static_table(self, resource: DltResource, table_name: str, items: TDataItems) -> None: + existing_table = self.dynamic_tables.get(table_name) + if existing_table is None: + static_table = resource.compute_table_schema() + static_table["name"] = table_name + self.dynamic_tables[table_name] = [static_table] + + +class JsonLExtractor(Extractor): + file_format = "puae-jsonl" + + +class ArrowExtractor(Extractor): + file_format = "arrow" + + def write_table(self, resource: DltResource, items: TDataItems, meta: Any) -> None: + items = [ + pyarrow.pyarrow.Table.from_pandas(item) if (pd and isinstance(item, pd.DataFrame)) else item + for item in (items if isinstance(items, list) else [items]) + ] + super().write_table(resource, items, meta) + + def _write_static_table(self, resource: DltResource, table_name: str, items: TDataItems) -> None: + existing_table = self.dynamic_tables.get(table_name) + if existing_table is not None: + return + static_table = resource.compute_table_schema() + if isinstance(items, list): + item = items[0] + else: + item = items + # Merge the columns to include primary_key and other hints that may be set on the resource + arrow_columns = pyarrow.py_arrow_to_table_schema_columns(item.schema) + for key, value in static_table["columns"].items(): + arrow_columns[key] = utils.merge_columns(value, arrow_columns.get(key, {})) + static_table["columns"] = arrow_columns + static_table["name"] = table_name + self.dynamic_tables[table_name] = [static_table] def extract( @@ -66,50 +240,20 @@ def extract( workers: int = None, futures_poll_interval: float = None ) -> TSchemaUpdate: - dynamic_tables: TSchemaUpdate = {} schema = source.schema resources_with_items: Set[str] = set() + extractors: Dict[TLoaderFileFormat, Extractor] = { + "puae-jsonl": JsonLExtractor( + extract_id, storage, schema, resources_with_items, dynamic_tables, collector=collector + ), + "arrow": ArrowExtractor( + extract_id, storage, schema, resources_with_items, dynamic_tables, collector=collector + ) + } + last_item_format: Optional[TLoaderFileFormat] = None with collector(f"Extract {source.name}"): - - def _write_empty_file(table_name: str) -> None: - table_name = schema.naming.normalize_table_identifier(table_name) - storage.write_empty_file(extract_id, schema.name, table_name, None) - - def _write_item(table_name: str, resource_name: str, item: TDataItems) -> None: - # normalize table name before writing so the name match the name in schema - # note: normalize function should be cached so there's almost no penalty on frequent calling - # note: column schema is not required for jsonl writer used here - table_name = schema.naming.normalize_table_identifier(table_name) - collector.update(table_name) - resources_with_items.add(resource_name) - storage.write_data_item(extract_id, schema.name, table_name, item, None) - - def _write_dynamic_table(resource: DltResource, item: TDataItem) -> None: - table_name = resource._table_name_hint_fun(item) - existing_table = dynamic_tables.get(table_name) - if existing_table is None: - dynamic_tables[table_name] = [resource.compute_table_schema(item)] - else: - # quick check if deep table merge is required - if resource._table_has_other_dynamic_hints: - new_table = resource.compute_table_schema(item) - # this merges into existing table in place - utils.merge_tables(existing_table[0], new_table) - else: - # if there are no other dynamic hints besides name then we just leave the existing partial table - pass - # write to storage with inferred table name - _write_item(table_name, resource.name, item) - - def _write_static_table(resource: DltResource, table_name: str) -> None: - existing_table = dynamic_tables.get(table_name) - if existing_table is None: - static_table = resource.compute_table_schema() - static_table["name"] = table_name - dynamic_tables[table_name] = [static_table] - # yield from all selected pipes with PipeIterator.from_pipes(source.resources.selected_pipes, max_parallel_items=max_parallel_items, workers=workers, futures_poll_interval=futures_poll_interval) as pipes: left_gens = total_gens = len(pipes._sources) @@ -125,24 +269,10 @@ def _write_static_table(resource: DltResource, table_name: str) -> None: signals.raise_if_signalled() resource = source.resources[pipe_item.pipe.name] - table_name: str = None - if isinstance(pipe_item.meta, TableNameMeta): - table_name = pipe_item.meta.table_name - _write_static_table(resource, table_name) - _write_item(table_name, resource.name, pipe_item.item) - else: - # get partial table from table template - if resource._table_name_hint_fun: - if isinstance(pipe_item.item, List): - for item in pipe_item.item: - _write_dynamic_table(resource, item) - else: - _write_dynamic_table(resource, pipe_item.item) - else: - # write item belonging to table with static name - table_name = resource.table_name # type: ignore - _write_static_table(resource, table_name) - _write_item(table_name, resource.name, pipe_item.item) + # Fallback to last item's format or default (puae-jsonl) if the current item is an empty list + item_format = Extractor.item_format(pipe_item.item) or last_item_format or "puae-jsonl" + extractors[item_format].write_table(resource, pipe_item.item, pipe_item.meta) + last_item_format = item_format # find defined resources that did not yield any pipeitems and create empty jobs for them data_tables = {t["name"]: t for t in schema.data_tables()} @@ -155,7 +285,7 @@ def _write_static_table(resource: DltResource, table_name: str) -> None: for table in tables_by_resources[resource.name]: # we only need to write empty files for the top tables if not table.get("parent", None): - _write_empty_file(table["name"]) + extractors[last_item_format or "puae-jsonl"].write_empty_file(table["name"]) if left_gens > 0: # go to 100% @@ -194,4 +324,3 @@ def extract_with_schema( schema.update_table(schema.normalize_table_identifiers(partial)) return extract_id - diff --git a/dlt/extract/incremental.py b/dlt/extract/incremental/__init__.py similarity index 81% rename from dlt/extract/incremental.py rename to dlt/extract/incremental/__init__.py index ebc54530cc..57593769f7 100644 --- a/dlt/extract/incremental.py +++ b/dlt/extract/incremental/__init__.py @@ -1,10 +1,16 @@ import os -from typing import Generic, TypeVar, Any, Optional, Callable, List, TypedDict, get_args, get_origin, Sequence, Type +from typing import Generic, TypeVar, Any, Optional, Callable, List, TypedDict, get_args, get_origin, Sequence, Type, Dict import inspect -from functools import wraps +from functools import wraps, partial from datetime import datetime # noqa: I251 +try: + import pandas as pd +except ModuleNotFoundError: + pd = None + import dlt +from dlt.common.exceptions import MissingDependencyException from dlt.common import pendulum, logger from dlt.common.json import json from dlt.common.jsonpath import compile_path, find_values, JSONPath @@ -17,39 +23,20 @@ from dlt.common.data_types.type_helpers import coerce_from_date_types, coerce_value, py_type_to_sc_type from dlt.extract.exceptions import IncrementalUnboundError, PipeException +from dlt.extract.incremental.exceptions import IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing +from dlt.extract.incremental.typing import IncrementalColumnState, TCursorValue, LastValueFunc from dlt.extract.pipe import Pipe from dlt.extract.utils import resolve_column_value -from dlt.extract.typing import FilterItem, SupportsPipe, TTableHintTemplate - - -TCursorValue = TypeVar("TCursorValue", bound=Any) -LastValueFunc = Callable[[Sequence[TCursorValue]], Any] - - -class IncrementalColumnState(TypedDict): - initial_value: Optional[Any] - last_value: Optional[Any] - unique_hashes: List[str] - - -class IncrementalCursorPathMissing(PipeException): - def __init__(self, pipe_name: str, json_path: str, item: TDataItem) -> None: - self.json_path = json_path - self.item = item - msg = f"Cursor element with JSON path {json_path} was not found in extracted data item. All data items must contain this path. Use the same names of fields as in your JSON document - if those are different from the names you see in database." - super().__init__(pipe_name, msg) - - -class IncrementalPrimaryKeyMissing(PipeException): - def __init__(self, pipe_name: str, primary_key_column: str, item: TDataItem) -> None: - self.primary_key_column = primary_key_column - self.item = item - msg = f"Primary key column {primary_key_column} was not found in extracted data item. All data items must contain this column. Use the same names of fields as in your JSON document." - super().__init__(pipe_name, msg) +from dlt.extract.typing import SupportsPipe, TTableHintTemplate, MapItem, YieldMapItem, FilterItem, ItemTransform +from dlt.extract.incremental.transform import JsonIncremental, ArrowIncremental, IncrementalTransformer +try: + from dlt.common.libs.pyarrow import is_arrow_item, pyarrow as pa, TAnyArrowItem +except MissingDependencyException: + is_arrow_item = lambda item: False @configspec -class Incremental(FilterItem, BaseConfiguration, Generic[TCursorValue]): +class Incremental(ItemTransform[TDataItem], BaseConfiguration, Generic[TCursorValue]): """Adds incremental extraction for a resource by storing a cursor value in persistent state. The cursor could for example be a timestamp for when the record was created and you can use this to load only @@ -115,13 +102,29 @@ def __init__( self._cached_state: IncrementalColumnState = None """State dictionary cached on first access""" - super().__init__(self.transform) + super().__init__(lambda x: x) # TODO: self.end_out_of_range: bool = False """Becomes true on the first item that is out of range of `end_value`. I.e. when using `max` function this means a value that is equal or higher""" self.start_out_of_range: bool = False """Becomes true on the first item that is out of range of `start_value`. I.e. when using `max` this is a value that is lower than `start_value`""" + self._transformers: Dict[str, IncrementalTransformer] = {} + + def _make_transformers(self) -> None: + types = [("arrow", ArrowIncremental), ("json", JsonIncremental)] + for dt, kls in types: + self._transformers[dt] = kls( + self.resource_name, + self.cursor_path_p, + self.start_value, + self.end_value, + self._cached_state, + self.last_value_func, + self.primary_key + ) + + @classmethod def from_existing_state(cls, resource_name: str, cursor_path: str) -> "Incremental[TCursorValue]": """Create Incremental instance from existing state.""" @@ -163,6 +166,7 @@ def merge(self, other: "Incremental[TCursorValue]") -> "Incremental[TCursorValue constructor = self.__orig_class__ else: constructor = other.__orig_class__ if hasattr(other, "__orig_class__") else other.__class__ + constructor = extract_inner_type(constructor) return constructor(**kwargs) # type: ignore def on_resolved(self) -> None: @@ -233,73 +237,11 @@ def last_value(self) -> Optional[TCursorValue]: s = self.get_state() return s['last_value'] # type: ignore - def unique_value(self, row: TDataItem) -> str: - try: - if self.primary_key: - return digest128(json.dumps(resolve_column_value(self.primary_key, row), sort_keys=True)) - elif self.primary_key is None: - return digest128(json.dumps(row, sort_keys=True)) - else: - return None - except KeyError as k_err: - raise IncrementalPrimaryKeyMissing(self.resource_name, k_err.args[0], row) - - def transform(self, row: TDataItem) -> bool: - if row is None: - return True - - row_values = find_values(self.cursor_path_p, row) - if not row_values: - raise IncrementalCursorPathMissing(self.resource_name, self.cursor_path, row) - row_value = row_values[0] - - # For datetime cursor, ensure the value is a timezone aware datetime. - # The object saved in state will always be a tz aware pendulum datetime so this ensures values are comparable - if isinstance(row_value, datetime): - row_value = pendulum.instance(row_value) - - incremental_state = self._cached_state - last_value = incremental_state['last_value'] - last_value_func = self.last_value_func - - # Check whether end_value has been reached - # Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value - if self.end_value is not None and ( - last_value_func((row_value, self.end_value)) != self.end_value or last_value_func((row_value, )) == self.end_value - ): - self.end_out_of_range = True - return False - - check_values = (row_value,) + ((last_value, ) if last_value is not None else ()) - new_value = last_value_func(check_values) - if last_value == new_value: - processed_row_value = last_value_func((row_value, )) - # we store row id for all records with the current "last_value" in state and use it to deduplicate - if processed_row_value == last_value: - unique_value = self.unique_value(row) - # if unique value exists then use it to deduplicate - if unique_value: - if unique_value in incremental_state['unique_hashes']: - return False - # add new hash only if the record row id is same as current last value - incremental_state['unique_hashes'].append(unique_value) - return True - # skip the record that is not a last_value or new_value: that record was already processed - check_values = (row_value,) + ((self.start_value,) if self.start_value is not None else ()) - new_value = last_value_func(check_values) - # Include rows == start_value but exclude "lower" - if new_value == self.start_value and processed_row_value != self.start_value: - self.start_out_of_range = True - return False - else: - return True - else: - incremental_state["last_value"] = new_value - unique_value = self.unique_value(row) - if unique_value: - incremental_state["unique_hashes"] = [unique_value] - - return True + def _transform_item(self, transformer: IncrementalTransformer, row: TDataItem) -> Optional[TDataItem]: + row, start_out_of_range, end_out_of_range = transformer(row) + self.start_out_of_range = start_out_of_range + self.end_out_of_range = end_out_of_range + return row def get_incremental_value_type(self) -> Type[Any]: """Infers the type of incremental value from a class of an instance if those preserve the Generic arguments information.""" @@ -371,13 +313,35 @@ def bind(self, pipe: SupportsPipe) -> "Incremental[TCursorValue]": logger.info(f"Bind incremental on {self.resource_name} with initial_value: {self.initial_value}, start_value: {self.start_value}, end_value: {self.end_value}") # cache state self._cached_state = self.get_state() + self._make_transformers() return self def __str__(self) -> str: return f"Incremental at {id(self)} for resource {self.resource_name} with cursor path: {self.cursor_path} initial {self.initial_value} lv_func {self.last_value_func}" + def _get_transformer(self, items: TDataItems) -> IncrementalTransformer: + # Assume list is all of the same type + for item in items if isinstance(items, list) else [items]: + if is_arrow_item(item): + return self._transformers['arrow'] + elif pd is not None and isinstance(item, pd.DataFrame): + return self._transformers['arrow'] + return self._transformers['json'] + return self._transformers['json'] + + def __call__(self, rows: TDataItems, meta: Any = None) -> Optional[TDataItems]: + if rows is None: + return rows + + transformer = self._get_transformer(rows) + transformer.primary_key = self.primary_key + + if isinstance(rows, list): + return [item for item in (self._transform_item(transformer, row) for row in rows) if item is not None] + return self._transform_item(transformer, rows) + -class IncrementalResourceWrapper(FilterItem): +class IncrementalResourceWrapper(ItemTransform[TDataItem]): _incremental: Optional[Incremental[Any]] = None """Keeps the injectable incremental""" _resource_name: str = None diff --git a/dlt/extract/incremental/exceptions.py b/dlt/extract/incremental/exceptions.py new file mode 100644 index 0000000000..8de5623c78 --- /dev/null +++ b/dlt/extract/incremental/exceptions.py @@ -0,0 +1,18 @@ +from dlt.extract.exceptions import PipeException +from dlt.common.typing import TDataItem + + +class IncrementalCursorPathMissing(PipeException): + def __init__(self, pipe_name: str, json_path: str, item: TDataItem, msg: str=None) -> None: + self.json_path = json_path + self.item = item + msg = msg or f"Cursor element with JSON path {json_path} was not found in extracted data item. All data items must contain this path. Use the same names of fields as in your JSON document - if those are different from the names you see in database." + super().__init__(pipe_name, msg) + + +class IncrementalPrimaryKeyMissing(PipeException): + def __init__(self, pipe_name: str, primary_key_column: str, item: TDataItem) -> None: + self.primary_key_column = primary_key_column + self.item = item + msg = f"Primary key column {primary_key_column} was not found in extracted data item. All data items must contain this column. Use the same names of fields as in your JSON document." + super().__init__(pipe_name, msg) diff --git a/dlt/extract/incremental/transform.py b/dlt/extract/incremental/transform.py new file mode 100644 index 0000000000..c77fb97b9f --- /dev/null +++ b/dlt/extract/incremental/transform.py @@ -0,0 +1,276 @@ +from datetime import datetime # noqa: I251 +from typing import Optional, Tuple, Protocol, Mapping, Union, List + +try: + import pandas as pd +except ModuleNotFoundError: + pd = None + +try: + import numpy as np +except ModuleNotFoundError: + np = None + +from dlt.common.exceptions import MissingDependencyException +from dlt.common.utils import digest128 +from dlt.common.json import json +from dlt.common import pendulum +from dlt.common.typing import TDataItem, TDataItems +from dlt.common.jsonpath import TJsonPath, find_values +from dlt.extract.incremental.exceptions import IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing +from dlt.extract.incremental.typing import IncrementalColumnState, TCursorValue, LastValueFunc +from dlt.extract.utils import resolve_column_value +from dlt.extract.typing import TTableHintTemplate +from dlt.common.schema.typing import TColumnNames +try: + from dlt.common.libs.pyarrow import pyarrow as pa, TAnyArrowItem +except MissingDependencyException: + pa = None + + + +class IncrementalTransformer: + def __init__( + self, + resource_name: str, + cursor_path: TJsonPath, + start_value: Optional[TCursorValue], + end_value: Optional[TCursorValue], + incremental_state: IncrementalColumnState, + last_value_func: LastValueFunc[TCursorValue], + primary_key: Optional[TTableHintTemplate[TColumnNames]], + ) -> None: + self.resource_name = resource_name + self.cursor_path = cursor_path + self.start_value = start_value + self.end_value = end_value + self.incremental_state = incremental_state + self.last_value_func = last_value_func + self.primary_key = primary_key + + def __call__( + self, + row: TDataItem, + ) -> Tuple[bool, bool, bool]: + ... + + +class JsonIncremental(IncrementalTransformer): + def unique_value( + self, + row: TDataItem, + primary_key: Optional[TTableHintTemplate[TColumnNames]], + resource_name: str + ) -> str: + try: + if primary_key: + return digest128(json.dumps(resolve_column_value(primary_key, row), sort_keys=True)) + elif primary_key is None: + return digest128(json.dumps(row, sort_keys=True)) + else: + return None + except KeyError as k_err: + raise IncrementalPrimaryKeyMissing(resource_name, k_err.args[0], row) + + def __call__( + self, + row: TDataItem, + ) -> Tuple[Optional[TDataItem], bool, bool]: + """ + Returns: + Tuple (row, start_out_of_range, end_out_of_range) where row is either the data item or `None` if it is completely filtered out + """ + start_out_of_range = end_out_of_range = False + if row is None: + return row, start_out_of_range, end_out_of_range + + row_values = find_values(self.cursor_path, row) + if not row_values: + raise IncrementalCursorPathMissing(self.resource_name, str(self.cursor_path), row) + row_value = row_values[0] + + # For datetime cursor, ensure the value is a timezone aware datetime. + # The object saved in state will always be a tz aware pendulum datetime so this ensures values are comparable + if isinstance(row_value, datetime): + row_value = pendulum.instance(row_value) + + last_value = self.incremental_state['last_value'] + + # Check whether end_value has been reached + # Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value + if self.end_value is not None and ( + self.last_value_func((row_value, self.end_value)) != self.end_value or self.last_value_func((row_value, )) == self.end_value + ): + end_out_of_range = True + return None, start_out_of_range, end_out_of_range + + check_values = (row_value,) + ((last_value, ) if last_value is not None else ()) + new_value = self.last_value_func(check_values) + if last_value == new_value: + processed_row_value = self.last_value_func((row_value, )) + # we store row id for all records with the current "last_value" in state and use it to deduplicate + + if processed_row_value == last_value: + unique_value = self.unique_value(row, self.primary_key, self.resource_name) + # if unique value exists then use it to deduplicate + if unique_value: + if unique_value in self.incremental_state['unique_hashes']: + return None, start_out_of_range, end_out_of_range + # add new hash only if the record row id is same as current last value + self.incremental_state['unique_hashes'].append(unique_value) + return row, start_out_of_range, end_out_of_range + # skip the record that is not a last_value or new_value: that record was already processed + check_values = (row_value,) + ((self.start_value,) if self.start_value is not None else ()) + new_value = self.last_value_func(check_values) + # Include rows == start_value but exclude "lower" + if new_value == self.start_value and processed_row_value != self.start_value: + start_out_of_range = True + return None, start_out_of_range, end_out_of_range + else: + return row, start_out_of_range, end_out_of_range + else: + self.incremental_state["last_value"] = new_value + unique_value = self.unique_value(row, self.primary_key, self.resource_name) + if unique_value: + self.incremental_state["unique_hashes"] = [unique_value] + + return row, start_out_of_range, end_out_of_range + + + +class ArrowIncremental(IncrementalTransformer): + def unique_values( + self, + item: "TAnyArrowItem", + unique_columns: List[str], + resource_name: str + ) -> List[Tuple[int, str]]: + if not unique_columns: + return [] + item = item + indices = item["_dlt_index"].to_pylist() + rows = item.select(unique_columns).to_pylist() + return [ + (index, digest128(json.dumps(row, sort_keys=True))) for index, row in zip(indices, rows) + ] + + def _deduplicate(self, tbl: "pa.Table", unique_columns: Optional[List[str]], aggregate: str, cursor_path: str) -> "pa.Table": + if unique_columns is None: + return tbl + group_cols = unique_columns + [cursor_path] + tbl = tbl.append_column("_dlt_index", pa.array(np.arange(tbl.num_rows))) + try: + tbl = tbl.filter( + pa.compute.is_in( + tbl['_dlt_index'], + tbl.group_by(group_cols).aggregate( + [("_dlt_index", "one"), (cursor_path, aggregate)] + )['_dlt_index_one'] + ) + ) + except KeyError as e: + raise IncrementalPrimaryKeyMissing(self.resource_name, unique_columns[0], tbl) from e + return tbl + + def __call__( + self, + tbl: "TAnyArrowItem", + ) -> Tuple[TDataItem, bool, bool]: + is_pandas = pd is not None and isinstance(tbl, pd.DataFrame) + if is_pandas: + tbl = pa.Table.from_pandas(tbl) + + start_out_of_range = end_out_of_range = False + if not tbl: # row is None or empty arrow table + return tbl, start_out_of_range, end_out_of_range + + last_value = self.incremental_state['last_value'] + + if self.last_value_func is max: + compute = pa.compute.max + aggregate = "max" + end_compare = pa.compute.less + last_value_compare = pa.compute.greater_equal + new_value_compare = pa.compute.greater + elif self.last_value_func is min: + compute = pa.compute.min + aggregate = "min" + end_compare = pa.compute.greater + last_value_compare = pa.compute.less_equal + new_value_compare = pa.compute.less + else: + raise NotImplementedError("Only min or max last_value_func is supported for arrow tables") + + + # TODO: Json path support. For now assume the cursor_path is a column name + cursor_path = str(self.cursor_path) + # The new max/min value + try: + row_value = compute(tbl[cursor_path]).as_py() + except KeyError as e: + raise IncrementalCursorPathMissing( + self.resource_name, cursor_path, tbl, + f"Column name {str(cursor_path)} was not found in the arrow table. Note nested JSON paths are not supported for arrow tables and dataframes, the incremental cursor_path must be a column name." + ) from e + + primary_key = self.primary_key(tbl) if callable(self.primary_key) else self.primary_key + if primary_key: + if isinstance(primary_key, str): + unique_columns = [primary_key] + else: + unique_columns = list(primary_key) + elif primary_key is None: + unique_columns = tbl.column_names + else: # deduplicating is disabled + unique_columns = None + + # If end_value is provided, filter to include table rows that are "less" than end_value + if self.end_value is not None: + tbl = tbl.filter(end_compare(tbl[cursor_path], self.end_value)) + # Is max row value higher than end value? + # NOTE: pyarrow bool *always* evaluates to python True. `as_py()` is necessary + end_out_of_range = not end_compare(row_value, self.end_value).as_py() + + if last_value is not None: + if self.start_value is not None: + # Remove rows lower than the last start value + keep_filter = last_value_compare(tbl[cursor_path], self.start_value) + start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py()) + tbl = tbl.filter(keep_filter) + + # Deduplicate after filtering old values + tbl = self._deduplicate(tbl, unique_columns, aggregate, cursor_path) + # Remove already processed rows where the cursor is equal to the last value + eq_rows = tbl.filter(pa.compute.equal(tbl[cursor_path], last_value)) + # compute index, unique hash mapping + unique_values = self.unique_values(eq_rows, unique_columns, self.resource_name) + unique_values = [(i, uq_val) for i, uq_val in unique_values if uq_val in self.incremental_state['unique_hashes']] + remove_idx = pa.array(i for i, _ in unique_values) + # Filter the table + tbl = tbl.filter(pa.compute.invert(pa.compute.is_in(tbl["_dlt_index"], remove_idx))) + + if new_value_compare(row_value, last_value).as_py() and row_value != last_value: # Last value has changed + self.incremental_state['last_value'] = row_value + # Compute unique hashes for all rows equal to row value + self.incremental_state['unique_hashes'] = [uq_val for _, uq_val in self.unique_values( + tbl.filter(pa.compute.equal(tbl[cursor_path], row_value)), unique_columns, self.resource_name + )] + else: + # last value is unchanged, add the hashes + self.incremental_state['unique_hashes'] = list(set(self.incremental_state['unique_hashes'] + [uq_val for _, uq_val in unique_values])) + else: + tbl = self._deduplicate(tbl, unique_columns, aggregate, cursor_path) + self.incremental_state['last_value'] = row_value + self.incremental_state['unique_hashes'] = [uq_val for _, uq_val in self.unique_values( + tbl.filter(pa.compute.equal(tbl[cursor_path], row_value)), unique_columns, self.resource_name + )] + + if len(tbl) == 0: + return None, start_out_of_range, end_out_of_range + try: + tbl = tbl.drop(["_dlt_index"]) + except KeyError: + pass + if is_pandas: + return tbl.to_pandas(), start_out_of_range, end_out_of_range + return tbl, start_out_of_range, end_out_of_range diff --git a/dlt/extract/incremental/typing.py b/dlt/extract/incremental/typing.py new file mode 100644 index 0000000000..03f36121be --- /dev/null +++ b/dlt/extract/incremental/typing.py @@ -0,0 +1,10 @@ +from typing import TypedDict, Optional, Any, List, TypeVar, Callable, Sequence + + +TCursorValue = TypeVar("TCursorValue", bound=Any) +LastValueFunc = Callable[[Sequence[TCursorValue]], Any] + +class IncrementalColumnState(TypedDict): + initial_value: Optional[Any] + last_value: Optional[Any] + unique_hashes: List[str] diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py new file mode 100644 index 0000000000..2f613f4b40 --- /dev/null +++ b/dlt/normalize/items_normalizers.py @@ -0,0 +1,134 @@ +import os +from typing import List, Dict, Tuple, Protocol +from pathlib import Path + +from dlt.common import json, logger +from dlt.common.json import custom_pua_decode +from dlt.common.runtime import signals +from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common.storages import NormalizeStorage, LoadStorage, NormalizeStorageConfiguration, FileStorage +from dlt.common.typing import TDataItem +from dlt.common.schema import TSchemaUpdate, Schema +from dlt.common.utils import TRowCount, merge_row_count, increase_row_count + + +class ItemsNormalizer(Protocol): + def __call__( + self, + extracted_items_file: str, + load_storage: LoadStorage, + normalize_storage: NormalizeStorage, + schema: Schema, + load_id: str, + root_table_name: str, + ) -> Tuple[List[TSchemaUpdate], int, TRowCount]: + ... + + +class JsonLItemsNormalizer(ItemsNormalizer): + def _normalize_chunk( + self, + load_storage: LoadStorage, + schema: Schema, + load_id: str, + root_table_name: str, + items: List[TDataItem], + ) -> Tuple[TSchemaUpdate, int, TRowCount]: + column_schemas: Dict[ + str, TTableSchemaColumns + ] = {} # quick access to column schema for writers below + schema_update: TSchemaUpdate = {} + schema_name = schema.name + items_count = 0 + row_counts: TRowCount = {} + + for item in items: + for (table_name, parent_table), row in schema.normalize_data_item( + item, load_id, root_table_name + ): + # filter row, may eliminate some or all fields + row = schema.filter_row(table_name, row) + # do not process empty rows + if row: + # decode pua types + for k, v in row.items(): + row[k] = custom_pua_decode(v) # type: ignore + # coerce row of values into schema table, generating partial table with new columns if any + row, partial_table = schema.coerce_row( + table_name, parent_table, row + ) + # theres a new table or new columns in existing table + if partial_table: + # update schema and save the change + schema.update_table(partial_table) + table_updates = schema_update.setdefault(table_name, []) + table_updates.append(partial_table) + # update our columns + column_schemas[table_name] = schema.get_table_columns( + table_name + ) + # get current columns schema + columns = column_schemas.get(table_name) + if not columns: + columns = schema.get_table_columns(table_name) + column_schemas[table_name] = columns + # store row + # TODO: it is possible to write to single file from many processes using this: https://gitlab.com/warsaw/flufl.lock + load_storage.write_data_item( + load_id, schema_name, table_name, row, columns + ) + # count total items + items_count += 1 + increase_row_count(row_counts, table_name, 1) + signals.raise_if_signalled() + return schema_update, items_count, row_counts + + def __call__( + self, + extracted_items_file: str, + load_storage: LoadStorage, + normalize_storage: NormalizeStorage, + schema: Schema, + load_id: str, + root_table_name: str, + ) -> Tuple[List[TSchemaUpdate], int, TRowCount]: + schema_updates: List[TSchemaUpdate] = [] + row_counts: TRowCount = {} + with normalize_storage.storage.open_file(extracted_items_file) as f: + # enumerate jsonl file line by line + items_count = 0 + for line_no, line in enumerate(f): + items: List[TDataItem] = json.loads(line) + partial_update, items_count, r_counts = self._normalize_chunk( + load_storage, schema, load_id, root_table_name, items + ) + schema_updates.append(partial_update) + merge_row_count(row_counts, r_counts) + logger.debug( + f"Processed {line_no} items from file {extracted_items_file}, items {items_count}" + ) + + return schema_updates, items_count, row_counts + + +class ParquetItemsNormalizer(ItemsNormalizer): + def __call__( + self, + extracted_items_file: str, + load_storage: LoadStorage, + normalize_storage: NormalizeStorage, + schema: Schema, + load_id: str, + root_table_name: str, + ) -> Tuple[List[TSchemaUpdate], int, TRowCount]: + from dlt.common.libs import pyarrow + with normalize_storage.storage.open_file(extracted_items_file, "rb") as f: + items_count = pyarrow.get_row_count(f) + target_folder = load_storage.storage.make_full_path(os.path.join(load_id, LoadStorage.NEW_JOBS_FOLDER)) + parts = NormalizeStorage.parse_normalize_file_name(extracted_items_file) + new_file_name = load_storage.build_job_file_name(parts.table_name, parts.file_id, with_extension=True) + FileStorage.link_hard_with_fallback( + normalize_storage.storage.make_full_path(extracted_items_file), + os.path.join(target_folder, new_file_name) + ) + return [], items_count, {root_table_name: items_count} diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 3a23c6b371..68ab50012f 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -18,10 +18,12 @@ from dlt.common.typing import TDataItem from dlt.common.schema import TSchemaUpdate, Schema from dlt.common.schema.exceptions import CannotCoerceColumnException +from dlt.common.exceptions import TerminalValueError from dlt.common.pipeline import NormalizeInfo from dlt.common.utils import chunks, TRowCount, merge_row_count, increase_row_count from dlt.normalize.configuration import NormalizeConfiguration +from dlt.normalize.items_normalizers import ParquetItemsNormalizer, JsonLItemsNormalizer, ItemsNormalizer # normalize worker wrapping function (map_parallel, map_single) return type TMapFuncRV = Tuple[Sequence[TSchemaUpdate], TRowCount] @@ -67,22 +69,40 @@ def load_or_create_schema(schema_storage: SchemaStorage, schema_name: str) -> Sc @staticmethod def w_normalize_files( - normalize_storage_config: NormalizeStorageConfiguration, - loader_storage_config: LoadStorageConfiguration, - destination_caps: DestinationCapabilitiesContext, - stored_schema: TStoredSchema, - load_id: str, - extracted_items_files: Sequence[str], - ) -> TWorkerRV: + normalize_storage_config: NormalizeStorageConfiguration, + loader_storage_config: LoadStorageConfiguration, + destination_caps: DestinationCapabilitiesContext, + stored_schema: TStoredSchema, + load_id: str, + extracted_items_files: Sequence[str], + ) -> TWorkerRV: schema_updates: List[TSchemaUpdate] = [] total_items = 0 row_counts: TRowCount = {} + load_storages: Dict[TLoaderFileFormat, LoadStorage] = {} + + def _get_load_storage(file_format: TLoaderFileFormat) -> LoadStorage: + if file_format != "parquet": + file_format = destination_caps.preferred_loader_file_format or destination_caps.preferred_staging_file_format + if storage := load_storages.get(file_format): + return storage + # TODO: capabilities.supporteed_*_formats can be None, it should have defaults + supported_formats = list(set(destination_caps.supported_loader_file_formats or []) | set(destination_caps.supported_staging_file_formats or [])) + if file_format not in supported_formats: + if file_format == "parquet": # Give users a helpful error message for parquet + raise TerminalValueError(( + "The destination doesn't support direct loading of arrow tables. " + "Either use a different destination with parquet support or yield dicts instead of pyarrow tables/pandas dataframes from your sources." + )) + # Load storage throws a generic error for other unsupported formats, normally that shouldn't happen + storage = load_storages[file_format] = LoadStorage(False, file_format, supported_formats, loader_storage_config) + return storage # process all files with data items and write to buffered item storage with Container().injectable_context(destination_caps): schema = Schema.from_stored_schema(stored_schema) - load_storage = LoadStorage(False, destination_caps.preferred_loader_file_format, LoadStorage.ALL_SUPPORTED_FILE_FORMATS, loader_storage_config) + load_storage = _get_load_storage(destination_caps.preferred_loader_file_format) # Default load storage, used for empty tables when no data normalize_storage = NormalizeStorage(False, normalize_storage_config) try: @@ -90,25 +110,27 @@ def w_normalize_files( populated_root_tables: Set[str] = set() for extracted_items_file in extracted_items_files: line_no: int = 0 - root_table_name = NormalizeStorage.parse_normalize_file_name(extracted_items_file).table_name + parsed_file_name = NormalizeStorage.parse_normalize_file_name(extracted_items_file) + root_table_name = parsed_file_name.table_name root_tables.add(root_table_name) logger.debug(f"Processing extracted items in {extracted_items_file} in load_id {load_id} with table name {root_table_name} and schema {schema.name}") - with normalize_storage.storage.open_file(extracted_items_file) as f: - # enumerate jsonl file line by line - items_count = 0 - for line_no, line in enumerate(f): - items: List[TDataItem] = json.loads(line) - partial_update, items_count, r_counts = Normalize._w_normalize_chunk(load_storage, schema, load_id, root_table_name, items) - schema_updates.append(partial_update) - total_items += items_count - merge_row_count(row_counts, r_counts) - logger.debug(f"Processed {line_no} items from file {extracted_items_file}, items {items_count} of total {total_items}") - # if any item found in the file - if items_count > 0: - populated_root_tables.add(root_table_name) - logger.debug(f"Processed total {line_no + 1} lines from file {extracted_items_file}, total items {total_items}") - # make sure base tables are all covered - increase_row_count(row_counts, root_table_name, 0) + + file_format = parsed_file_name.file_format + load_storage = _get_load_storage(file_format) + normalizer: ItemsNormalizer + if file_format == "parquet": + normalizer = ParquetItemsNormalizer() + else: + normalizer = JsonLItemsNormalizer() + partial_updates, items_count, r_counts = normalizer(extracted_items_file, load_storage, normalize_storage, schema, load_id, root_table_name) + schema_updates.extend(partial_updates) + total_items += items_count + merge_row_count(row_counts, r_counts) + if items_count > 0: + populated_root_tables.add(root_table_name) + logger.debug(f"Processed total {line_no + 1} lines from file {extracted_items_file}, total items {total_items}") + # make sure base tables are all covered + increase_row_count(row_counts, root_table_name, 0) # write empty jobs for tables without items if table exists in schema for table_name in root_tables - populated_root_tables: if table_name not in schema.tables: @@ -126,47 +148,6 @@ def w_normalize_files( return schema_updates, total_items, load_storage.closed_files(), row_counts - @staticmethod - def _w_normalize_chunk(load_storage: LoadStorage, schema: Schema, load_id: str, root_table_name: str, items: List[TDataItem]) -> Tuple[TSchemaUpdate, int, TRowCount]: - column_schemas: Dict[str, TTableSchemaColumns] = {} # quick access to column schema for writers below - schema_update: TSchemaUpdate = {} - schema_name = schema.name - items_count = 0 - row_counts: TRowCount = {} - - for item in items: - for (table_name, parent_table), row in schema.normalize_data_item(item, load_id, root_table_name): - # filter row, may eliminate some or all fields - row = schema.filter_row(table_name, row) - # do not process empty rows - if row: - # decode pua types - for k, v in row.items(): - row[k] = custom_pua_decode(v) # type: ignore - # coerce row of values into schema table, generating partial table with new columns if any - row, partial_table = schema.coerce_row(table_name, parent_table, row) - # theres a new table or new columns in existing table - if partial_table: - # update schema and save the change - schema.update_table(partial_table) - table_updates = schema_update.setdefault(table_name, []) - table_updates.append(partial_table) - # update our columns - column_schemas[table_name] = schema.get_table_columns(table_name) - # get current columns schema - columns = column_schemas.get(table_name) - if not columns: - columns = schema.get_table_columns(table_name) - column_schemas[table_name] = columns - # store row - # TODO: it is possible to write to single file from many processes using this: https://gitlab.com/warsaw/flufl.lock - load_storage.write_data_item(load_id, schema_name, table_name, row, columns) - # count total items - items_count += 1 - increase_row_count(row_counts, table_name, 1) - signals.raise_if_signalled() - return schema_update, items_count, row_counts - def update_table(self, schema: Schema, schema_updates: List[TSchemaUpdate]) -> None: for schema_update in schema_updates: for table_name, table_updates in schema_update.items(): @@ -281,8 +262,7 @@ def spool_files(self, schema_name: str, load_id: str, map_f: TMapFuncType, files # rename temp folder to processing self.load_storage.commit_temp_load_package(load_id) # delete item files to complete commit - for file in files: - self.normalize_storage.storage.delete(file) + self.normalize_storage.delete_extracted_files(files) # log and update metrics logger.info(f"Chunk {load_id} processed") self._row_counts = row_counts diff --git a/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md b/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md new file mode 100644 index 0000000000..5d993fa9b9 --- /dev/null +++ b/docs/website/docs/dlt-ecosystem/verified-sources/arrow-pandas.md @@ -0,0 +1,107 @@ +--- +title: Arrow Table / Pandas +description: dlt source for Arrow tables and Pandas dataframes +keywords: [arrow, pandas, parquet, source] +--- + +# Arrow Table / Pandas + +You can load data directly from an Arrow table or Pandas dataframe. +This is supported by all destinations that support the parquet file format (e.g. [Snowflake](../destinations/snowflake.md) and [Filesystem](../destinations/filesystem.md)). + +This is a more performance way to load structured data since dlt bypasses many processing steps normally involved in passing JSON objects through the pipeline. +Dlt automatically translates the Arrow table's schema to the destination table's schema and writes the table to a parquet file which gets uploaded to the destination without any further processing. + +## Usage + +To write an Arrow source, pass any `pyarrow.Table` or `pandas.DataFrame` object to the pipeline's `run` or `extract` method, or yield table(s)/dataframe(s) from a `@dlt.resource` decorated function. + +This example loads a Pandas dataframe to a Snowflake table: + +```python +import dlt +from dlt.common import pendulum +import pandas as pd + + +df = pd.DataFrame({ + "order_id": [1, 2, 3], + "customer_id": [1, 2, 3], + "ordered_at": [pendulum.DateTime(2021, 1, 1, 4, 5, 6), pendulum.DateTime(2021, 1, 3, 4, 5, 6), pendulum.DateTime(2021, 1, 6, 4, 5, 6)], + "order_amount": [100.0, 200.0, 300.0], +}) + +pipeline = dlt.pipeline("orders_pipeline", destination="snowflake") + +pipeline.run([df], table_name="orders") +``` + +A `pyarrow` table can be loaded in the same way: + +```python +import pyarrow as pa + +# Create dataframe and pipeline same as above +... + +table = pa.Table.from_pandas(df) +pipeline.run([table], table_name="orders") +``` + +Note: The data in the table must be compatible with the destination database as no data conversion is performed. Refer to the documentation of the destination for information about supported data types. + +## Incremental loading with Arrow tables + +You can use incremental loading with Arrow tables as well. +Usage is the same as without other dlt resources. Refer to the [incremental loading](/general-usage/incremental-loading.md) guide for more information. + +Example: + +```python +import dlt +from dlt.common import pendulum +import pandas as pd + +# Create a resource using that yields a dataframe, using the `ordered_at` field as an incremental cursor +@dlt.resource(primary_key="order_id") +def orders(ordered_at = dlt.sources.incremental('ordered_at')) + # Get dataframe/arrow table from somewhere + # If your database supports it, you can use the last_value to filter data at the source. + # Otherwise it will be filtered automatically after loading the data. + df = get_orders(since=ordered_at.last_value) + yield df + +pipeline = dlt.pipeline("orders_pipeline", destination="snowflake") +pipeline.run(orders) +# Run again to load only new data +pipeline.run(orders) +``` + +## Supported Arrow data types + +The Arrow data types are translated to dlt data types as follows: + +| Arrow type | dlt type | Notes | +|-------------------|-------------|------------------------------------------------------------| +| `string` | `text` | | +| `float`/`double` | `double` | | +| `boolean` | `bool` | | +| `timestamp` | `timestamp` | Precision is determined by the unit of the timestamp. | +| `date` | `date` | | +| `time` | `time` | Precision is determined by the unit of the time. | +| `int` | `bigint` | Precision is determined by the bit width. | +| `binary` | `binary` | | +| `decimal` | `decimal` | Precision and scale are determined by the type properties. | +| `struct` | `complex` | | +| | | | + + + + + + + + + + + diff --git a/tests/cases.py b/tests/cases.py index 88285c9241..6680f563ad 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Any, Sequence, Tuple +from typing import Dict, List, Any, Sequence, Tuple, Literal import base64 from hexbytes import HexBytes from copy import deepcopy @@ -11,6 +11,9 @@ from dlt.common.schema import TColumnSchema, TTableSchemaColumns +TArrowFormat = Literal["pandas", "table", "record_batch"] + + # _UUID = "c8209ee7-ee95-4b90-8c9f-f7a0f8b51014" JSON_TYPED_DICT: StrAny = { "str": "string", @@ -326,3 +329,39 @@ def assert_all_data_types_row( for expected, actual in zip(expected_rows.values(), db_mapping.values()): assert expected == actual assert db_mapping == expected_rows + + +def arrow_table_all_data_types(object_format: TArrowFormat, include_json: bool = True, include_time: bool = True) -> Tuple[Any, List[Dict[str, Any]]]: + """Create an arrow object or pandas dataframe with all supported data types. + + Returns the table and its records in python format + """ + import pandas as pd + from dlt.common.libs.pyarrow import pyarrow as pa + + data = { + "string": ["a", "b", "c"], + "float": [1.0, 2.0, 3.0], + "int": [1, 2, 3], + "datetime": pd.date_range("2021-01-01", periods=3, tz="UTC"), + "date": pd.date_range("2021-01-01", periods=3, tz="UTC").date, + "binary": [b"a", b"b", b"c"], + "decimal": [Decimal("1.0"), Decimal("2.0"), Decimal("3.0")], + "bool": [True, True, False] + } + + if include_json: + data["json"] = [{"a": 1}, {"b": 2}, {"c": 3}] + + if include_time: + data["time"] = pd.date_range("2021-01-01", periods=3, tz="UTC").time + + df = pd.DataFrame(data) + rows = df.to_dict("records") + if object_format == "pandas": + return df, rows + elif object_format == "table": + return pa.Table.from_pandas(df), rows + elif object_format == "record_batch": + return pa.RecordBatch.from_pandas(df), rows + raise ValueError("Unknown item type: " + object_format) diff --git a/tests/common/storages/test_normalize_storage.py b/tests/common/storages/test_normalize_storage.py index 678e1e49fe..7199405c12 100644 --- a/tests/common/storages/test_normalize_storage.py +++ b/tests/common/storages/test_normalize_storage.py @@ -18,11 +18,11 @@ def test_build_extracted_file_name() -> None: load_id = uniq_id() name = NormalizeStorage.build_extracted_file_stem("event", "table_with_parts__many", load_id) + ".jsonl" assert NormalizeStorage.get_schema_name(name) == "event" - assert NormalizeStorage.parse_normalize_file_name(name) == TParsedNormalizeFileName("event", "table_with_parts__many", load_id) + assert NormalizeStorage.parse_normalize_file_name(name) == TParsedNormalizeFileName("event", "table_with_parts__many", load_id, "jsonl") # empty schema should be supported name = NormalizeStorage.build_extracted_file_stem("", "table", load_id) + ".jsonl" - assert NormalizeStorage.parse_normalize_file_name(name) == TParsedNormalizeFileName("", "table", load_id) + assert NormalizeStorage.parse_normalize_file_name(name) == TParsedNormalizeFileName("", "table", load_id, "jsonl") def test_full_migration_path() -> None: diff --git a/tests/common/test_pyarrow.py b/tests/common/test_pyarrow.py new file mode 100644 index 0000000000..6dbdae00cb --- /dev/null +++ b/tests/common/test_pyarrow.py @@ -0,0 +1,51 @@ +from copy import deepcopy + +import pyarrow as pa + +from dlt.common.libs.pyarrow import py_arrow_to_table_schema_columns, get_py_arrow_datatype +from dlt.common.destination import DestinationCapabilitiesContext +from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA + + + +def test_py_arrow_to_table_schema_columns(): + dlt_schema = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA) + + caps = DestinationCapabilitiesContext.generic_capabilities() + # The arrow schema will add precision + dlt_schema['col4']['precision'] = caps.timestamp_precision + dlt_schema['col6']['precision'], dlt_schema['col6']['scale'] = caps.decimal_precision + dlt_schema['col11']['precision'] = caps.timestamp_precision + dlt_schema['col4_null']['precision'] = caps.timestamp_precision + dlt_schema['col6_null']['precision'], dlt_schema['col6_null']['scale'] = caps.decimal_precision + dlt_schema['col11_null']['precision'] = caps.timestamp_precision + + # Ignoring wei as we can't distinguish from decimal + dlt_schema['col8']['precision'], dlt_schema['col8']['scale'] = (76, 0) + dlt_schema['col8']['data_type'] = 'decimal' + dlt_schema['col8_null']['precision'], dlt_schema['col8_null']['scale'] = (76, 0) + dlt_schema['col8_null']['data_type'] = 'decimal' + # No json type + dlt_schema['col9']['data_type'] = 'text' + del dlt_schema['col9']['variant'] + dlt_schema['col9_null']['data_type'] = 'text' + del dlt_schema['col9_null']['variant'] + + # arrow string fields don't have precision + del dlt_schema['col5_precision']['precision'] + + + # Convert to arrow schema + arrow_schema = pa.schema( + [ + pa.field( + column["name"], get_py_arrow_datatype(column, caps, "UTC"), nullable=column["nullable"] + ) + for column in dlt_schema.values() + ] + ) + + result = py_arrow_to_table_schema_columns(arrow_schema) + + # Resulting schema should match the original + assert result == dlt_schema diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index cf644aa08d..5f1ab9279f 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -20,15 +20,21 @@ from dlt.extract.source import DltSource from dlt.sources.helpers.transform import take_first from dlt.extract.incremental import IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing +from dlt.pipeline.exceptions import PipelineStepFailed -from tests.extract.utils import AssertItems +from tests.extract.utils import AssertItems, data_to_item_format, TItemFormat, ALL_ITEM_FORMATS, data_item_to_list -def test_single_items_last_value_state_is_updated() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_single_items_last_value_state_is_updated(item_type: TItemFormat) -> None: + data = [ + {'created_at': 425}, + {'created_at': 426}, + ] + source_items = data_to_item_format(item_type, data) @dlt.resource def some_data(created_at=dlt.sources.incremental('created_at')): - yield {'created_at': 425} - yield {'created_at': 426} + yield from source_items p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) @@ -36,11 +42,17 @@ def some_data(created_at=dlt.sources.incremental('created_at')): assert s['last_value'] == 426 -def test_single_items_last_value_state_is_updated_transformer() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_single_items_last_value_state_is_updated_transformer(item_type: TItemFormat) -> None: + data = [ + {'created_at': 425}, + {'created_at': 426}, + ] + source_items = data_to_item_format(item_type, data) + @dlt.transformer def some_data(item, created_at=dlt.sources.incremental('created_at')): - yield {'created_at': 425} - yield {'created_at': 426} + yield from source_items p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(dlt.resource([1,2,3], name="table") | some_data()) @@ -49,11 +61,18 @@ def some_data(item, created_at=dlt.sources.incremental('created_at')): assert s['last_value'] == 426 -def test_batch_items_last_value_state_is_updated() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_batch_items_last_value_state_is_updated(item_type: TItemFormat) -> None: + data1 = [{'created_at': i} for i in range(5)] + data2 = [{'created_at': i} for i in range(5, 10)] + + source_items1 = data_to_item_format(item_type, data1) + source_items2 = data_to_item_format(item_type, data2) + @dlt.resource def some_data(created_at=dlt.sources.incremental('created_at')): - yield [{'created_at': i} for i in range(5)] - yield [{'created_at': i} for i in range(5, 10)] + yield source_items1 + yield source_items2 p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) @@ -62,13 +81,17 @@ def some_data(created_at=dlt.sources.incremental('created_at')): assert s['last_value'] == 9 -def test_last_value_access_in_resource() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_last_value_access_in_resource(item_type: TItemFormat) -> None: values = [] + data = [{'created_at': i} for i in range(6)] + source_items = data_to_item_format(item_type, data) + @dlt.resource def some_data(created_at=dlt.sources.incremental('created_at')): values.append(created_at.last_value) - yield [{'created_at': i} for i in range(6)] + yield source_items p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) @@ -77,21 +100,31 @@ def some_data(created_at=dlt.sources.incremental('created_at')): assert values == [None, 5] -def test_unique_keys_are_deduplicated() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_unique_keys_are_deduplicated(item_type: TItemFormat) -> None: + data1 = [ + {'created_at': 1, 'id': 'a'}, + {'created_at': 2, 'id': 'b'}, + {'created_at': 3, 'id': 'c'}, + {'created_at': 3, 'id': 'd'}, + {'created_at': 3, 'id': 'e'}, + ] + data2 = [ + {'created_at': 3, 'id': 'c'}, + {'created_at': 3, 'id': 'd'}, + {'created_at': 3, 'id': 'e'}, + {'created_at': 3, 'id': 'f'}, + {'created_at': 4, 'id': 'g'}, + ] + + source_items1 = data_to_item_format(item_type, data1) + source_items2 = data_to_item_format(item_type, data2) @dlt.resource(primary_key='id') def some_data(created_at=dlt.sources.incremental('created_at')): if created_at.last_value is None: - yield {'created_at': 1, 'id': 'a'} - yield {'created_at': 2, 'id': 'b'} - yield {'created_at': 3, 'id': 'c'} - yield {'created_at': 3, 'id': 'd'} - yield {'created_at': 3, 'id': 'e'} + yield from source_items1 else: - yield {'created_at': 3, 'id': 'c'} - yield {'created_at': 3, 'id': 'd'} - yield {'created_at': 3, 'id': 'e'} - yield {'created_at': 3, 'id': 'f'} - yield {'created_at': 4, 'id': 'g'} + yield from source_items2 p = dlt.pipeline(pipeline_name=uniq_id(), destination='duckdb', credentials=duckdb.connect(':memory:')) @@ -105,21 +138,32 @@ def some_data(created_at=dlt.sources.incremental('created_at')): assert rows == [(1, 'a'), (2, 'b'), (3, 'c'), (3, 'd'), (3, 'e'), (3, 'f'), (4, 'g')] -def test_unique_rows_by_hash_are_deduplicated() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_unique_rows_by_hash_are_deduplicated(item_type: TItemFormat) -> None: + data1 = [ + {'created_at': 1, 'id': 'a'}, + {'created_at': 2, 'id': 'b'}, + {'created_at': 3, 'id': 'c'}, + {'created_at': 3, 'id': 'd'}, + {'created_at': 3, 'id': 'e'}, + ] + data2 = [ + {'created_at': 3, 'id': 'c'}, + {'created_at': 3, 'id': 'd'}, + {'created_at': 3, 'id': 'e'}, + {'created_at': 3, 'id': 'f'}, + {'created_at': 4, 'id': 'g'}, + ] + + source_items1 = data_to_item_format(item_type, data1) + source_items2 = data_to_item_format(item_type, data2) + @dlt.resource def some_data(created_at=dlt.sources.incremental('created_at')): if created_at.last_value is None: - yield {'created_at': 1, 'id': 'a'} - yield {'created_at': 2, 'id': 'b'} - yield {'created_at': 3, 'id': 'c'} - yield {'created_at': 3, 'id': 'd'} - yield {'created_at': 3, 'id': 'e'} + yield from source_items1 else: - yield {'created_at': 3, 'id': 'c'} - yield {'created_at': 3, 'id': 'd'} - yield {'created_at': 3, 'id': 'e'} - yield {'created_at': 3, 'id': 'f'} - yield {'created_at': 4, 'id': 'g'} + yield from source_items2 p = dlt.pipeline(pipeline_name=uniq_id(), destination='duckdb', credentials=duckdb.connect(':memory:')) p.run(some_data()) @@ -144,10 +188,32 @@ def some_data(created_at=dlt.sources.incremental('data.items[0].created_at')): assert s['last_value'] == 2 -def test_explicit_initial_value() -> None: +@pytest.mark.parametrize("item_type", ["arrow", "pandas"]) +def test_nested_cursor_path_arrow_fails(item_type: TItemFormat) -> None: + data = [ + {'data': {'items': [{'created_at': 2}]}} + ] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data(created_at=dlt.sources.incremental('data.items[0].created_at')): + yield from source_items + + p = dlt.pipeline(pipeline_name=uniq_id()) + with pytest.raises(PipelineStepFailed) as py_ex: + p.extract(some_data()) + + ex: PipelineStepFailed = py_ex.value + assert isinstance(ex.exception, IncrementalCursorPathMissing) + assert "Column name data.items.[0].created_at was not found in the arrow table" in str(ex) + + +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_explicit_initial_value(item_type: TItemFormat) -> None: @dlt.resource def some_data(created_at=dlt.sources.incremental('created_at')): - yield {'created_at': created_at.last_value} + data = [{"created_at": created_at.last_value}] + yield from data_to_item_format(item_type, data) p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data(created_at=4242)) @@ -156,19 +222,23 @@ def some_data(created_at=dlt.sources.incremental('created_at')): assert s['last_value'] == 4242 -def test_explicit_incremental_instance() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_explicit_incremental_instance(item_type: TItemFormat) -> None: + data = [{'inserted_at': 242, 'some_uq': 444}] + source_items = data_to_item_format(item_type, data) + @dlt.resource(primary_key='some_uq') def some_data(incremental=dlt.sources.incremental('created_at', initial_value=0)): assert incremental.cursor_path == 'inserted_at' assert incremental.initial_value == 241 - yield {'inserted_at': 242, 'some_uq': 444} + yield from source_items p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data(incremental=dlt.sources.incremental('inserted_at', initial_value=241))) @dlt.resource -def some_data_from_config(call_no: int, created_at: Optional[dlt.sources.incremental[str]] = dlt.secrets.value): +def some_data_from_config(call_no: int, item_type: TItemFormat, created_at: Optional[dlt.sources.incremental[str]] = dlt.secrets.value): assert created_at.cursor_path == 'created_at' # start value will update to the last_value on next call if call_no == 1: @@ -177,40 +247,34 @@ def some_data_from_config(call_no: int, created_at: Optional[dlt.sources.increme if call_no == 2: assert created_at.initial_value == '2022-02-03T00:00:00Z' assert created_at.start_value == '2022-02-03T00:00:01Z' - yield {'created_at': '2022-02-03T00:00:01Z'} + data = [{'created_at': '2022-02-03T00:00:01Z'}] + source_items = data_to_item_format(item_type, data) + yield from source_items -def test_optional_incremental_from_config() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_optional_incremental_from_config(item_type: TItemFormat) -> None: os.environ['SOURCES__TEST_INCREMENTAL__SOME_DATA_FROM_CONFIG__CREATED_AT__CURSOR_PATH'] = 'created_at' os.environ['SOURCES__TEST_INCREMENTAL__SOME_DATA_FROM_CONFIG__CREATED_AT__INITIAL_VALUE'] = '2022-02-03T00:00:00Z' p = dlt.pipeline(pipeline_name=uniq_id()) - p.extract(some_data_from_config(1)) - p.extract(some_data_from_config(2)) - - -@configspec -class SomeDataOverrideConfiguration(BaseConfiguration): - created_at: dlt.sources.incremental = dlt.sources.incremental('created_at', initial_value='2022-02-03T00:00:00Z') # type: ignore[type-arg] + p.extract(some_data_from_config(1, item_type)) + p.extract(some_data_from_config(2, item_type)) -# provide what to inject via spec. the spec contain the default -@dlt.resource(spec=SomeDataOverrideConfiguration) -def some_data_override_config(created_at: dlt.sources.incremental[str] = dlt.config.value): - assert created_at.cursor_path == 'created_at' - assert created_at.initial_value == '2000-02-03T00:00:00Z' - yield {'created_at': '2023-03-03T00:00:00Z'} - - -def test_optional_incremental_not_passed() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_optional_incremental_not_passed(item_type: TItemFormat) -> None: """Resource still runs when no incremental is passed""" + data = [1,2,3] + source_items = data_to_item_format(item_type, data) @dlt.resource def some_data(created_at: Optional[dlt.sources.incremental[str]] = None): - yield [1,2,3] + yield source_items - assert list(some_data()) == [1, 2, 3] + result = list(some_data()) + assert result == source_items @configspec @@ -219,51 +283,81 @@ class OptionalIncrementalConfig(BaseConfiguration): @dlt.resource(spec=OptionalIncrementalConfig) -def optional_incremental_arg_resource(incremental: Optional[dlt.sources.incremental[Any]] = None) -> Any: +def optional_incremental_arg_resource(item_type: TItemFormat, incremental: Optional[dlt.sources.incremental[Any]] = None) -> Any: + data = [1,2,3] + source_items = data_to_item_format(item_type, data) assert incremental is None - yield [1, 2, 3] + yield source_items -def test_optional_arg_from_spec_not_passed() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_optional_arg_from_spec_not_passed(item_type: TItemFormat) -> None: p = dlt.pipeline(pipeline_name=uniq_id()) - p.extract(optional_incremental_arg_resource()) + p.extract(optional_incremental_arg_resource(item_type)) + + +@configspec +class SomeDataOverrideConfiguration(BaseConfiguration): + created_at: dlt.sources.incremental = dlt.sources.incremental('created_at', initial_value='2022-02-03T00:00:00Z') # type: ignore[type-arg] -def test_override_initial_value_from_config() -> None: +# provide what to inject via spec. the spec contain the default +@dlt.resource(spec=SomeDataOverrideConfiguration) +def some_data_override_config(item_type: TItemFormat, created_at: dlt.sources.incremental[str] = dlt.config.value): + assert created_at.cursor_path == 'created_at' + assert created_at.initial_value == '2000-02-03T00:00:00Z' + data = [{'created_at': '2023-03-03T00:00:00Z'}] + source_items = data_to_item_format(item_type, data) + yield from source_items + + +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_override_initial_value_from_config(item_type: TItemFormat) -> None: # use the shortest possible config version # os.environ['SOURCES__TEST_INCREMENTAL__SOME_DATA_OVERRIDE_CONFIG__CREATED_AT__INITIAL_VALUE'] = '2000-02-03T00:00:00Z' os.environ['CREATED_AT__INITIAL_VALUE'] = '2000-02-03T00:00:00Z' p = dlt.pipeline(pipeline_name=uniq_id()) - p.extract(some_data_override_config()) - # p.extract(some_data_override_config()) + p.extract(some_data_override_config(item_type)) -def test_override_primary_key_in_pipeline() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_override_primary_key_in_pipeline(item_type: TItemFormat) -> None: """Primary key hint passed to pipeline is propagated through apply_hints """ + data = [ + {'created_at': 22, 'id': 2, 'other_id': 5}, + {'created_at': 22, 'id': 2, 'other_id': 6} + ] + source_items = data_to_item_format(item_type, data) + @dlt.resource(primary_key='id') def some_data(created_at=dlt.sources.incremental('created_at')): # TODO: this only works because incremental instance is shared across many copies of the resource assert some_data.incremental.primary_key == ['id', 'other_id'] - yield {'created_at': 22, 'id': 2, 'other_id': 5} - yield {'created_at': 22, 'id': 2, 'other_id': 6} + yield from source_items p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data, primary_key=['id', 'other_id']) -def test_composite_primary_key() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_composite_primary_key(item_type: TItemFormat) -> None: + data = [ + {'created_at': 1, 'isrc': 'AAA', 'market': 'DE'}, + {'created_at': 2, 'isrc': 'BBB', 'market': 'DE'}, + {'created_at': 2, 'isrc': 'CCC', 'market': 'US'}, + {'created_at': 2, 'isrc': 'AAA', 'market': 'DE'}, + {'created_at': 2, 'isrc': 'CCC', 'market': 'DE'}, + {'created_at': 2, 'isrc': 'DDD', 'market': 'DE'}, + {'created_at': 2, 'isrc': 'CCC', 'market': 'DE'}, + ] + source_items = data_to_item_format(item_type, data) + @dlt.resource(primary_key=['isrc', 'market']) def some_data(created_at=dlt.sources.incremental('created_at')): - yield {'created_at': 1, 'isrc': 'AAA', 'market': 'DE'} - yield {'created_at': 2, 'isrc': 'BBB', 'market': 'DE'} - yield {'created_at': 2, 'isrc': 'CCC', 'market': 'US'} - yield {'created_at': 2, 'isrc': 'AAA', 'market': 'DE'} - yield {'created_at': 2, 'isrc': 'CCC', 'market': 'DE'} - yield {'created_at': 2, 'isrc': 'DDD', 'market': 'DE'} - yield {'created_at': 2, 'isrc': 'CCC', 'market': 'DE'} + yield from source_items p = dlt.pipeline(pipeline_name=uniq_id(), destination='duckdb', credentials=duckdb.connect(':memory:')) p.run(some_data()) @@ -272,18 +366,25 @@ def some_data(created_at=dlt.sources.incremental('created_at')): with c.execute_query("SELECT created_at, isrc, market FROM some_data order by created_at, isrc, market") as cur: rows = cur.fetchall() - assert rows == [(1, 'AAA', 'DE'), (2, 'AAA', 'DE'), (2, 'BBB', 'DE'), (2, 'CCC', 'DE'), (2, 'CCC', 'US'), (2, 'DDD', 'DE')] + expected = [(1, 'AAA', 'DE'), (2, 'AAA', 'DE'), (2, 'BBB', 'DE'), (2, 'CCC', 'DE'), (2, 'CCC', 'US'), (2, 'DDD', 'DE')] + assert rows == expected -def test_last_value_func_min() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_last_value_func_min(item_type: TItemFormat) -> None: + data = [ + {'created_at': 10}, + {'created_at': 11}, + {'created_at': 9}, + {'created_at': 10}, + {'created_at': 8}, + {'created_at': 22}, + ] + source_items = data_to_item_format(item_type, data) + @dlt.resource def some_data(created_at=dlt.sources.incremental('created_at', last_value_func=min)): - yield {'created_at': 10} - yield {'created_at': 11} - yield {'created_at': 9} - yield {'created_at': 10} - yield {'created_at': 8} - yield {'created_at': 22} + yield from source_items p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) @@ -309,16 +410,22 @@ def some_data(created_at=dlt.sources.incremental('created_at', last_value_func=l assert s['last_value'] == 11 -def test_cursor_datetime_type() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_cursor_datetime_type(item_type: TItemFormat) -> None: initial_value = pendulum.now() + data = [ + {'created_at': initial_value + timedelta(minutes=1)}, + {'created_at': initial_value + timedelta(minutes=3)}, + {'created_at': initial_value + timedelta(minutes=2)}, + {'created_at': initial_value + timedelta(minutes=4)}, + {'created_at': initial_value + timedelta(minutes=2)}, + ] + + source_items = data_to_item_format(item_type, data) @dlt.resource def some_data(created_at=dlt.sources.incremental('created_at', initial_value)): - yield {'created_at': initial_value + timedelta(minutes=1)} - yield {'created_at': initial_value + timedelta(minutes=3)} - yield {'created_at': initial_value + timedelta(minutes=2)} - yield {'created_at': initial_value + timedelta(minutes=4)} - yield {'created_at': initial_value + timedelta(minutes=2)} + yield from source_items p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) @@ -327,14 +434,17 @@ def some_data(created_at=dlt.sources.incremental('created_at', initial_value)): assert s['last_value'] == initial_value + timedelta(minutes=4) -def test_descending_order_unique_hashes() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_descending_order_unique_hashes(item_type: TItemFormat) -> None: """Resource returns items in descending order but using `max` last value function. Only hash matching last_value are stored. """ + data = [{'created_at': i} for i in reversed(range(15, 25))] + source_items = data_to_item_format(item_type, data) + @dlt.resource def some_data(created_at=dlt.sources.incremental('created_at', 20)): - for i in reversed(range(15, 25)): - yield {'created_at': i} + yield from source_items p = dlt.pipeline(pipeline_name=uniq_id()) p.extract(some_data()) @@ -349,12 +459,15 @@ def some_data(created_at=dlt.sources.incremental('created_at', 20)): assert list(some_data()) == [] -def test_unique_keys_json_identifiers() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_unique_keys_json_identifiers(item_type: TItemFormat) -> None: """Uses primary key name that is matching the name of the JSON element in the original namespace but gets converted into destination namespace""" + @dlt.resource(primary_key="DelTa") - def some_data(last_timestamp=dlt.sources.incremental("item.ts")): - for i in range(-10, 10): - yield {"DelTa": i, "item": {"ts": pendulum.now().add(days=i).timestamp()}} + def some_data(last_timestamp=dlt.sources.incremental("ts")): + data = [{"DelTa": i, "ts": pendulum.now().add(days=i).timestamp()} for i in range(-10, 10)] + source_items = data_to_item_format(item_type, data) + yield from source_items p = dlt.pipeline(pipeline_name=uniq_id()) p.run(some_data, destination="duckdb") @@ -371,103 +484,153 @@ def some_data(last_timestamp=dlt.sources.incremental("item.ts")): # something got loaded = wee create 20 elements starting from now. so one element will be in the future comparing to previous 20 elements assert len(load_info.loads_ids) == 1 with p.sql_client() as c: - with c.execute_query("SELECT del_ta FROM some_data WHERE _dlt_load_id = %s", load_info.loads_ids[0]) as cur: - rows = cur.fetchall() - assert len(rows) == 1 - assert rows[0][0] == 9 + # with c.execute_query("SELECT del_ta FROM some_data WHERE _dlt_load_id = %s", load_info.loads_ids[0]) as cur: + # rows = cur.fetchall() + with c.execute_query("SELECT del_ta FROM some_data") as cur: + rows2 = cur.fetchall() + assert len(rows2) == 21 + assert rows2[-1][0] == 9 -def test_missing_primary_key() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_missing_primary_key(item_type: TItemFormat) -> None: @dlt.resource(primary_key="DELTA") - def some_data(last_timestamp=dlt.sources.incremental("item.ts")): - for i in range(-10, 10): - yield {"delta": i, "item": {"ts": pendulum.now().add(days=i).timestamp()}} + def some_data(last_timestamp=dlt.sources.incremental("ts")): + data = [{"delta": i, "ts": pendulum.now().add(days=i).timestamp()} for i in range(-10, 10)] + source_items = data_to_item_format(item_type, data) + yield from source_items with pytest.raises(IncrementalPrimaryKeyMissing) as py_ex: list(some_data()) assert py_ex.value.primary_key_column == "DELTA" -def test_missing_cursor_field() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_missing_cursor_field(item_type: TItemFormat) -> None: @dlt.resource def some_data(last_timestamp=dlt.sources.incremental("item.timestamp")): - for i in range(-10, 10): - yield {"delta": i, "item": {"ts": pendulum.now().add(days=i).timestamp()}} + data = [{"delta": i, "ts": pendulum.now().add(days=i).timestamp()} for i in range(-10, 10)] + source_items = data_to_item_format(item_type, data) + yield from source_items with pytest.raises(IncrementalCursorPathMissing) as py_ex: list(some_data) assert py_ex.value.json_path == "item.timestamp" -@dlt.resource -def standalone_some_data(now=None, last_timestamp=dlt.sources.incremental("item.timestamp")): - for i in range(-10, 10): - yield {"delta": i, "item": {"timestamp": (now or pendulum.now()).add(days=i).timestamp()}} -def test_filter_processed_items() -> None: + +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_filter_processed_items(item_type: TItemFormat) -> None: """Checks if already processed items are filtered out""" + @dlt.resource + def standalone_some_data(item_type: TItemFormat, now=None, last_timestamp=dlt.sources.incremental("timestamp")): + data = [{"delta": i, "timestamp": (now or pendulum.now()).add(days=i).timestamp()} for i in range(-10, 10)] + source_items = data_to_item_format(item_type, data) + yield from source_items + # we get all items (no initial - nothing filtered) - assert len(list(standalone_some_data)) == 20 + values = list(standalone_some_data(item_type)) + values = data_item_to_list(item_type, values) + assert len(values) == 20 # provide initial value using max function - values = list(standalone_some_data(last_timestamp=pendulum.now().timestamp())) + values = list(standalone_some_data(item_type, last_timestamp=pendulum.now().timestamp())) + values = data_item_to_list(item_type, values) assert len(values) == 10 # only the future timestamps assert all(v["delta"] >= 0 for v in values) # provide the initial value, use min function - values = list(standalone_some_data(last_timestamp=dlt.sources.incremental("item.timestamp", pendulum.now().timestamp(), min))) + values = list(standalone_some_data( + item_type, last_timestamp=dlt.sources.incremental("timestamp", pendulum.now().timestamp(), min) + )) + values = data_item_to_list(item_type, values) assert len(values) == 10 # the minimum element assert values[0]["delta"] == -10 def test_start_value_set_to_last_value() -> None: - os.environ["COMPLETED_PROB"] = "1.0" - p = dlt.pipeline(pipeline_name=uniq_id()) now = pendulum.now() + @dlt.resource - def some_data(step, last_timestamp=dlt.sources.incremental("item.ts")): + def some_data(step, last_timestamp=dlt.sources.incremental("ts")): + expected_last = now.add(days=step-1) + if step == -10: assert last_timestamp.start_value is None else: # print(last_timestamp.initial_value) # print(now.add(days=step-1).timestamp()) - assert last_timestamp.start_value == last_timestamp.last_value == now.add(days=step-1).timestamp() - for i in range(-10, 10): - yield {"delta": i, "item": {"ts": now.add(days=i).timestamp()}} + assert last_timestamp.start_value == last_timestamp.last_value == expected_last + data = [{"delta": i, "ts": now.add(days=i)} for i in range(-10, 10)] + yield from data # after all yielded if step == -10: assert last_timestamp.start_value is None else: - assert last_timestamp.start_value == now.add(days=step-1).timestamp() != last_timestamp.last_value + assert last_timestamp.start_value == expected_last != last_timestamp.last_value for i in range(-10, 10): r = some_data(i) assert len(r._pipe) == 2 r.add_filter(take_first(i + 11), 1) - p.run(r, destination="dummy") + p.run(r, destination="duckdb") + + +@pytest.mark.parametrize("item_type", set(ALL_ITEM_FORMATS) - {'json'}) +def test_start_value_set_to_last_value_arrow(item_type: TItemFormat) -> None: + p = dlt.pipeline(pipeline_name=uniq_id(), destination='duckdb') + now = pendulum.now() + + data = [{"delta": i, "ts": now.add(days=i)} for i in range(-10, 10)] + source_items = data_to_item_format(item_type, data) + + @dlt.resource + def some_data(first: bool, last_timestamp=dlt.sources.incremental("ts")): + if first: + assert last_timestamp.start_value is None + else: + # print(last_timestamp.initial_value) + # print(now.add(days=step-1).timestamp()) + assert last_timestamp.start_value == last_timestamp.last_value == data[-1]['ts'] + yield from source_items + # after all yielded + if first: + assert last_timestamp.start_value is None + else: + assert last_timestamp.start_value == data[-1]['ts'] == last_timestamp.last_value + p.run(some_data(True)) + p.run(some_data(False)) -def test_replace_resets_state() -> None: - os.environ["COMPLETED_PROB"] = "1.0" - p = dlt.pipeline(pipeline_name=uniq_id(), destination="dummy") + +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_replace_resets_state(item_type: TItemFormat) -> None: + p = dlt.pipeline(pipeline_name=uniq_id(), destination="duckdb") now = pendulum.now() - info = p.run(standalone_some_data(now)) + @dlt.resource + def standalone_some_data(item_type: TItemFormat, now=None, last_timestamp=dlt.sources.incremental("timestamp")): + data = [{"delta": i, "timestamp": (now or pendulum.now()).add(days=i).timestamp()} for i in range(-10, 10)] + source_items = data_to_item_format(item_type, data) + yield from source_items + + info = p.run(standalone_some_data(item_type, now)) assert len(info.loads_ids) == 1 - info = p.run(standalone_some_data(now)) + info = p.run(standalone_some_data(item_type, now)) assert len(info.loads_ids) == 0 - info = p.run(standalone_some_data(now), write_disposition="replace") + info = p.run(standalone_some_data(item_type, now), write_disposition="replace") assert len(info.loads_ids) == 1 - parent_r = standalone_some_data(now) + parent_r = standalone_some_data(item_type, now) @dlt.transformer(data_from=parent_r, write_disposition="append") def child(item): state = resource_state("child") @@ -520,98 +683,94 @@ def child(item): assert extracted[child._pipe.parent.name].write_disposition == "append" -def test_incremental_as_transform() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_incremental_as_transform(item_type: TItemFormat) -> None: now = pendulum.now().timestamp() @dlt.resource def some_data(): - last_value: dlt.sources.incremental[float] = dlt.sources.incremental.from_existing_state("some_data", "item.ts") + last_value: dlt.sources.incremental[float] = dlt.sources.incremental.from_existing_state("some_data", "ts") assert last_value.initial_value == now assert last_value.start_value == now - assert last_value.cursor_path == "item.ts" + assert last_value.cursor_path == "ts" assert last_value.last_value == now - for i in range(-10, 10): - yield {"delta": i, "item": {"ts": pendulum.now().add(days=i).timestamp()}} + data = [{"delta": i, "ts": pendulum.now().add(days=i).timestamp()} for i in range(-10, 10)] + source_items = data_to_item_format(item_type, data) + yield from source_items - r = some_data().add_step(dlt.sources.incremental("item.ts", initial_value=now, primary_key="delta")) + r = some_data().add_step(dlt.sources.incremental("ts", initial_value=now, primary_key="delta")) p = dlt.pipeline(pipeline_name=uniq_id()) info = p.run(r, destination="duckdb") assert len(info.loads_ids) == 1 -def test_incremental_explicit_primary_key() -> None: - @dlt.resource(primary_key="delta") - def some_data(last_timestamp=dlt.sources.incremental("item.ts", primary_key="DELTA")): - for i in range(-10, 10): - yield {"delta": i, "item": {"ts": pendulum.now().add(days=i).timestamp()}} - - with pytest.raises(IncrementalPrimaryKeyMissing) as py_ex: - list(some_data()) - assert py_ex.value.primary_key_column == "DELTA" - - -def test_incremental_explicit_disable_unique_check() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_incremental_explicit_disable_unique_check(item_type: TItemFormat) -> None: @dlt.resource(primary_key="delta") - def some_data(last_timestamp=dlt.sources.incremental("item.ts", primary_key=())): - for i in range(-10, 10): - yield {"delta": i, "item": {"ts": pendulum.now().timestamp()}} + def some_data(last_timestamp=dlt.sources.incremental("ts", primary_key=())): + data = [{"delta": i, "ts": pendulum.now().timestamp()} for i in range(-10, 10)] + source_items = data_to_item_format(item_type, data) + yield from source_items with Container().injectable_context(StateInjectableContext(state={})): s = some_data() list(s) # no unique hashes at all - assert s.state["incremental"]["item.ts"]["unique_hashes"] == [] + assert s.state["incremental"]["ts"]["unique_hashes"] == [] -def test_apply_hints_incremental() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_apply_hints_incremental(item_type: TItemFormat) -> None: p = dlt.pipeline(pipeline_name=uniq_id()) + data = [{"created_at": 1}, {"created_at": 2}, {"created_at": 3}] + source_items = data_to_item_format(item_type, data) @dlt.resource def some_data(created_at: Optional[dlt.sources.incremental[int]] = None): - yield [1,2,3] + yield source_items # the incremental wrapper is created for a resource and the incremental value is provided via apply hints r = some_data() - assert list(r) == [1, 2, 3] - r.apply_hints(incremental=dlt.sources.incremental("$")) + assert list(r) == source_items + r.apply_hints(incremental=dlt.sources.incremental("created_at")) p.extract(r) assert "incremental" in r.state assert list(r) == [] # as above but we provide explicit incremental when creating resource p = p.drop() - r = some_data(created_at=dlt.sources.incremental("$", last_value_func=min)) + r = some_data(created_at=dlt.sources.incremental("created_at", last_value_func=min)) # explicit has precedence here - r.apply_hints(incremental=dlt.sources.incremental("$", last_value_func=max)) + r.apply_hints(incremental=dlt.sources.incremental("created_at", last_value_func=max)) p.extract(r) assert "incremental" in r.state # min value - assert r.state["incremental"]["$"]["last_value"] == 1 + assert r.state["incremental"]["created_at"]["last_value"] == 1 @dlt.resource - def some_data_w_default(created_at = dlt.sources.incremental("$", last_value_func=min)): - yield [1,2,3] + def some_data_w_default(created_at = dlt.sources.incremental("created_at", last_value_func=min)): + yield source_items # default is overridden by apply hints p = p.drop() r = some_data_w_default() - r.apply_hints(incremental=dlt.sources.incremental("$", last_value_func=max)) + r.apply_hints(incremental=dlt.sources.incremental("created_at", last_value_func=max)) p.extract(r) assert "incremental" in r.state # min value - assert r.state["incremental"]["$"]["last_value"] == 3 + assert r.state["incremental"]["created_at"]["last_value"] == 3 @dlt.resource def some_data_no_incremental(): - yield [1, 2, 3] + yield source_items # we add incremental as a step p = p.drop() r = some_data_no_incremental() - r.apply_hints(incremental=dlt.sources.incremental("$", last_value_func=max)) + r.apply_hints(incremental=dlt.sources.incremental("created_at", last_value_func=max)) assert r.incremental is not None p.extract(r) assert "incremental" in r.state @@ -655,6 +814,7 @@ def _get_shuffled_events(last_created_at = dlt.sources.incremental("$", last_val def test_timezone_naive_datetime() -> None: + # TODO: arrow doesn't work with this """Resource has timezone naive datetime objects, but incremental stored state is converted to tz aware pendulum dates. Can happen when loading e.g. from sql database""" start_dt = datetime.now() @@ -662,24 +822,32 @@ def test_timezone_naive_datetime() -> None: @dlt.resource def some_data(updated_at: dlt.sources.incremental[pendulum.DateTime] = dlt.sources.incremental('updated_at', pendulum_start_dt)): - yield [{'updated_at': start_dt + timedelta(hours=1)}, {'updated_at': start_dt + timedelta(hours=2)}] + data = [{'updated_at': start_dt + timedelta(hours=1)}, {'updated_at': start_dt + timedelta(hours=2)}] + yield data pipeline = dlt.pipeline(pipeline_name=uniq_id()) - pipeline.extract(some_data()) + resource = some_data() + pipeline.extract(resource) + # last value has timezone added + last_value = resource.state['incremental']['updated_at']['last_value'] + assert isinstance(last_value, pendulum.DateTime) + assert last_value.tzname() == "UTC" @dlt.resource def endless_sequence( + item_type: TItemFormat, updated_at: dlt.sources.incremental[int] = dlt.sources.incremental('updated_at', initial_value=1) ) -> Any: max_values = 20 start = updated_at.last_value + data = [{'updated_at': i} for i in range(start, start + max_values)] + source_items = data_to_item_format(item_type, data) + yield from source_items - for i in range(start, start + max_values): - yield {'updated_at': i} - -def test_chunked_ranges() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_chunked_ranges(item_type: TItemFormat) -> None: """Load chunked ranges with end value along with incremental""" pipeline = dlt.pipeline(pipeline_name='incremental_' + uniq_id(), destination='duckdb') @@ -701,7 +869,7 @@ def test_chunked_ranges() -> None: for start, end in chunks: pipeline.run( - endless_sequence(updated_at=dlt.sources.incremental(initial_value=start, end_value=end)), + endless_sequence(item_type, updated_at=dlt.sources.incremental(initial_value=start, end_value=end)), write_disposition='append' ) @@ -722,15 +890,18 @@ def test_chunked_ranges() -> None: assert items == expected_range -def test_end_value_with_batches() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_end_value_with_batches(item_type: TItemFormat) -> None: """Ensure incremental with end_value works correctly when resource yields lists instead of single items""" @dlt.resource def batched_sequence( updated_at: dlt.sources.incremental[int] = dlt.sources.incremental('updated_at', initial_value=1) ) -> Any: start = updated_at.last_value - yield [{'updated_at': i} for i in range(start, start + 12)] - yield [{'updated_at': i} for i in range(start+12, start + 20)] + data = [{'updated_at': i} for i in range(start, start + 12)] + yield data_to_item_format(item_type, data) + data = [{'updated_at': i} for i in range(start+12, start + 20)] + yield data_to_item_format(item_type, data) pipeline = dlt.pipeline(pipeline_name='incremental_' + uniq_id(), destination='duckdb') @@ -755,17 +926,19 @@ def batched_sequence( assert items == list(range(1, 14)) -def test_load_with_end_value_does_not_write_state() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_load_with_end_value_does_not_write_state(item_type: TItemFormat) -> None: """When loading chunk with initial/end value range. The resource state is untouched. """ pipeline = dlt.pipeline(pipeline_name='incremental_' + uniq_id(), destination='duckdb') - pipeline.extract(endless_sequence(updated_at=dlt.sources.incremental(initial_value=20, end_value=30))) + pipeline.extract(endless_sequence(item_type, updated_at=dlt.sources.incremental(initial_value=20, end_value=30))) assert pipeline.state.get('sources') is None -def test_end_value_initial_value_errors() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_end_value_initial_value_errors(item_type: TItemFormat) -> None: @dlt.resource def some_data( updated_at: dlt.sources.incremental[int] = dlt.sources.incremental('updated_at') @@ -800,14 +973,16 @@ def custom_last_value(items): assert "The result of 'custom_last_value([end_value, initial_value])' must equal 'end_value'" in str(ex.value) -def test_out_of_range_flags() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_out_of_range_flags(item_type: TItemFormat) -> None: """Test incremental.start_out_of_range / end_out_of_range flags are set when items are filtered out""" @dlt.resource def descending( updated_at: dlt.sources.incremental[int] = dlt.sources.incremental('updated_at', initial_value=10) ) -> Any: for chunk in chunks(list(reversed(range(48))), 10): - yield [{'updated_at': i} for i in chunk] + data = [{'updated_at': i} for i in chunk] + yield data_to_item_format(item_type, data) # Assert flag is set only on the first item < initial_value if all(item > 9 for item in chunk): assert updated_at.start_out_of_range is False @@ -820,7 +995,8 @@ def ascending( updated_at: dlt.sources.incremental[int] = dlt.sources.incremental('updated_at', initial_value=22, end_value=45) ) -> Any: for chunk in chunks(list(range(22, 500)), 10): - yield [{'updated_at': i} for i in chunk] + data = [{'updated_at': i} for i in chunk] + yield data_to_item_format(item_type, data) # Flag is set only when end_value is reached if all(item < 45 for item in chunk): assert updated_at.end_out_of_range is False @@ -834,6 +1010,8 @@ def descending_single_item( updated_at: dlt.sources.incremental[int] = dlt.sources.incremental('updated_at', initial_value=10) ) -> Any: for i in reversed(range(14)): + data = [{'updated_at': i}] + yield from data_to_item_format(item_type, data) yield {'updated_at': i} if i >= 10: assert updated_at.start_out_of_range is False @@ -846,7 +1024,8 @@ def ascending_single_item( updated_at: dlt.sources.incremental[int] = dlt.sources.incremental('updated_at', initial_value=10, end_value=22) ) -> Any: for i in range(10, 500): - yield {'updated_at': i} + data = [{'updated_at': i}] + yield from data_to_item_format(item_type, data) if i < 22: assert updated_at.end_out_of_range is False else: @@ -863,8 +1042,8 @@ def ascending_single_item( pipeline.extract(ascending_single_item()) - -def test_get_incremental_value_type() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_get_incremental_value_type(item_type: TItemFormat) -> None: assert dlt.sources.incremental("id").get_incremental_value_type() is Any assert dlt.sources.incremental("id", initial_value=0).get_incremental_value_type() is int assert dlt.sources.incremental("id", initial_value=None).get_incremental_value_type() is Any @@ -876,7 +1055,8 @@ def test_get_incremental_value_type() -> None: # pass default value @dlt.resource def test_type(updated_at = dlt.sources.incremental[str]("updated_at", allow_external_schedulers=True)): # noqa: B008 - yield [{"updated_at": d} for d in [1, 2, 3]] + data = [{"updated_at": d} for d in [1, 2, 3]] + yield data_to_item_format(item_type, data) r = test_type() list(r) @@ -885,7 +1065,8 @@ def test_type(updated_at = dlt.sources.incremental[str]("updated_at", allow_exte # use annotation @dlt.resource def test_type_2(updated_at: dlt.sources.incremental[int] = dlt.sources.incremental("updated_at", allow_external_schedulers=True)): - yield [{"updated_at": d} for d in [1, 2, 3]] + data = [{"updated_at": d} for d in [1, 2, 3]] + yield data_to_item_format(item_type, data) r = test_type_2() list(r) @@ -894,7 +1075,8 @@ def test_type_2(updated_at: dlt.sources.incremental[int] = dlt.sources.increment # pass in explicit value @dlt.resource def test_type_3(updated_at: dlt.sources.incremental[int]): - yield [{"updated_at": d} for d in [1, 2, 3]] + data = [{"updated_at": d} for d in [1, 2, 3]] + yield data_to_item_format(item_type, data) r = test_type_3(dlt.sources.incremental[float]("updated_at", allow_external_schedulers=True)) list(r) @@ -903,7 +1085,8 @@ def test_type_3(updated_at: dlt.sources.incremental[int]): # pass explicit value overriding default that is typed @dlt.resource def test_type_4(updated_at = dlt.sources.incremental("updated_at", allow_external_schedulers=True)): - yield [{"updated_at": d} for d in [1, 2, 3]] + data = [{"updated_at": d} for d in [1, 2, 3]] + yield data_to_item_format(item_type, data) r = test_type_4(dlt.sources.incremental[str]("updated_at", allow_external_schedulers=True)) list(r) @@ -912,36 +1095,44 @@ def test_type_4(updated_at = dlt.sources.incremental("updated_at", allow_externa # no generic type information @dlt.resource def test_type_5(updated_at = dlt.sources.incremental("updated_at", allow_external_schedulers=True)): - yield [{"updated_at": d} for d in [1, 2, 3]] + data = [{"updated_at": d} for d in [1, 2, 3]] + yield data_to_item_format(item_type, data) r = test_type_5(dlt.sources.incremental("updated_at")) list(r) assert r.incremental._incremental.get_incremental_value_type() is Any -def test_join_env_scheduler() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_join_env_scheduler(item_type: TItemFormat) -> None: @dlt.resource def test_type_2(updated_at: dlt.sources.incremental[int] = dlt.sources.incremental("updated_at", allow_external_schedulers=True)): - yield [{"updated_at": d} for d in [1, 2, 3]] + data = [{"updated_at": d} for d in [1, 2, 3]] + yield data_to_item_format(item_type, data) - assert list(test_type_2()) == [{'updated_at': 1}, {'updated_at': 2}, {'updated_at': 3}] + result = list(test_type_2()) + assert data_item_to_list(item_type, result) == [{'updated_at': 1}, {'updated_at': 2}, {'updated_at': 3}] # set start and end values os.environ["DLT_START_VALUE"] = "2" - assert list(test_type_2()) == [{'updated_at': 2}, {'updated_at': 3}] + result = list(test_type_2()) + assert data_item_to_list(item_type, result) == [{'updated_at': 2}, {'updated_at': 3}] os.environ["DLT_END_VALUE"] = "3" - assert list(test_type_2()) == [{'updated_at': 2}] + result = list(test_type_2()) + assert data_item_to_list(item_type, result) == [{'updated_at': 2}] -def test_join_env_scheduler_pipeline() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_join_env_scheduler_pipeline(item_type: TItemFormat) -> None: @dlt.resource def test_type_2(updated_at: dlt.sources.incremental[int] = dlt.sources.incremental("updated_at", allow_external_schedulers=True)): - yield [{"updated_at": d} for d in [1, 2, 3]] + data = [{"updated_at": d} for d in [1, 2, 3]] + yield data_to_item_format(item_type, data) pip_1_name = 'incremental_' + uniq_id() pipeline = dlt.pipeline(pipeline_name=pip_1_name, destination='duckdb') r = test_type_2() - r.add_step(AssertItems([{'updated_at': 2}, {'updated_at': 3}])) + r.add_step(AssertItems([{'updated_at': 2}, {'updated_at': 3}], item_type)) os.environ["DLT_START_VALUE"] = "2" pipeline.extract(r) # state is saved next extract has no items @@ -952,22 +1143,25 @@ def test_type_2(updated_at: dlt.sources.incremental[int] = dlt.sources.increment # setting end value will stop using state os.environ["DLT_END_VALUE"] = "3" r = test_type_2() - r.add_step(AssertItems([{'updated_at': 2}])) + r.add_step(AssertItems([{'updated_at': 2}], item_type)) pipeline.extract(r) r = test_type_2() os.environ["DLT_START_VALUE"] = "1" - r.add_step(AssertItems([{'updated_at': 1}, {'updated_at': 2}])) + r.add_step(AssertItems([{'updated_at': 1}, {'updated_at': 2}], item_type)) pipeline.extract(r) -def test_allow_external_schedulers() -> None: +@pytest.mark.parametrize("item_type", ALL_ITEM_FORMATS) +def test_allow_external_schedulers(item_type: TItemFormat) -> None: @dlt.resource() def test_type_2(updated_at: dlt.sources.incremental[int] = dlt.sources.incremental("updated_at")): - yield [{"updated_at": d} for d in [1, 2, 3]] + data = [{"updated_at": d} for d in [1, 2, 3]] + yield data_to_item_format(item_type, data) # does not participate os.environ["DLT_START_VALUE"] = "2" - assert len(list(test_type_2())) == 3 + result = data_item_to_list(item_type, list(test_type_2())) + assert len(result) == 3 assert test_type_2.incremental.allow_external_schedulers is False assert test_type_2().incremental.allow_external_schedulers is False @@ -975,7 +1169,8 @@ def test_type_2(updated_at: dlt.sources.incremental[int] = dlt.sources.increment # allow scheduler in wrapper r = test_type_2() r.incremental.allow_external_schedulers = True - assert len(list(test_type_2())) == 2 + result = data_item_to_list(item_type, list(test_type_2())) + assert len(result) == 2 assert r.incremental.allow_external_schedulers is True assert r.incremental._incremental.allow_external_schedulers is True @@ -987,4 +1182,5 @@ def test_type_3(): r = test_type_3() r.add_step(dlt.sources.incremental("updated_at")) r.incremental.allow_external_schedulers = True - assert len(list(test_type_2())) == 2 + result = data_item_to_list(item_type, list(test_type_2())) + assert len(result) == 2 diff --git a/tests/extract/utils.py b/tests/extract/utils.py index 3cf7e6373c..2465a1b1e2 100644 --- a/tests/extract/utils.py +++ b/tests/extract/utils.py @@ -1,11 +1,20 @@ -from typing import Any, Optional +from typing import Any, Optional, List, Union, Literal, get_args import pytest -from itertools import zip_longest +from itertools import zip_longest, chain from dlt.common.typing import TDataItem, TDataItems, TAny from dlt.extract.extract import ExtractorStorage from dlt.extract.typing import ItemTransform, ItemTransformFunc +from tests.cases import TArrowFormat + +import pandas as pd +from dlt.common.libs.pyarrow import pyarrow as pa + + +TItemFormat = Literal["json", "pandas", "arrow"] + +ALL_ITEM_FORMATS = get_args(TItemFormat) def expect_extracted_file(storage: ExtractorStorage, schema_name: str, table_name: str, content: str) -> None: @@ -27,9 +36,32 @@ def expect_extracted_file(storage: ExtractorStorage, schema_name: str, table_nam class AssertItems(ItemTransform[TDataItem]): - def __init__(self, expected_items: Any) -> None: + def __init__(self, expected_items: Any, item_type: TItemFormat = "json") -> None: self.expected_items = expected_items + self.item_type = item_type def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: - assert item == self.expected_items + assert data_item_to_list(self.item_type, item) == self.expected_items return item + + +def data_to_item_format(item_format: TItemFormat, data: List[TDataItem]): + """Return the given data in the form of pandas, arrow table or json items""" + if item_format == "json": + return data + # Make dataframe from the data + df = pd.DataFrame(data) + if item_format == "pandas": + return [df] + elif item_format == "arrow": + return [pa.Table.from_pandas(df)] + else: + raise ValueError(f"Unknown item format: {item_format}") + + +def data_item_to_list(from_type: TItemFormat, values: List[TDataItem]): + if from_type == "arrow": + return values[0].to_pylist() + elif from_type == "pandas": + return values[0].to_dict("records") + return values diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py new file mode 100644 index 0000000000..d240af0c81 --- /dev/null +++ b/tests/load/pipeline/test_arrow_loading.py @@ -0,0 +1,56 @@ +import pytest +from datetime import datetime # noqa: I251 + +from typing import Any, Union, List, Dict, Tuple, Literal + +import dlt +from dlt.common import Decimal +from dlt.common import pendulum +from dlt.common.utils import uniq_id +from tests.load.utils import destinations_configs, DestinationTestConfiguration +from tests.load.pipeline.utils import assert_table, assert_query_data, select_data +from tests.utils import preserve_environ +from tests.cases import arrow_table_all_data_types + + +@pytest.mark.parametrize("destination_config", destinations_configs(file_format="parquet", default_sql_configs=True, default_staging_configs=True, all_buckets_filesystem_configs=True, all_staging_configs=True), ids=lambda x: x.name) +@pytest.mark.parametrize("item_type", ["pandas", "table", "record_batch"]) +def test_load_item(item_type: Literal["pandas", "table", "record_batch"], destination_config: DestinationTestConfiguration) -> None: + include_time = destination_config.destination not in ("athena", "redshift") # athena/redshift can't load TIME columns from parquet + item, records = arrow_table_all_data_types(item_type, include_json=False, include_time=include_time) + + pipeline = destination_config.setup_pipeline("arrow_" + uniq_id()) + + @dlt.resource + def some_data(): + yield item + + pipeline.run(some_data()) + # assert the table types + some_table_columns = pipeline.default_schema.get_table("some_data")["columns"] + assert some_table_columns["string"]["data_type"] == "text" + assert some_table_columns["float"]["data_type"] == "double" + assert some_table_columns["int"]["data_type"] == "bigint" + assert some_table_columns["datetime"]["data_type"] == "timestamp" + assert some_table_columns["binary"]["data_type"] == "binary" + assert some_table_columns["decimal"]["data_type"] == "decimal" + assert some_table_columns["bool"]["data_type"] == "bool" + if include_time: + assert some_table_columns["time"]["data_type"] == "time" + + rows = [list(row) for row in select_data(pipeline, "SELECT * FROM some_data ORDER BY 1")] + + if destination_config.destination == "redshift": + # Binary columns are hex formatted in results + for record in records: + if "binary" in record: + record["binary"] = record["binary"].hex() + + for row in rows: + for i in range(len(row)): + if isinstance(row[i], datetime): + row[i] = pendulum.instance(row[i]) + + expected = [list(r.values()) for r in records] + + assert rows == expected diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 60399ab3ee..9f4834abc9 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -619,7 +619,7 @@ def test_snowflake_delete_file_after_copy(destination_config: DestinationTestCon # do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, subset=["bigquery", "snowflake", "duckdb"]), ids=lambda x: x.name) +@pytest.mark.parametrize("destination_config", destinations_configs(default_sql_configs=True, file_format="parquet"), ids=lambda x: x.name) def test_parquet_loading(destination_config: DestinationTestConfiguration) -> None: """Run pipeline twice with merge write disposition Resource with primary key falls back to append. Resource without keys falls back to replace. diff --git a/tests/load/utils.py b/tests/load/utils.py index b2c87e56ae..d25f4169a8 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -97,6 +97,7 @@ def destinations_configs( all_buckets_filesystem_configs: bool = False, subset: Sequence[str] = (), exclude: Sequence[str] = (), + file_format: Optional[TLoaderFileFormat] = None, ) -> List[DestinationTestConfiguration]: # sanity check @@ -109,6 +110,7 @@ def destinations_configs( # default non staging sql based configs, one per destination if default_sql_configs: destination_configs += [DestinationTestConfiguration(destination=destination) for destination in SQL_DESTINATIONS if destination != "athena"] + destination_configs += [DestinationTestConfiguration(destination="duckdb", file_format="parquet")] # athena needs filesystem staging, which will be automatically set, we have to supply a bucket url though destination_configs += [DestinationTestConfiguration(destination="athena", supports_merge=False, bucket_url=AWS_BUCKET)] destination_configs += [DestinationTestConfiguration(destination="athena", staging="filesystem", file_format="parquet", bucket_url=AWS_BUCKET, force_iceberg=True, supports_merge=False, supports_dbt=False, extra_info="iceberg")] @@ -154,6 +156,8 @@ def destinations_configs( destination_configs = [conf for conf in destination_configs if conf.destination in subset] if exclude: destination_configs = [conf for conf in destination_configs if conf.destination not in exclude] + if file_format: + destination_configs = [conf for conf in destination_configs if conf.file_format == file_format] # filter out excluded configs destination_configs = [conf for conf in destination_configs if conf.name not in EXCLUDED_DESTINATION_CONFIGURATIONS] diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index 9f7354bf30..09c9a017b3 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -363,7 +363,7 @@ def test_group_worker_files() -> None: def extract_items(normalize_storage: NormalizeStorage, items: Sequence[StrAny], schema_name: str, table_name: str) -> None: extractor = ExtractorStorage(normalize_storage.config) extract_id = extractor.create_extract_id() - extractor.write_data_item(extract_id, schema_name, table_name, items, None) + extractor.write_data_item("puae-jsonl", extract_id, schema_name, table_name, items, None) extractor.close_writers(extract_id) extractor.commit_extract_files(extract_id) @@ -392,7 +392,7 @@ def normalize_pending(normalize: Normalize, schema_name: str = "event") -> str: def extract_cases(normalize_storage: NormalizeStorage, cases: Sequence[str]) -> None: for case in cases: - schema_name, table_name, _ = NormalizeStorage.parse_normalize_file_name(case + ".jsonl") + schema_name, table_name, _, _ = NormalizeStorage.parse_normalize_file_name(case + ".jsonl") with open(json_case_path(case), "rb") as f: items = json.load(f) extract_items(normalize_storage, items, schema_name, table_name) diff --git a/tests/pipeline/test_arrow_sources.py b/tests/pipeline/test_arrow_sources.py new file mode 100644 index 0000000000..759d5735c9 --- /dev/null +++ b/tests/pipeline/test_arrow_sources.py @@ -0,0 +1,113 @@ +import pytest + +import pandas as pd +from typing import Any, Union +import pyarrow as pa + +import dlt +from dlt.common import Decimal +from dlt.common.utils import uniq_id +from dlt.common.exceptions import TerminalValueError +from dlt.pipeline.exceptions import PipelineStepFailed +from tests.cases import arrow_table_all_data_types, TArrowFormat +from dlt.common.storages import LoadStorage + + + +@pytest.mark.parametrize( + ("item_type", "is_list"), [("pandas", False), ("table", False), ("record_batch", False), ("pandas", True), ("table", True), ("record_batch", True)] +) +def test_extract_and_normalize(item_type: TArrowFormat, is_list: bool): + item, records = arrow_table_all_data_types(item_type) + + pipeline = dlt.pipeline("arrow_" + uniq_id(), destination="filesystem") + + @dlt.resource + def some_data(): + if is_list: + yield [item] + else: + yield item + + + pipeline.extract(some_data()) + norm_storage = pipeline._get_normalize_storage() + extract_files = [fn for fn in norm_storage.list_files_to_normalize_sorted() if fn.endswith(".parquet")] + + assert len(extract_files) == 1 + + with norm_storage.storage.open_file(extract_files[0], 'rb') as f: + extracted_bytes = f.read() + + info = pipeline.normalize() + + assert info.row_counts['some_data'] == len(records) + + load_id = pipeline.list_normalized_load_packages()[0] + storage = pipeline._get_load_storage() + jobs = storage.list_new_jobs(load_id) + with storage.storage.open_file(jobs[0], 'rb') as f: + normalized_bytes = f.read() + + # Normalized is linked/copied exactly and should be the same as the extracted file + assert normalized_bytes == extracted_bytes + + f.seek(0) + pq = pa.parquet.ParquetFile(f) + tbl = pq.read() + + # Make the dataframes comparable exactly + df_tbl = pa.Table.from_pandas(pd.DataFrame(records)).to_pandas() + # Data is identical to the original dataframe + assert (tbl.to_pandas() == df_tbl).all().all() + + schema = pipeline.default_schema + + # Check schema detection + schema_columns = schema.tables['some_data']['columns'] + assert set(df_tbl.columns) == set(schema_columns) + assert schema_columns['date']['data_type'] == 'date' + assert schema_columns['int']['data_type'] == 'bigint' + assert schema_columns['float']['data_type'] == 'double' + assert schema_columns['decimal']['data_type'] == 'decimal' + assert schema_columns['time']['data_type'] == 'time' + assert schema_columns['binary']['data_type'] == 'binary' + assert schema_columns['string']['data_type'] == 'text' + assert schema_columns['json']['data_type'] == 'complex' + + +@pytest.mark.parametrize("item_type", ["pandas", "table", "record_batch"]) +def test_normalize_unsupported_loader_format(item_type: TArrowFormat): + item, _ = arrow_table_all_data_types(item_type) + + pipeline = dlt.pipeline("arrow_" + uniq_id(), destination="dummy") + + @dlt.resource + def some_data(): + yield item + + pipeline.extract(some_data()) + with pytest.raises(PipelineStepFailed) as py_ex: + pipeline.normalize() + + assert "The destination doesn't support direct loading of arrow tables" in str(py_ex.value) + + +@pytest.mark.parametrize("item_type", ["table", "record_batch"]) +def test_add_map(item_type: TArrowFormat): + item, _ = arrow_table_all_data_types(item_type) + + @dlt.resource + def some_data(): + yield item + + def map_func(item): + return item.filter(pa.compute.equal(item['int'], 1)) + + # Add map that filters the table + some_data.add_map(map_func) + + result = list(some_data()) + + assert len(result) == 1 + assert result[0]['int'][0].as_py() == 1