Skip to content

Commit

Permalink
Pydantic validation (#638)
Browse files Browse the repository at this point in the history
* Coerce and encode enum types

* Add pydantic validator step

* Test enum coercion

* Use isinstance in pipe.find to support finding subclasses

* Wrap pydantic ValidationError

* Test extract validation

* Pydantic json type, test special string types

* Validate lists

* Pydantic skip Any fields

* Move validator base class

* Subclass validator from ItemTransform

* Imrpvoe validation error message

* Cleanup
  • Loading branch information
steinitzu authored Sep 19, 2023
1 parent 4e7ef02 commit 390519d
Show file tree
Hide file tree
Showing 13 changed files with 396 additions and 20 deletions.
17 changes: 16 additions & 1 deletion dlt/common/data_types/type_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime # noqa: I251
from collections.abc import Mapping as C_Mapping, Sequence as C_Sequence
from typing import Any, Type, Literal, Union, cast
from enum import Enum

from dlt.common import pendulum, json, Decimal, Wei
from dlt.common.json import custom_pua_remove
Expand Down Expand Up @@ -51,6 +52,13 @@ def py_type_to_sc_type(t: Type[Any]) -> TDataType:
return "binary"
if issubclass(t, (C_Mapping, C_Sequence)):
return "complex"
# Enum is coerced to str or int respectively
if issubclass(t, Enum):
if issubclass(t, int):
return "bigint"
else:
# str subclass and unspecified enum type translates to text
return "text"

raise TypeError(t)

Expand Down Expand Up @@ -83,6 +91,13 @@ def coerce_value(to_type: TDataType, from_type: TDataType, value: Any) -> Any:
if to_type == "complex":
# complex types need custom encoding to be removed
return map_nested_in_place(custom_pua_remove, value)
# Make sure we use enum value instead of the object itself
# This check is faster than `isinstance(value, Enum)` for non-enum types
if hasattr(value, 'value'):
if to_type == "text":
return str(value.value)
elif to_type == "bigint":
return int(value.value)
return value

if to_type == "text":
Expand All @@ -91,7 +106,7 @@ def coerce_value(to_type: TDataType, from_type: TDataType, value: Any) -> Any:
else:
# use the same string encoding as in json
try:
return json_custom_encode(value)
return str(json_custom_encode(value))
except TypeError:
# for other types use internal conversion
return str(value)
Expand Down
6 changes: 6 additions & 0 deletions dlt/common/json/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, List, Protocol, IO, Union
from uuid import UUID
from hexbytes import HexBytes
from enum import Enum

try:
from pydantic import BaseModel as PydanticBaseModel
Expand Down Expand Up @@ -82,6 +83,8 @@ def custom_encode(obj: Any) -> str:
return obj.dict() # type: ignore[return-value]
elif dataclasses.is_dataclass(obj):
return dataclasses.asdict(obj) # type: ignore
elif isinstance(obj, Enum):
return obj.value # type: ignore[no-any-return]
raise TypeError(repr(obj) + " is not JSON serializable")


Expand Down Expand Up @@ -145,6 +148,9 @@ def custom_pua_encode(obj: Any) -> str:
return dataclasses.asdict(obj) # type: ignore
elif PydanticBaseModel and isinstance(obj, PydanticBaseModel):
return obj.dict() # type: ignore[return-value]
elif isinstance(obj, Enum):
# Enum value is just int or str
return obj.value # type: ignore[no-any-return]
raise TypeError(repr(obj) + " is not JSON serializable")


Expand Down
14 changes: 12 additions & 2 deletions dlt/common/libs/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Type, Union, get_type_hints, get_args
from typing import Type, Union, get_type_hints, get_args, Any

from dlt.common.exceptions import MissingDependencyException
from dlt.common.schema.typing import TTableSchemaColumns
from dlt.common.data_types import py_type_to_sc_type, TDataType
from dlt.common.typing import is_optional_type, extract_inner_type, is_list_generic_type, is_dict_generic_type, is_union

try:
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, Json
except ImportError:
raise MissingDependencyException("DLT pydantic Helpers", ["pydantic"], "DLT Helpers for for pydantic.")

Expand All @@ -26,13 +26,23 @@ def pydantic_to_table_schema_columns(model: Union[BaseModel, Type[BaseModel]], s
fields = model.__fields__
for field_name, field in fields.items():
annotation = field.annotation
if inner_annotation := getattr(annotation, 'inner_type', None):
# This applies to pydantic.Json fields, the inner type is the type after json parsing
# (In pydantic 2 the outer annotation is the final type)
annotation = inner_annotation
nullable = is_optional_type(annotation)

if is_union(annotation):
inner_type = get_args(annotation)[0]
else:
inner_type = extract_inner_type(annotation)

if inner_type is Json: # Same as `field: Json[Any]`
inner_type = Any

if inner_type is Any: # Any fields will be inferred from data
continue

if is_list_generic_type(inner_type):
inner_type = list
elif is_dict_generic_type(inner_type) or issubclass(inner_type, BaseModel):
Expand Down
11 changes: 6 additions & 5 deletions dlt/extract/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,10 @@ def resource(
write_disposition (Literal["skip", "append", "replace", "merge"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append".
This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes.
columns (Sequence[TAnySchemaColumns], optional): A list, dict or pydantic model of column schemas. Typed dictionary describing column names, data types, write disposition and performance hints that gives you full control over the created table schema.
This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes.
columns (Sequence[TAnySchemaColumns], optional): A list, dict or pydantic model of column schemas.
Typed dictionary describing column names, data types, write disposition and performance hints that gives you full control over the created table schema.
This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes.
When the argument is a pydantic model, the model will be used to validate the data yielded by the resource as well.
primary_key (str | Sequence[str]): A column name or a list of column names that comprise a private key. Typically used with "merge" write disposition to deduplicate loaded data.
This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes.
Expand All @@ -305,13 +307,12 @@ def resource(
DltResource instance which may be loaded, iterated or combined with other resources into a pipeline.
"""
def make_resource(_name: str, _section: str, _data: Any, incremental: IncrementalResourceWrapper = None) -> DltResource:
schema_columns = ensure_table_schema_columns_hint(columns) if columns is not None else None
table_template = DltResource.new_table_template(
table_name or _name,
write_disposition=write_disposition,
columns=schema_columns,
columns=columns,
primary_key=primary_key,
merge_key=merge_key
merge_key=merge_key,
)
return DltResource.from_data(_data, _name, _section, table_template, selected, cast(DltResource, depends_on), incremental=incremental)

Expand Down
9 changes: 9 additions & 0 deletions dlt/extract/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from dlt.common.exceptions import DltException
from dlt.common.utils import get_callable_name
from dlt.extract.typing import ValidateItem, TDataItems


class ExtractorException(DltException):
Expand Down Expand Up @@ -259,3 +260,11 @@ def __init__(self, source_name: str, schema_name: str) -> None:
class IncrementalUnboundError(DltResourceException):
def __init__(self, cursor_path: str) -> None:
super().__init__("", f"The incremental definition with cursor path {cursor_path} is used without being bound to the resource. This most often happens when you create dynamic resource from a generator function that uses incremental. See https://dlthub.com/docs/general-usage/incremental-loading#incremental-loading-with-last-value for an example.")


class ValidationError(ValueError, DltException):
def __init__(self, validator: ValidateItem, data_item: TDataItems, original_exception: Exception) ->None:
self.original_exception = original_exception
self.validator = validator
self.data_item = data_item
super().__init__(f"Extracted data item could not be validated with {validator}. Original message: {original_exception}")
2 changes: 1 addition & 1 deletion dlt/extract/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def steps(self) -> List[TPipeStep]:

def find(self, *step_type: AnyType) -> int:
"""Finds a step with object of type `step_type`"""
return next((i for i,v in enumerate(self._steps) if type(v) in step_type), -1)
return next((i for i,v in enumerate(self._steps) if isinstance(v, step_type)), -1)

def __getitem__(self, i: int) -> TPipeStep:
return self._steps[i]
Expand Down
19 changes: 15 additions & 4 deletions dlt/extract/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from dlt.common.validation import validate_dict_ignoring_xkeys

from dlt.extract.incremental import Incremental
from dlt.extract.typing import TFunHintTemplate, TTableHintTemplate
from dlt.extract.typing import TFunHintTemplate, TTableHintTemplate, ValidateItem
from dlt.extract.exceptions import DataItemRequiredForDynamicTableHints, InconsistentTableTemplate, TableNameMissing
from dlt.extract.utils import ensure_table_schema_columns, ensure_table_schema_columns_hint
from dlt.extract.validation import get_column_validator


class TTableSchemaTemplate(TypedDict, total=False):
Expand All @@ -24,6 +25,7 @@ class TTableSchemaTemplate(TypedDict, total=False):
primary_key: TTableHintTemplate[TColumnNames]
merge_key: TTableHintTemplate[TColumnNames]
incremental: Incremental[Any]
validator: ValidateItem


class DltResourceSchema:
Expand Down Expand Up @@ -78,6 +80,7 @@ def compute_table_schema(self, item: TDataItem = None) -> TPartialTableSchema:
# resolve
resolved_template: TTableSchemaTemplate = {k: self._resolve_hint(item, v) for k, v in table_template.items()} # type: ignore
resolved_template.pop("incremental", None)
resolved_template.pop("validator", None)
table_schema = self._merge_keys(resolved_template)
table_schema["resource"] = self._name
validate_dict_ignoring_xkeys(
Expand Down Expand Up @@ -129,6 +132,7 @@ def apply_hints(
if write_disposition:
t["write_disposition"] = write_disposition
if columns is not None:
t['validator'] = get_column_validator(columns)
# if callable then override existing
if callable(columns) or callable(t["columns"]):
t["columns"] = ensure_table_schema_columns_hint(columns)
Expand Down Expand Up @@ -206,27 +210,34 @@ def new_table_template(
write_disposition: TTableHintTemplate[TWriteDisposition] = None,
columns: TTableHintTemplate[TAnySchemaColumns] = None,
primary_key: TTableHintTemplate[TColumnNames] = None,
merge_key: TTableHintTemplate[TColumnNames] = None
merge_key: TTableHintTemplate[TColumnNames] = None,
) -> TTableSchemaTemplate:
if not table_name:
raise TableNameMissing()

if columns is not None:
validator = get_column_validator(columns)
columns = ensure_table_schema_columns_hint(columns)
if not callable(columns):
columns = columns.values() # type: ignore
else:
validator = None
# create a table schema template where hints can be functions taking TDataItem
new_template: TTableSchemaTemplate = new_table(table_name, parent_table_name, write_disposition=write_disposition, columns=columns) # type: ignore
new_template: TTableSchemaTemplate = new_table(
table_name, parent_table_name, write_disposition=write_disposition, columns=columns # type: ignore
)
if primary_key:
new_template["primary_key"] = primary_key
if merge_key:
new_template["merge_key"] = merge_key
if validator:
new_template["validator"] = validator
DltResourceSchema.validate_dynamic_hints(new_template)
return new_template

@staticmethod
def validate_dynamic_hints(template: TTableSchemaTemplate) -> None:
table_name = template["name"]
# if any of the hints is a function then name must be as well
if any(callable(v) for k, v in template.items() if k not in ["name", "incremental"]) and not callable(table_name):
if any(callable(v) for k, v in template.items() if k not in ["name", "incremental", "validator"]) and not callable(table_name):
raise InconsistentTableTemplate(f"Table name {table_name} must be a function if any other table hint is a function")
25 changes: 23 additions & 2 deletions dlt/extract/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from copy import copy
import makefun
import inspect
from typing import AsyncIterable, AsyncIterator, ClassVar, Callable, ContextManager, Dict, Iterable, Iterator, List, Sequence, Tuple, Union, Any
from typing import AsyncIterable, AsyncIterator, ClassVar, Callable, ContextManager, Dict, Iterable, Iterator, List, Sequence, Tuple, Union, Any, Optional
import types

from dlt.common.configuration.resolve import inject_section
Expand All @@ -17,7 +17,7 @@
from dlt.common.pipeline import PipelineContext, StateInjectableContext, SupportsPipelineRun, resource_state, source_state, pipeline_state
from dlt.common.utils import graph_find_scc_nodes, flatten_list_or_items, get_callable_name, graph_edges_to_nodes, multi_context_manager, uniq_id

from dlt.extract.typing import DataItemWithMeta, ItemTransformFunc, ItemTransformFunctionWithMeta, TDecompositionStrategy, TableNameMeta, FilterItem, MapItem, YieldMapItem
from dlt.extract.typing import DataItemWithMeta, ItemTransformFunc, ItemTransformFunctionWithMeta, TDecompositionStrategy, TableNameMeta, FilterItem, MapItem, YieldMapItem, ValidateItem
from dlt.extract.pipe import Pipe, ManagedPipeIterator, TPipeStep
from dlt.extract.schema import DltResourceSchema, TTableSchemaTemplate
from dlt.extract.incremental import Incremental, IncrementalResourceWrapper
Expand Down Expand Up @@ -135,6 +135,24 @@ def incremental(self) -> IncrementalResourceWrapper:
incremental = self._pipe.steps[step_no] # type: ignore
return incremental

@property
def validator(self) -> Optional[ValidateItem]:
"""Gets validator transform if it is in the pipe"""
validator: ValidateItem = None
step_no = self._pipe.find(ValidateItem)
if step_no >= 0:
validator = self._pipe.steps[step_no] # type: ignore[assignment]
return validator

@validator.setter
def validator(self, validator: Optional[ValidateItem]) -> None:
"""Add/remove or replace the validator in pipe"""
step_no = self._pipe.find(ValidateItem)
if step_no >= 0:
self._pipe.remove_step(step_no)
if validator:
self.add_step(validator, insert_at=step_no if step_no >= 0 else None)

def pipe_data_from(self, data_from: Union["DltResource", Pipe]) -> None:
"""Replaces the parent in the transformer resource pipe from which the data is piped."""
if self.is_transformer:
Expand Down Expand Up @@ -273,6 +291,9 @@ def set_template(self, table_schema_template: TTableSchemaTemplate) -> None:
if primary_key is not None:
incremental.primary_key = primary_key

if table_schema_template.get('validator') is not None:
self.validator = table_schema_template['validator']

def bind(self, *args: Any, **kwargs: Any) -> "DltResource":
"""Binds the parametrized resource to passed arguments. Modifies resource pipe in place. Does not evaluate generators or iterators."""
if self._bound:
Expand Down
10 changes: 9 additions & 1 deletion dlt/extract/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,12 @@ def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]:
if self._f_meta:
yield from self._f_meta(item, meta)
else:
yield from self._f(item)
yield from self._f(item)


class ValidateItem(ItemTransform[TDataItem]):
"""Base class for validators of data items.
Subclass should implement the `__call__` method to either return the data item(s) or raise `extract.exceptions.ValidationError`.
See `PydanticValidator` for possible implementation.
"""
46 changes: 46 additions & 0 deletions dlt/extract/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional, Protocol, TypeVar, Generic, Type, Union, Any, List

try:
from pydantic import BaseModel as PydanticBaseModel, ValidationError as PydanticValidationError, create_model
except ModuleNotFoundError:
PydanticBaseModel = None # type: ignore[misc]

from dlt.extract.exceptions import ValidationError
from dlt.common.typing import TDataItems
from dlt.common.schema.typing import TAnySchemaColumns
from dlt.extract.typing import TTableHintTemplate, ValidateItem


_TPydanticModel = TypeVar("_TPydanticModel", bound=PydanticBaseModel)


class PydanticValidator(ValidateItem, Generic[_TPydanticModel]):
model: Type[_TPydanticModel]
def __init__(self, model: Type[_TPydanticModel]) -> None:
self.model = model

# Create a model for validating list of items in batch
self.list_model = create_model(
"List" + model.__name__,
items=(List[model], ...) # type: ignore[valid-type]
)

def __call__(self, item: TDataItems, meta: Any = None) -> Union[_TPydanticModel, List[_TPydanticModel]]:
"""Validate a data item against the pydantic model"""
if item is None:
return None
try:
if isinstance(item, list):
return self.list_model(items=item).items # type: ignore[attr-defined, no-any-return]
return self.model.parse_obj(item)
except PydanticValidationError as e:
raise ValidationError(self, item, e) from e

def __str__(self, *args: Any, **kwargs: Any) -> str:
return f"PydanticValidator(model={self.model.__qualname__})"


def get_column_validator(columns: TTableHintTemplate[TAnySchemaColumns]) -> Optional[ValidateItem]:
if PydanticBaseModel is not None and isinstance(columns, type) and issubclass(columns, PydanticBaseModel):
return PydanticValidator(columns)
return None
Loading

0 comments on commit 390519d

Please sign in to comment.