Skip to content

Commit

Permalink
picks file format matching item format (#1222)
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix authored Apr 16, 2024
1 parent 2ce906c commit 0abad12
Show file tree
Hide file tree
Showing 21 changed files with 481 additions and 186 deletions.
6 changes: 6 additions & 0 deletions dlt/common/data_writers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
DataWriterMetrics,
TDataItemFormat,
FileWriterSpec,
resolve_best_writer_spec,
get_best_writer_spec,
is_native_writer,
)
from dlt.common.data_writers.buffered import BufferedDataWriter, new_file_id
from dlt.common.data_writers.escape import (
Expand All @@ -14,6 +17,9 @@
__all__ = [
"DataWriter",
"FileWriterSpec",
"resolve_best_writer_spec",
"get_best_writer_spec",
"is_native_writer",
"DataWriterMetrics",
"TDataItemFormat",
"BufferedDataWriter",
Expand Down
4 changes: 1 addition & 3 deletions dlt/common/data_writers/buffered.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def __init__(
self.writer_spec = writer_spec
if self.writer_spec.requires_destination_capabilities and not _caps:
raise DestinationCapabilitiesRequired(self.writer_spec.file_format)
self.writer_cls = DataWriter.class_factory(
writer_spec.file_format, writer_spec.data_item_format
)
self.writer_cls = DataWriter.writer_class_from_spec(writer_spec)
self._supports_schema_changes = self.writer_spec.supports_schema_changes
self._caps = _caps
# validate if template has correct placeholders
Expand Down
31 changes: 31 additions & 0 deletions dlt/common/data_writers/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import NamedTuple, Sequence
from dlt.common.destination import TLoaderFileFormat
from dlt.common.exceptions import DltException

Expand Down Expand Up @@ -30,6 +31,10 @@ def __init__(self, file_format: TLoaderFileFormat):


class DataWriterNotFound(DataWriterException):
pass


class FileFormatForItemFormatNotFound(DataWriterNotFound):
def __init__(self, file_format: TLoaderFileFormat, data_item_format: str):
self.file_format = file_format
self.data_item_format = data_item_format
Expand All @@ -39,6 +44,32 @@ def __init__(self, file_format: TLoaderFileFormat, data_item_format: str):
)


class FileSpecNotFound(KeyError, DataWriterNotFound):
def __init__(self, file_format: TLoaderFileFormat, data_item_format: str, spec: NamedTuple):
self.file_format = file_format
self.data_item_format = data_item_format
super().__init__(
f"Can't find a file writer for spec with file format {file_format} and item format"
f" {data_item_format} where the full spec is {spec}"
)


class SpecLookupFailed(DataWriterNotFound):
def __init__(
self,
data_item_format: str,
possible_file_formats: Sequence[TLoaderFileFormat],
file_format: TLoaderFileFormat,
):
self.file_format = file_format
self.possible_file_formats = possible_file_formats
self.data_item_format = data_item_format
super().__init__(
f"Lookup for best file writer for item format {data_item_format} among file formats"
f" {possible_file_formats} failed. The preferred file format was {file_format}."
)


class InvalidDataItem(DataWriterException):
def __init__(self, file_format: TLoaderFileFormat, data_item_format: str, details: str):
self.file_format = file_format
Expand Down
127 changes: 115 additions & 12 deletions dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import abc
import csv
from dataclasses import dataclass
from typing import (
IO,
TYPE_CHECKING,
Expand All @@ -20,7 +19,13 @@
from dlt.common.json import json
from dlt.common.configuration import configspec, known_sections, with_config
from dlt.common.configuration.specs import BaseConfiguration
from dlt.common.data_writers.exceptions import DataWriterNotFound, InvalidDataItem
from dlt.common.data_writers.exceptions import (
SpecLookupFailed,
DataWriterNotFound,
FileFormatForItemFormatNotFound,
FileSpecNotFound,
InvalidDataItem,
)
from dlt.common.destination import DestinationCapabilitiesContext, TLoaderFileFormat
from dlt.common.schema.typing import TTableSchemaColumns
from dlt.common.typing import StrAny
Expand All @@ -33,8 +38,7 @@
TWriter = TypeVar("TWriter", bound="DataWriter")


@dataclass
class FileWriterSpec:
class FileWriterSpec(NamedTuple):
file_format: TLoaderFileFormat
"""format of the output file"""
data_item_format: TDataItemFormat
Expand Down Expand Up @@ -105,13 +109,13 @@ def from_file_format(
f: IO[Any],
caps: DestinationCapabilitiesContext = None,
) -> "DataWriter":
return cls.class_factory(file_format, data_item_format)(f, caps)
return cls.class_factory(file_format, data_item_format, ALL_WRITERS)(f, caps)

@classmethod
def writer_spec_from_file_format(
cls, file_format: TLoaderFileFormat, data_item_format: TDataItemFormat
) -> FileWriterSpec:
return cls.class_factory(file_format, data_item_format).writer_spec()
return cls.class_factory(file_format, data_item_format, ALL_WRITERS).writer_spec()

@classmethod
def item_format_from_file_extension(cls, extension: str) -> TDataItemFormat:
Expand All @@ -123,15 +127,24 @@ def item_format_from_file_extension(cls, extension: str) -> TDataItemFormat:
else:
raise ValueError(f"Cannot figure out data item format for extension {extension}")

@staticmethod
def writer_class_from_spec(spec: FileWriterSpec) -> Type["DataWriter"]:
try:
return WRITER_SPECS[spec]
except KeyError:
raise FileSpecNotFound(spec.file_format, spec.data_item_format, spec)

@staticmethod
def class_factory(
file_format: TLoaderFileFormat, data_item_format: TDataItemFormat
file_format: TLoaderFileFormat,
data_item_format: TDataItemFormat,
writers: Sequence[Type["DataWriter"]],
) -> Type["DataWriter"]:
for writer in ALL_WRITERS:
for writer in writers:
spec = writer.writer_spec()
if spec.file_format == file_format and spec.data_item_format == data_item_format:
return writer
raise DataWriterNotFound(file_format, data_item_format)
raise FileFormatForItemFormatNotFound(file_format, data_item_format)


class JsonlWriter(DataWriter):
Expand Down Expand Up @@ -601,8 +614,7 @@ def write_data(self, rows: Sequence[Any]) -> None:
@staticmethod
def convert_spec(base: Type[DataWriter]) -> FileWriterSpec:
spec = base.writer_spec()
spec.data_item_format = "arrow"
return spec
return spec._replace(data_item_format="arrow")


class ArrowToInsertValuesWriter(ArrowToObjectAdapter, InsertValuesWriter):
Expand All @@ -623,7 +635,14 @@ def writer_spec(cls) -> FileWriterSpec:
return cls.convert_spec(TypedJsonlListWriter)


# ArrowToCsvWriter
def is_native_writer(writer_type: Type[DataWriter]) -> bool:
"""Checks if writer has adapter mixin. Writers with adapters are not native and typically
decrease the performance.
"""
# we only have arrow adapters now
return not issubclass(writer_type, ArrowToObjectAdapter)


ALL_WRITERS: List[Type[DataWriter]] = [
JsonlWriter,
TypedJsonlListWriter,
Expand All @@ -636,3 +655,87 @@ def writer_spec(cls) -> FileWriterSpec:
ArrowToTypedJsonlListWriter,
ArrowToCsvWriter,
]

WRITER_SPECS: Dict[FileWriterSpec, Type[DataWriter]] = {
writer.writer_spec(): writer for writer in ALL_WRITERS
}

NATIVE_FORMAT_WRITERS: Dict[TDataItemFormat, Tuple[Type[DataWriter], ...]] = {
# all "object" writers are native object writers (no adapters yet)
"object": tuple(
writer
for writer in ALL_WRITERS
if writer.writer_spec().data_item_format == "object" and is_native_writer(writer)
),
# exclude arrow adapters
"arrow": tuple(
writer
for writer in ALL_WRITERS
if writer.writer_spec().data_item_format == "arrow" and is_native_writer(writer)
),
}


def resolve_best_writer_spec(
item_format: TDataItemFormat,
possible_file_formats: Sequence[TLoaderFileFormat],
preferred_format: TLoaderFileFormat = None,
) -> FileWriterSpec:
"""Finds best writer for `item_format` out of `possible_file_formats`. Tries `preferred_format` first.
Best possible writer is a native writer for `item_format` writing files in `preferred_format`.
If not found, any native writer for `possible_file_formats` is picked.
Native writer supports `item_format` directly without a need to convert to other item formats.
"""
native_writers = NATIVE_FORMAT_WRITERS[item_format]
# check if preferred format has native item_format writer
if preferred_format:
if preferred_format not in possible_file_formats:
raise ValueError(
f"Preferred format {preferred_format} not possible in {possible_file_formats}"
)
try:
return DataWriter.class_factory(
preferred_format, item_format, native_writers
).writer_spec()
except DataWriterNotFound:
pass
# if not found, use scan native file formats for item format
for supported_format in possible_file_formats:
if supported_format != preferred_format:
try:
return DataWriter.class_factory(
supported_format, item_format, native_writers
).writer_spec()
except DataWriterNotFound:
pass

# search all writers
if preferred_format:
try:
return DataWriter.class_factory(
preferred_format, item_format, ALL_WRITERS
).writer_spec()
except DataWriterNotFound:
pass

for supported_format in possible_file_formats:
if supported_format != preferred_format:
try:
return DataWriter.class_factory(
supported_format, item_format, ALL_WRITERS
).writer_spec()
except DataWriterNotFound:
pass

raise SpecLookupFailed(item_format, possible_file_formats, preferred_format)


def get_best_writer_spec(
item_format: TDataItemFormat, file_format: TLoaderFileFormat
) -> FileWriterSpec:
"""Gets writer for `item_format` writing files in {file_format}. Looks for native writer first"""
native_writers = NATIVE_FORMAT_WRITERS[item_format]
try:
return DataWriter.class_factory(file_format, item_format, native_writers).writer_spec()
except DataWriterNotFound:
return DataWriter.class_factory(file_format, item_format, ALL_WRITERS).writer_spec()
2 changes: 2 additions & 0 deletions dlt/common/destination/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from dlt.common.destination.capabilities import (
DestinationCapabilitiesContext,
merge_caps_file_formats,
TLoaderFileFormat,
ALL_SUPPORTED_FILE_FORMATS,
)
from dlt.common.destination.reference import TDestinationReferenceArg, Destination, TDestination

__all__ = [
"DestinationCapabilitiesContext",
"merge_caps_file_formats",
"TLoaderFileFormat",
"ALL_SUPPORTED_FILE_FORMATS",
"TDestinationReferenceArg",
Expand Down
46 changes: 42 additions & 4 deletions dlt/common/destination/capabilities.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import Any, Callable, ClassVar, List, Literal, Optional, Tuple, Set, get_args
from typing import Any, Callable, ClassVar, List, Literal, Optional, Sequence, Tuple, Set, get_args

from dlt.common.configuration.utils import serialize_value
from dlt.common.configuration import configspec
from dlt.common.configuration.specs import ContainerInjectableContext
from dlt.common.destination.exceptions import (
DestinationIncompatibleLoaderFileFormatException,
DestinationLoadingViaStagingNotSupported,
DestinationLoadingWithoutStagingNotSupported,
)
from dlt.common.utils import identity

from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE
Expand All @@ -22,9 +27,9 @@ class DestinationCapabilitiesContext(ContainerInjectableContext):
"""Injectable destination capabilities required for many Pipeline stages ie. normalize"""

preferred_loader_file_format: TLoaderFileFormat = None
supported_loader_file_formats: List[TLoaderFileFormat] = None
supported_loader_file_formats: Sequence[TLoaderFileFormat] = None
preferred_staging_file_format: Optional[TLoaderFileFormat] = None
supported_staging_file_formats: List[TLoaderFileFormat] = None
supported_staging_file_formats: Sequence[TLoaderFileFormat] = None
escape_identifier: Callable[[str], str] = None
escape_literal: Callable[[Any], Any] = None
decimal_precision: Tuple[int, int] = None
Expand All @@ -46,8 +51,8 @@ class DestinationCapabilitiesContext(ContainerInjectableContext):
insert_values_writer_type: str = "default"
supports_multiple_statements: bool = True
supports_clone_table: bool = False
max_table_nesting: Optional[int] = None # destination can overwrite max table nesting
"""Destination supports CREATE TABLE ... CLONE ... statements"""
max_table_nesting: Optional[int] = None # destination can overwrite max table nesting

# do not allow to create default value, destination caps must be always explicitly inserted into container
can_create_default: ClassVar[bool] = False
Expand Down Expand Up @@ -75,3 +80,36 @@ def generic_capabilities(
caps.supports_transactions = True
caps.supports_multiple_statements = True
return caps


def merge_caps_file_formats(
destination: str,
staging: str,
dest_caps: DestinationCapabilitiesContext,
stage_caps: DestinationCapabilitiesContext,
) -> Tuple[TLoaderFileFormat, Sequence[TLoaderFileFormat]]:
"""Merges preferred and supported file formats from destination and staging.
Returns new preferred file format and all possible formats.
"""
possible_file_formats = dest_caps.supported_loader_file_formats
if stage_caps:
if not dest_caps.supported_staging_file_formats:
raise DestinationLoadingViaStagingNotSupported(destination)
possible_file_formats = [
f
for f in dest_caps.supported_staging_file_formats
if f in stage_caps.supported_loader_file_formats
]
if len(possible_file_formats) == 0:
raise DestinationIncompatibleLoaderFileFormatException(
destination, staging, None, possible_file_formats
)
if not stage_caps:
if not dest_caps.preferred_loader_file_format:
raise DestinationLoadingWithoutStagingNotSupported(destination)
requested_file_format = dest_caps.preferred_loader_file_format
elif stage_caps and dest_caps.preferred_staging_file_format in possible_file_formats:
requested_file_format = dest_caps.preferred_staging_file_format
else:
requested_file_format = possible_file_formats[0] if len(possible_file_formats) > 0 else None
return requested_file_format, possible_file_formats
4 changes: 1 addition & 3 deletions dlt/common/storages/data_item_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
class DataItemStorage(ABC):
def __init__(self, writer_spec: FileWriterSpec, *args: Any) -> None:
self.writer_spec = writer_spec
self.writer_cls = DataWriter.class_factory(
writer_spec.file_format, writer_spec.data_item_format
)
self.writer_cls = DataWriter.writer_class_from_spec(writer_spec)
self.buffered_writers: Dict[str, BufferedDataWriter[DataWriter]] = {}
super().__init__(*args)

Expand Down
Loading

0 comments on commit 0abad12

Please sign in to comment.