Skip to content

Commit

Permalink
custom destination fixes (#1119)
Browse files Browse the repository at this point in the history
* always register SPEC for f when injecting, fixes base tests

* always synthesizes a spec even if fields are not added, keeps the base class fields if sig not annotated

* fixes how base and spec are used in sink factory

* don't fail on destination instantiation if no callable arg provided

* add docstring for decorator
rename config spec
format generic destination example

* allow destination decorator to be used without args
remove some unneded things from the destination tests

* add tests for base spec of custom destination
fix tests for source decorator

* improves sink spec test

---------

Co-authored-by: Dave <[email protected]>
  • Loading branch information
rudolfix and sh-rp authored Mar 21, 2024
1 parent f52e2e4 commit 1f2b4ce
Show file tree
Hide file tree
Showing 12 changed files with 357 additions and 89 deletions.
14 changes: 10 additions & 4 deletions dlt/common/configuration/inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,25 @@ def with_config(
def decorator(f: TFun) -> TFun:
SPEC: Type[BaseConfiguration] = None
sig: Signature = inspect.signature(f)
signature_fields: Dict[str, Any]
kwargs_arg = next(
(p for p in sig.parameters.values() if p.kind == Parameter.VAR_KEYWORD), None
)
spec_arg: Parameter = None
pipeline_name_arg: Parameter = None
if spec is None:
SPEC = spec_from_signature(f, sig, include_defaults, base=base)
SPEC, signature_fields = spec_from_signature(f, sig, include_defaults, base=base)
else:
SPEC = spec
signature_fields = SPEC.get_resolvable_fields()

if SPEC is None:
# if no signature fields were added we will not wrap `f` for injection
if len(signature_fields) == 0:
# always register new function
_FUNC_SPECS[id(f)] = SPEC
return f

spec_arg: Parameter = None
pipeline_name_arg: Parameter = None

for p in sig.parameters.values():
# for all positional parameters that do not have default value, set default
# if hasattr(SPEC, p.name) and p.default == Parameter.empty:
Expand Down
39 changes: 28 additions & 11 deletions dlt/common/reflection/spec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
import inspect
from typing import Dict, List, Type, Any, Optional, NewType
from typing import Dict, List, Tuple, Type, Any, Optional, NewType
from inspect import Signature, Parameter

from dlt.common.typing import AnyType, AnyFun, TSecretValue
Expand Down Expand Up @@ -30,14 +30,27 @@ def spec_from_signature(
sig: Signature,
include_defaults: bool = True,
base: Type[BaseConfiguration] = BaseConfiguration,
) -> Type[BaseConfiguration]:
) -> Tuple[Type[BaseConfiguration], Dict[str, Any]]:
"""Creates a SPEC on base `base1 for a function `f` with signature `sig`.
All the arguments in `sig` that are valid SPEC hints and have defaults will be part of the SPEC.
Special markers for required SPEC fields `dlt.secrets.value` and `dlt.config.value` are parsed using
module source code, which is a hack and will not work for modules not imported from a file.
The name of a SPEC type is inferred from qualname of `f` and type will refer to `f` module and is unique
for a module. NOTE: the SPECS are cached in the module by using name as an id.
Return value is a tuple of SPEC and SPEC fields created from a `sig`.
"""
name = _get_spec_name_from_f(f)
module = inspect.getmodule(f)
base_fields = base.get_resolvable_fields()

# check if spec for that function exists
spec_id = name # f"SPEC_{name}_kw_only_{kw_only}"
if hasattr(module, spec_id):
return getattr(module, spec_id) # type: ignore
MOD_SPEC: Type[BaseConfiguration] = getattr(module, spec_id)
return MOD_SPEC, MOD_SPEC.get_resolvable_fields()

# find all the arguments that have following defaults
literal_defaults: Dict[str, str] = None
Expand All @@ -62,7 +75,8 @@ def dlt_config_literal_to_type(arg_name: str) -> AnyType:
return None

# synthesize configuration from the signature
fields: Dict[str, Any] = {}
new_fields: Dict[str, Any] = {}
sig_base_fields: Dict[str, Any] = {}
annotations: Dict[str, Any] = {}

for p in sig.parameters.values():
Expand All @@ -72,6 +86,10 @@ def dlt_config_literal_to_type(arg_name: str) -> AnyType:
"cls",
]:
field_type = AnyType if p.annotation == Parameter.empty else p.annotation
# keep the base fields if sig not annotated
if p.name in base_fields and field_type is AnyType and p.default is None:
sig_base_fields[p.name] = base_fields[p.name]
continue
# only valid hints and parameters with defaults are eligible
if is_valid_hint(field_type) and p.default != Parameter.empty:
# try to get type from default
Expand Down Expand Up @@ -102,18 +120,17 @@ def dlt_config_literal_to_type(arg_name: str) -> AnyType:
# set annotations
annotations[p.name] = field_type
# set field with default value
fields[p.name] = p.default
new_fields[p.name] = p.default

if not fields:
return None
signature_fields = {**sig_base_fields, **new_fields}

# new type goes to the module where sig was declared
fields["__module__"] = module.__name__
new_fields["__module__"] = module.__name__
# set annotations so they are present in __dict__
fields["__annotations__"] = annotations
new_fields["__annotations__"] = annotations
# synthesize type
T: Type[BaseConfiguration] = type(name, (base,), fields)
T: Type[BaseConfiguration] = type(name, (base,), new_fields)
SPEC = configspec()(T)
# add to the module
setattr(module, spec_id, SPEC)
return SPEC
return SPEC, signature_fields
45 changes: 40 additions & 5 deletions dlt/destinations/decorators.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import functools

from typing import Any, Type, Optional, Callable, Union
from typing import Any, Type, Optional, Callable, Union, cast
from typing_extensions import Concatenate
from dlt.common.typing import AnyFun

from functools import wraps

from dlt.common import logger
from dlt.destinations.impl.destination.factory import destination as _destination
from dlt.destinations.impl.destination.configuration import (
TDestinationCallableParams,
GenericDestinationClientConfiguration,
CustomDestinationClientConfiguration,
)
from dlt.common.destination import TLoaderFileFormat
from dlt.common.destination.reference import Destination
Expand All @@ -18,18 +19,47 @@


def destination(
*,
func: Optional[AnyFun] = None,
/,
loader_file_format: TLoaderFileFormat = None,
batch_size: int = 10,
name: str = None,
naming_convention: str = "direct",
skip_dlt_columns_and_tables: bool = True,
max_table_nesting: int = 0,
spec: Type[GenericDestinationClientConfiguration] = GenericDestinationClientConfiguration,
spec: Type[CustomDestinationClientConfiguration] = None,
) -> Callable[
[Callable[Concatenate[Union[TDataItems, str], TTableSchema, TDestinationCallableParams], Any]],
Callable[TDestinationCallableParams, _destination],
]:
"""A decorator that transforms a function that takes two positional arguments "table" and "items" and any number of keyword arguments with defaults
into a callable that will create a custom destination. The function does not return anything, the keyword arguments can be configuration and secrets values.
#### Example Usage with Configuration and Secrets:
>>> @dlt.destination(batch_size=100, loader_file_format="parquet")
>>> def my_destination(items, table, api_url: str = dlt.config.value, api_secret = dlt.secrets.value):
>>> print(table["name"])
>>> print(items)
>>>
>>> p = dlt.pipeline("chess_pipeline", destination=my_destination)
Here all incoming data will be sent to the destination function with the items in the requested format and the dlt table schema.
The config and secret values will be resolved from the path destination.my_destination.api_url and destination.my_destination.api_secret.
#### Args:
batch_size: defines how many items per function call are batched together and sent as an array. If you set a batch-size of 0, instead of passing in actual dataitems, you will receive one call per load job with the path of the file as the items argument. You can then open and process that file in any way you like.
loader_file_format: defines in which format files are stored in the load package before being sent to the destination function, this can be puae-jsonl or parquet.
name: defines the name of the destination that get's created by the destination decorator, defaults to the name of the function
naming_convention: defines the name of the destination that gets created by the destination decorator. This controls how table and column names are normalized. The default is direct which will keep all names the same.
max_nesting_level: defines how deep the normalizer will go to normalize complex fields on your data to create subtables. This overwrites any settings on your source and is set to zero to not create any nested tables by default.
skip_dlt_columns_and_tables: defines wether internal tables and columns will be fed into the custom destination function. This is set to True by default.
spec: defines a configuration spec that will be used to to inject arguments into the decorated functions. Argument not in spec will not be injected
Returns:
A callable that can be used to create a dlt custom destination instance
"""

def decorator(
destination_callable: Callable[
Concatenate[Union[TDataItems, str], TTableSchema, TDestinationCallableParams], Any
Expand Down Expand Up @@ -58,4 +88,9 @@ def wrapper(

return wrapper

return decorator
if func is None:
# we're called with parens.
return decorator

# we're called as @source without parens.
return decorator(func) # type: ignore
2 changes: 1 addition & 1 deletion dlt/destinations/impl/destination/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


@configspec
class GenericDestinationClientConfiguration(DestinationClientConfiguration):
class CustomDestinationClientConfiguration(DestinationClientConfiguration):
destination_type: Final[str] = "destination" # type: ignore
destination_callable: Optional[Union[str, TDestinationCallable]] = None # noqa: A003
loader_file_format: TLoaderFileFormat = "puae-jsonl"
Expand Down
8 changes: 4 additions & 4 deletions dlt/destinations/impl/destination/destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from dlt.destinations.impl.destination import capabilities
from dlt.destinations.impl.destination.configuration import (
GenericDestinationClientConfiguration,
CustomDestinationClientConfiguration,
TDestinationCallable,
)

Expand All @@ -36,7 +36,7 @@ def __init__(
self,
table: TTableSchema,
file_path: str,
config: GenericDestinationClientConfiguration,
config: CustomDestinationClientConfiguration,
schema: Schema,
destination_state: Dict[str, int],
destination_callable: TDestinationCallable,
Expand Down Expand Up @@ -140,9 +140,9 @@ class DestinationClient(JobClientBase):

capabilities: ClassVar[DestinationCapabilitiesContext] = capabilities()

def __init__(self, schema: Schema, config: GenericDestinationClientConfiguration) -> None:
def __init__(self, schema: Schema, config: CustomDestinationClientConfiguration) -> None:
super().__init__(schema, config)
self.config: GenericDestinationClientConfiguration = config
self.config: CustomDestinationClientConfiguration = config
# create pre-resolved callable to avoid multiple config resolutions during execution of the jobs
self.destination_callable = create_resolved_partial(
cast(AnyFun, self.config.destination_callable), self.config
Expand Down
47 changes: 37 additions & 10 deletions dlt/destinations/impl/destination/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@
from dlt.common.typing import AnyFun

from dlt.common.destination import Destination, DestinationCapabilitiesContext
from dlt.destinations.exceptions import DestinationTransientException
from dlt.common.configuration import known_sections, with_config, get_fun_spec
from dlt.common.configuration.exceptions import ConfigurationValueError
from dlt.common import logger

from dlt.destinations.impl.destination.configuration import (
GenericDestinationClientConfiguration,
CustomDestinationClientConfiguration,
TDestinationCallable,
)
from dlt.destinations.impl.destination import capabilities
from dlt.common.data_writers import TLoaderFileFormat
from dlt.common.utils import get_callable_name
from dlt.common.utils import get_callable_name, is_inner_callable

if t.TYPE_CHECKING:
from dlt.destinations.impl.destination.destination import DestinationClient
Expand All @@ -24,7 +26,7 @@
class DestinationInfo(t.NamedTuple):
"""Runtime information on a discovered destination"""

SPEC: t.Type[GenericDestinationClientConfiguration]
SPEC: t.Type[CustomDestinationClientConfiguration]
f: AnyFun
module: ModuleType

Expand All @@ -33,7 +35,7 @@ class DestinationInfo(t.NamedTuple):
"""A registry of all the decorated destinations"""


class destination(Destination[GenericDestinationClientConfiguration, "DestinationClient"]):
class destination(Destination[CustomDestinationClientConfiguration, "DestinationClient"]):
def capabilities(self) -> DestinationCapabilitiesContext:
return capabilities(
preferred_loader_file_format=self.config_params.get("loader_file_format", "puae-jsonl"),
Expand All @@ -42,7 +44,7 @@ def capabilities(self) -> DestinationCapabilitiesContext:
)

@property
def spec(self) -> t.Type[GenericDestinationClientConfiguration]:
def spec(self) -> t.Type[CustomDestinationClientConfiguration]:
"""A spec of destination configuration resolved from the sink function signature"""
return self._spec

Expand All @@ -60,9 +62,14 @@ def __init__(
loader_file_format: TLoaderFileFormat = None,
batch_size: int = 10,
naming_convention: str = "direct",
spec: t.Type[GenericDestinationClientConfiguration] = GenericDestinationClientConfiguration,
spec: t.Type[CustomDestinationClientConfiguration] = None,
**kwargs: t.Any,
) -> None:
if spec and not issubclass(spec, CustomDestinationClientConfiguration):
raise ValueError(
"A SPEC for a sink destination must use CustomDestinationClientConfiguration as a"
" base."
)
# resolve callable
if callable(destination_callable):
pass
Expand All @@ -81,7 +88,22 @@ def __init__(
f"Could not find callable function at {destination_callable}"
) from e

if not callable(destination_callable):
# provide dummy callable for cases where no callable is provided
# this is needed for cli commands to work
if not destination_callable:
logger.warning(
"No destination callable provided, providing dummy callable which will fail on"
" load."
)

def dummy_callable(*args: t.Any, **kwargs: t.Any) -> None:
raise DestinationTransientException(
"You tried to load to a custom destination without a valid callable."
)

destination_callable = dummy_callable

elif not callable(destination_callable):
raise ConfigurationValueError("Resolved Sink destination callable is not a callable.")

# resolve destination name
Expand All @@ -93,16 +115,21 @@ def __init__(
destination_sections = (known_sections.DESTINATION, destination_name)
conf_callable = with_config(
destination_callable,
spec=spec,
sections=destination_sections,
include_defaults=True,
base=spec,
base=None if spec else CustomDestinationClientConfiguration,
)

# save destination in registry
resolved_spec = t.cast(
t.Type[GenericDestinationClientConfiguration], get_fun_spec(conf_callable)
t.Type[CustomDestinationClientConfiguration], get_fun_spec(conf_callable)
)
_DESTINATIONS[callable.__qualname__] = DestinationInfo(resolved_spec, callable, func_module)
# register only standalone destinations, no inner
if not is_inner_callable(destination_callable):
_DESTINATIONS[destination_callable.__qualname__] = DestinationInfo(
resolved_spec, destination_callable, func_module
)

# remember spec
self._spec = resolved_spec or spec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# format: "your-project.your_dataset.your_table"
BIGQUERY_TABLE_ID = "chat-analytics-rasa-ci.ci_streaming_insert.natural-disasters"


# dlt sources
@dlt.resource(name="natural_disasters")
def resource(url: str):
Expand All @@ -38,6 +39,7 @@ def resource(url: str):
)
yield table


# dlt biquery custom destination
# we can use the dlt provided credentials class
# to retrieve the gcp credentials from the secrets
Expand All @@ -58,6 +60,7 @@ def bigquery_insert(
load_job = client.load_table_from_file(f, BIGQUERY_TABLE_ID, job_config=job_config)
load_job.result() # Waits for the job to complete.


if __name__ == "__main__":
# run the pipeline and print load results
pipeline = dlt.pipeline(
Expand All @@ -68,4 +71,4 @@ def bigquery_insert(
)
load_info = pipeline.run(resource(url=OWID_DISASTERS_URL))

print(load_info)
print(load_info)
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,3 @@ def bigquery_insert(
print(load_info)
# @@@DLT_SNIPPET_END example
assert_load_info(load_info)

Loading

0 comments on commit 1f2b4ce

Please sign in to comment.