diff --git a/dlt/common/data_types/type_helpers.py b/dlt/common/data_types/type_helpers.py index ffd355d86c..f42f81b06f 100644 --- a/dlt/common/data_types/type_helpers.py +++ b/dlt/common/data_types/type_helpers.py @@ -3,6 +3,7 @@ import datetime # noqa: I251 from collections.abc import Mapping as C_Mapping, Sequence as C_Sequence from typing import Any, Type, Literal, Union, cast +from enum import Enum from dlt.common import pendulum, json, Decimal, Wei from dlt.common.json import custom_pua_remove @@ -51,6 +52,13 @@ def py_type_to_sc_type(t: Type[Any]) -> TDataType: return "binary" if issubclass(t, (C_Mapping, C_Sequence)): return "complex" + # Enum is coerced to str or int respectively + if issubclass(t, Enum): + if issubclass(t, int): + return "bigint" + else: + # str subclass and unspecified enum type translates to text + return "text" raise TypeError(t) @@ -83,6 +91,13 @@ def coerce_value(to_type: TDataType, from_type: TDataType, value: Any) -> Any: if to_type == "complex": # complex types need custom encoding to be removed return map_nested_in_place(custom_pua_remove, value) + # Make sure we use enum value instead of the object itself + # This check is faster than `isinstance(value, Enum)` for non-enum types + if hasattr(value, 'value'): + if to_type == "text": + return str(value.value) + elif to_type == "bigint": + return int(value.value) return value if to_type == "text": @@ -91,7 +106,7 @@ def coerce_value(to_type: TDataType, from_type: TDataType, value: Any) -> Any: else: # use the same string encoding as in json try: - return json_custom_encode(value) + return str(json_custom_encode(value)) except TypeError: # for other types use internal conversion return str(value) diff --git a/dlt/common/json/__init__.py b/dlt/common/json/__init__.py index 0bceb9061e..c31d547efc 100644 --- a/dlt/common/json/__init__.py +++ b/dlt/common/json/__init__.py @@ -5,6 +5,7 @@ from typing import Any, Callable, List, Protocol, IO, Union from uuid import UUID from hexbytes import HexBytes +from enum import Enum try: from pydantic import BaseModel as PydanticBaseModel @@ -82,6 +83,8 @@ def custom_encode(obj: Any) -> str: return obj.dict() # type: ignore[return-value] elif dataclasses.is_dataclass(obj): return dataclasses.asdict(obj) # type: ignore + elif isinstance(obj, Enum): + return obj.value # type: ignore[no-any-return] raise TypeError(repr(obj) + " is not JSON serializable") @@ -145,6 +148,9 @@ def custom_pua_encode(obj: Any) -> str: return dataclasses.asdict(obj) # type: ignore elif PydanticBaseModel and isinstance(obj, PydanticBaseModel): return obj.dict() # type: ignore[return-value] + elif isinstance(obj, Enum): + # Enum value is just int or str + return obj.value # type: ignore[no-any-return] raise TypeError(repr(obj) + " is not JSON serializable") diff --git a/dlt/common/libs/pydantic.py b/dlt/common/libs/pydantic.py index 057cbb57d5..c66d67f1f7 100644 --- a/dlt/common/libs/pydantic.py +++ b/dlt/common/libs/pydantic.py @@ -1,4 +1,4 @@ -from typing import Type, Union, get_type_hints, get_args +from typing import Type, Union, get_type_hints, get_args, Any from dlt.common.exceptions import MissingDependencyException from dlt.common.schema.typing import TTableSchemaColumns @@ -6,7 +6,7 @@ from dlt.common.typing import is_optional_type, extract_inner_type, is_list_generic_type, is_dict_generic_type, is_union try: - from pydantic import BaseModel, Field + from pydantic import BaseModel, Field, Json except ImportError: raise MissingDependencyException("DLT pydantic Helpers", ["pydantic"], "DLT Helpers for for pydantic.") @@ -26,6 +26,10 @@ def pydantic_to_table_schema_columns(model: Union[BaseModel, Type[BaseModel]], s fields = model.__fields__ for field_name, field in fields.items(): annotation = field.annotation + if inner_annotation := getattr(annotation, 'inner_type', None): + # This applies to pydantic.Json fields, the inner type is the type after json parsing + # (In pydantic 2 the outer annotation is the final type) + annotation = inner_annotation nullable = is_optional_type(annotation) if is_union(annotation): @@ -33,6 +37,12 @@ def pydantic_to_table_schema_columns(model: Union[BaseModel, Type[BaseModel]], s else: inner_type = extract_inner_type(annotation) + if inner_type is Json: # Same as `field: Json[Any]` + inner_type = Any + + if inner_type is Any: # Any fields will be inferred from data + continue + if is_list_generic_type(inner_type): inner_type = list elif is_dict_generic_type(inner_type) or issubclass(inner_type, BaseModel): diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index dd756b1e6b..e769cc74e4 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -282,8 +282,10 @@ def resource( write_disposition (Literal["skip", "append", "replace", "merge"], optional): Controls how to write data to a table. `append` will always add new data at the end of the table. `replace` will replace existing data with new data. `skip` will prevent data from loading. "merge" will deduplicate and merge data based on "primary_key" and "merge_key" hints. Defaults to "append". This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. - columns (Sequence[TAnySchemaColumns], optional): A list, dict or pydantic model of column schemas. Typed dictionary describing column names, data types, write disposition and performance hints that gives you full control over the created table schema. - This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + columns (Sequence[TAnySchemaColumns], optional): A list, dict or pydantic model of column schemas. + Typed dictionary describing column names, data types, write disposition and performance hints that gives you full control over the created table schema. + This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. + When the argument is a pydantic model, the model will be used to validate the data yielded by the resource as well. primary_key (str | Sequence[str]): A column name or a list of column names that comprise a private key. Typically used with "merge" write disposition to deduplicate loaded data. This argument also accepts a callable that is used to dynamically create tables for stream-like resources yielding many datatypes. @@ -305,13 +307,12 @@ def resource( DltResource instance which may be loaded, iterated or combined with other resources into a pipeline. """ def make_resource(_name: str, _section: str, _data: Any, incremental: IncrementalResourceWrapper = None) -> DltResource: - schema_columns = ensure_table_schema_columns_hint(columns) if columns is not None else None table_template = DltResource.new_table_template( table_name or _name, write_disposition=write_disposition, - columns=schema_columns, + columns=columns, primary_key=primary_key, - merge_key=merge_key + merge_key=merge_key, ) return DltResource.from_data(_data, _name, _section, table_template, selected, cast(DltResource, depends_on), incremental=incremental) diff --git a/dlt/extract/exceptions.py b/dlt/extract/exceptions.py index 79329f2107..4a4b17967d 100644 --- a/dlt/extract/exceptions.py +++ b/dlt/extract/exceptions.py @@ -3,6 +3,7 @@ from dlt.common.exceptions import DltException from dlt.common.utils import get_callable_name +from dlt.extract.typing import ValidateItem, TDataItems class ExtractorException(DltException): @@ -259,3 +260,11 @@ def __init__(self, source_name: str, schema_name: str) -> None: class IncrementalUnboundError(DltResourceException): def __init__(self, cursor_path: str) -> None: super().__init__("", f"The incremental definition with cursor path {cursor_path} is used without being bound to the resource. This most often happens when you create dynamic resource from a generator function that uses incremental. See https://dlthub.com/docs/general-usage/incremental-loading#incremental-loading-with-last-value for an example.") + + +class ValidationError(ValueError, DltException): + def __init__(self, validator: ValidateItem, data_item: TDataItems, original_exception: Exception) ->None: + self.original_exception = original_exception + self.validator = validator + self.data_item = data_item + super().__init__(f"Extracted data item could not be validated with {validator}. Original message: {original_exception}") diff --git a/dlt/extract/pipe.py b/dlt/extract/pipe.py index 9e5b20374f..2c1cebe177 100644 --- a/dlt/extract/pipe.py +++ b/dlt/extract/pipe.py @@ -143,7 +143,7 @@ def steps(self) -> List[TPipeStep]: def find(self, *step_type: AnyType) -> int: """Finds a step with object of type `step_type`""" - return next((i for i,v in enumerate(self._steps) if type(v) in step_type), -1) + return next((i for i,v in enumerate(self._steps) if isinstance(v, step_type)), -1) def __getitem__(self, i: int) -> TPipeStep: return self._steps[i] diff --git a/dlt/extract/schema.py b/dlt/extract/schema.py index 3149a37c12..80e9f6f32f 100644 --- a/dlt/extract/schema.py +++ b/dlt/extract/schema.py @@ -9,9 +9,10 @@ from dlt.common.validation import validate_dict_ignoring_xkeys from dlt.extract.incremental import Incremental -from dlt.extract.typing import TFunHintTemplate, TTableHintTemplate +from dlt.extract.typing import TFunHintTemplate, TTableHintTemplate, ValidateItem from dlt.extract.exceptions import DataItemRequiredForDynamicTableHints, InconsistentTableTemplate, TableNameMissing from dlt.extract.utils import ensure_table_schema_columns, ensure_table_schema_columns_hint +from dlt.extract.validation import get_column_validator class TTableSchemaTemplate(TypedDict, total=False): @@ -24,6 +25,7 @@ class TTableSchemaTemplate(TypedDict, total=False): primary_key: TTableHintTemplate[TColumnNames] merge_key: TTableHintTemplate[TColumnNames] incremental: Incremental[Any] + validator: ValidateItem class DltResourceSchema: @@ -78,6 +80,7 @@ def compute_table_schema(self, item: TDataItem = None) -> TPartialTableSchema: # resolve resolved_template: TTableSchemaTemplate = {k: self._resolve_hint(item, v) for k, v in table_template.items()} # type: ignore resolved_template.pop("incremental", None) + resolved_template.pop("validator", None) table_schema = self._merge_keys(resolved_template) table_schema["resource"] = self._name validate_dict_ignoring_xkeys( @@ -129,6 +132,7 @@ def apply_hints( if write_disposition: t["write_disposition"] = write_disposition if columns is not None: + t['validator'] = get_column_validator(columns) # if callable then override existing if callable(columns) or callable(t["columns"]): t["columns"] = ensure_table_schema_columns_hint(columns) @@ -206,21 +210,28 @@ def new_table_template( write_disposition: TTableHintTemplate[TWriteDisposition] = None, columns: TTableHintTemplate[TAnySchemaColumns] = None, primary_key: TTableHintTemplate[TColumnNames] = None, - merge_key: TTableHintTemplate[TColumnNames] = None + merge_key: TTableHintTemplate[TColumnNames] = None, ) -> TTableSchemaTemplate: if not table_name: raise TableNameMissing() if columns is not None: + validator = get_column_validator(columns) columns = ensure_table_schema_columns_hint(columns) if not callable(columns): columns = columns.values() # type: ignore + else: + validator = None # create a table schema template where hints can be functions taking TDataItem - new_template: TTableSchemaTemplate = new_table(table_name, parent_table_name, write_disposition=write_disposition, columns=columns) # type: ignore + new_template: TTableSchemaTemplate = new_table( + table_name, parent_table_name, write_disposition=write_disposition, columns=columns # type: ignore + ) if primary_key: new_template["primary_key"] = primary_key if merge_key: new_template["merge_key"] = merge_key + if validator: + new_template["validator"] = validator DltResourceSchema.validate_dynamic_hints(new_template) return new_template @@ -228,5 +239,5 @@ def new_table_template( def validate_dynamic_hints(template: TTableSchemaTemplate) -> None: table_name = template["name"] # if any of the hints is a function then name must be as well - if any(callable(v) for k, v in template.items() if k not in ["name", "incremental"]) and not callable(table_name): + if any(callable(v) for k, v in template.items() if k not in ["name", "incremental", "validator"]) and not callable(table_name): raise InconsistentTableTemplate(f"Table name {table_name} must be a function if any other table hint is a function") diff --git a/dlt/extract/source.py b/dlt/extract/source.py index 52a0381dfe..39dcfc762c 100644 --- a/dlt/extract/source.py +++ b/dlt/extract/source.py @@ -3,7 +3,7 @@ from copy import copy import makefun import inspect -from typing import AsyncIterable, AsyncIterator, ClassVar, Callable, ContextManager, Dict, Iterable, Iterator, List, Sequence, Tuple, Union, Any +from typing import AsyncIterable, AsyncIterator, ClassVar, Callable, ContextManager, Dict, Iterable, Iterator, List, Sequence, Tuple, Union, Any, Optional import types from dlt.common.configuration.resolve import inject_section @@ -17,7 +17,7 @@ from dlt.common.pipeline import PipelineContext, StateInjectableContext, SupportsPipelineRun, resource_state, source_state, pipeline_state from dlt.common.utils import graph_find_scc_nodes, flatten_list_or_items, get_callable_name, graph_edges_to_nodes, multi_context_manager, uniq_id -from dlt.extract.typing import DataItemWithMeta, ItemTransformFunc, ItemTransformFunctionWithMeta, TDecompositionStrategy, TableNameMeta, FilterItem, MapItem, YieldMapItem +from dlt.extract.typing import DataItemWithMeta, ItemTransformFunc, ItemTransformFunctionWithMeta, TDecompositionStrategy, TableNameMeta, FilterItem, MapItem, YieldMapItem, ValidateItem from dlt.extract.pipe import Pipe, ManagedPipeIterator, TPipeStep from dlt.extract.schema import DltResourceSchema, TTableSchemaTemplate from dlt.extract.incremental import Incremental, IncrementalResourceWrapper @@ -135,6 +135,24 @@ def incremental(self) -> IncrementalResourceWrapper: incremental = self._pipe.steps[step_no] # type: ignore return incremental + @property + def validator(self) -> Optional[ValidateItem]: + """Gets validator transform if it is in the pipe""" + validator: ValidateItem = None + step_no = self._pipe.find(ValidateItem) + if step_no >= 0: + validator = self._pipe.steps[step_no] # type: ignore[assignment] + return validator + + @validator.setter + def validator(self, validator: Optional[ValidateItem]) -> None: + """Add/remove or replace the validator in pipe""" + step_no = self._pipe.find(ValidateItem) + if step_no >= 0: + self._pipe.remove_step(step_no) + if validator: + self.add_step(validator, insert_at=step_no if step_no >= 0 else None) + def pipe_data_from(self, data_from: Union["DltResource", Pipe]) -> None: """Replaces the parent in the transformer resource pipe from which the data is piped.""" if self.is_transformer: @@ -273,6 +291,9 @@ def set_template(self, table_schema_template: TTableSchemaTemplate) -> None: if primary_key is not None: incremental.primary_key = primary_key + if table_schema_template.get('validator') is not None: + self.validator = table_schema_template['validator'] + def bind(self, *args: Any, **kwargs: Any) -> "DltResource": """Binds the parametrized resource to passed arguments. Modifies resource pipe in place. Does not evaluate generators or iterators.""" if self._bound: diff --git a/dlt/extract/typing.py b/dlt/extract/typing.py index a8608021ba..5f32556f92 100644 --- a/dlt/extract/typing.py +++ b/dlt/extract/typing.py @@ -123,4 +123,12 @@ def __call__(self, item: TDataItems, meta: Any = None) -> Optional[TDataItems]: if self._f_meta: yield from self._f_meta(item, meta) else: - yield from self._f(item) \ No newline at end of file + yield from self._f(item) + + +class ValidateItem(ItemTransform[TDataItem]): + """Base class for validators of data items. + + Subclass should implement the `__call__` method to either return the data item(s) or raise `extract.exceptions.ValidationError`. + See `PydanticValidator` for possible implementation. + """ diff --git a/dlt/extract/validation.py b/dlt/extract/validation.py new file mode 100644 index 0000000000..c8e30d0eb2 --- /dev/null +++ b/dlt/extract/validation.py @@ -0,0 +1,46 @@ +from typing import Optional, Protocol, TypeVar, Generic, Type, Union, Any, List + +try: + from pydantic import BaseModel as PydanticBaseModel, ValidationError as PydanticValidationError, create_model +except ModuleNotFoundError: + PydanticBaseModel = None # type: ignore[misc] + +from dlt.extract.exceptions import ValidationError +from dlt.common.typing import TDataItems +from dlt.common.schema.typing import TAnySchemaColumns +from dlt.extract.typing import TTableHintTemplate, ValidateItem + + +_TPydanticModel = TypeVar("_TPydanticModel", bound=PydanticBaseModel) + + +class PydanticValidator(ValidateItem, Generic[_TPydanticModel]): + model: Type[_TPydanticModel] + def __init__(self, model: Type[_TPydanticModel]) -> None: + self.model = model + + # Create a model for validating list of items in batch + self.list_model = create_model( + "List" + model.__name__, + items=(List[model], ...) # type: ignore[valid-type] + ) + + def __call__(self, item: TDataItems, meta: Any = None) -> Union[_TPydanticModel, List[_TPydanticModel]]: + """Validate a data item against the pydantic model""" + if item is None: + return None + try: + if isinstance(item, list): + return self.list_model(items=item).items # type: ignore[attr-defined, no-any-return] + return self.model.parse_obj(item) + except PydanticValidationError as e: + raise ValidationError(self, item, e) from e + + def __str__(self, *args: Any, **kwargs: Any) -> str: + return f"PydanticValidator(model={self.model.__qualname__})" + + +def get_column_validator(columns: TTableHintTemplate[TAnySchemaColumns]) -> Optional[ValidateItem]: + if PydanticBaseModel is not None and isinstance(columns, type) and issubclass(columns, PydanticBaseModel): + return PydanticValidator(columns) + return None diff --git a/tests/common/schema/test_coercion.py b/tests/common/schema/test_coercion.py index ed976f2af7..dd9dcd1ae0 100644 --- a/tests/common/schema/test_coercion.py +++ b/tests/common/schema/test_coercion.py @@ -4,6 +4,7 @@ import pytest import datetime # noqa: I251 from hexbytes import HexBytes +from enum import Enum from pendulum.tz import UTC @@ -32,6 +33,24 @@ def test_coerce_type_to_text() -> None: assert coerce_value("text", "binary", b'binary string') == "YmluYXJ5IHN0cmluZw==" # HexBytes to text (hex with prefix) assert coerce_value("text", "binary", HexBytes(b'binary string')) == "0x62696e61727920737472696e67" + # Str enum value + class StrEnum(Enum): + a = "a_value" + b = "b_value" + + str_enum_result = coerce_value("text", "text", StrEnum.b) + # Make sure we get the bare str value, not the enum instance + assert not isinstance(str_enum_result, Enum) + assert str_enum_result == "b_value" + # Mixed enum value + class MixedEnum(Enum): + a = "a_value" + b = 1 + + mixed_enum_result = coerce_value("text", "text", MixedEnum.b) + # Make sure we get the bare str value, not the enum instance + assert not isinstance(mixed_enum_result, Enum) + assert mixed_enum_result == "1" def test_coerce_type_to_bool() -> None: @@ -93,6 +112,16 @@ def test_coerce_type_to_bigint() -> None: with pytest.raises(ValueError): coerce_value("bigint", "text", "912.12") + # Int enum value + class IntEnum(int, Enum): + a = 1 + b = 2 + + int_enum_result = coerce_value("bigint", "bigint", IntEnum.b) + # Make sure we get the bare int value, not the enum instance + assert not isinstance(int_enum_result, Enum) + assert int_enum_result == 2 + @pytest.mark.parametrize("dec_cls,data_type", [ (Decimal, "decimal"), @@ -280,6 +309,22 @@ def test_py_type_to_sc_type() -> None: assert py_type_to_sc_type(Mapping) == "complex" assert py_type_to_sc_type(MutableSequence) == "complex" + class IntEnum(int, Enum): + a = 1 + b = 2 + + class StrEnum(str, Enum): + a = "a" + b = "b" + + class MixedEnum(Enum): + a = 1 + b = "b" + + assert py_type_to_sc_type(IntEnum) == "bigint" + assert py_type_to_sc_type(StrEnum) == "text" + assert py_type_to_sc_type(MixedEnum) == "text" + def test_coerce_type_complex() -> None: # dicts and lists should be coerced into strings automatically diff --git a/tests/common/test_pydantic.py b/tests/common/test_pydantic.py index 66e7b13f0b..770fcce6e5 100644 --- a/tests/common/test_pydantic.py +++ b/tests/common/test_pydantic.py @@ -1,13 +1,33 @@ import pytest -from typing import Union, Optional, List, Dict +from typing import Union, Optional, List, Dict, Any +from enum import Enum from datetime import datetime, date, time # noqa: I251 from dlt.common import Decimal +from dlt.common import json -from pydantic import BaseModel +from pydantic import BaseModel, Json, AnyHttpUrl from dlt.common.libs.pydantic import pydantic_to_table_schema_columns +class StrEnum(str, Enum): + a = "a_value" + b = "b_value" + c = "c_value" + + +class IntEnum(int, Enum): + a = 0 + b = 1 + c = 2 + + +class MixedEnum(Enum): + a_int = 0 + b_str = "b_value" + c_int = 2 + + class NestedModel(BaseModel): nested_field: str @@ -31,6 +51,20 @@ class Model(BaseModel): blank_dict_field: dict # type: ignore[type-arg] parametrized_dict_field: Dict[str, int] + str_enum_field: StrEnum + int_enum_field: IntEnum + # Both of these shouold coerce to str + mixed_enum_int_field: MixedEnum + mixed_enum_str_field: MixedEnum + + json_field: Json[List[str]] + + url_field: AnyHttpUrl + + any_field: Any + json_any_field: Json[Any] + + @pytest.mark.parametrize('instance', [True, False]) def test_pydantic_model_to_columns(instance: bool) -> None: @@ -44,7 +78,15 @@ def test_pydantic_model_to_columns(instance: bool) -> None: union_field=1, optional_field=None, blank_dict_field={}, - parametrized_dict_field={"a": 1, "b": 2, "c": 3} + parametrized_dict_field={"a": 1, "b": 2, "c": 3}, + str_enum_field=StrEnum.a, + int_enum_field=IntEnum.a, + mixed_enum_int_field=MixedEnum.a_int, + mixed_enum_str_field=MixedEnum.b_str, + json_field=json.dumps(["a", "b", "c"]), # type: ignore[arg-type] + url_field="https://example.com", # type: ignore[arg-type] + any_field="any_string", + json_any_field=json.dumps("any_string"), ) else: model = Model # type: ignore[assignment] @@ -65,6 +107,16 @@ def test_pydantic_model_to_columns(instance: bool) -> None: assert result['optional_field']['nullable'] is True assert result['blank_dict_field']['data_type'] == 'complex' assert result['parametrized_dict_field']['data_type'] == 'complex' + assert result['str_enum_field']['data_type'] == 'text' + assert result['int_enum_field']['data_type'] == 'bigint' + assert result['mixed_enum_int_field']['data_type'] == 'text' + assert result['mixed_enum_str_field']['data_type'] == 'text' + assert result['json_field']['data_type'] == 'complex' + assert result['url_field']['data_type'] == 'text' + + # Any type fields are excluded from schema + assert 'any_field' not in result + assert 'json_any_field' not in result def test_pydantic_model_skip_complex_types() -> None: @@ -76,7 +128,7 @@ def test_pydantic_model_skip_complex_types() -> None: assert "list_field" not in result assert "blank_dict_field" not in result assert "parametrized_dict_field" not in result + assert "json_field" not in result assert result["bigint_field"]["data_type"] == "bigint" assert result["text_field"]["data_type"] == "text" assert result["timestamp_field"]["data_type"] == "timestamp" - diff --git a/tests/extract/test_validation.py b/tests/extract/test_validation.py new file mode 100644 index 0000000000..64e06bcecc --- /dev/null +++ b/tests/extract/test_validation.py @@ -0,0 +1,152 @@ +"""Tests for resource validation with pydantic schema +""" +import typing as t + +import pytest +import dlt +from dlt.extract.typing import ValidateItem +from dlt.common.typing import TDataItems +from dlt.extract.validation import PydanticValidator +from dlt.extract.exceptions import ValidationError, ResourceExtractionError + +from pydantic import BaseModel + + +class SimpleModel(BaseModel): + a: int + b: str + + +@pytest.mark.parametrize("yield_list", [True, False]) +def test_validator_model_in_decorator(yield_list: bool) -> None: + # model passed in decorator + @dlt.resource(columns=SimpleModel) + def some_data() -> t.Iterator[TDataItems]: + items = [{"a": 1, "b": "2"}, {"a": 2, "b": "3"}] + if yield_list: + yield items + else: + yield from items + + # Items are passed through model + data = list(some_data()) + assert data == [SimpleModel(a=1, b="2"), SimpleModel(a=2, b="3")] + + +@pytest.mark.parametrize("yield_list", [True, False]) +def test_validator_model_in_apply_hints(yield_list: bool) -> None: + # model passed in apply_hints + + @dlt.resource + def some_data() -> t.Iterator[TDataItems]: + items = [{"a": 1, "b": "2"}, {"a": 2, "b": "3"}] + if yield_list: + yield items + else: + yield from items + + resource = some_data() + resource.apply_hints(columns=SimpleModel) + + # Items are passed through model + data = list(resource) + assert data == [SimpleModel(a=1, b="2"), SimpleModel(a=2, b="3")] + + +@pytest.mark.parametrize("yield_list", [True, False]) +def test_remove_validator(yield_list: bool) -> None: + + @dlt.resource(columns=SimpleModel) + def some_data() -> t.Iterator[TDataItems]: + items = [{"a": 1, "b": "2"}, {"a": 2, "b": "3"}] + if yield_list: + yield items + else: + yield from items + + resource = some_data() + resource.validator = None + + data = list(resource) + assert data == [{"a": 1, "b": "2"}, {"a": 2, "b": "3"}] + + +@pytest.mark.parametrize("yield_list", [True, False]) +def test_replace_validator_model(yield_list: bool) -> None: + + @dlt.resource(columns=SimpleModel) + def some_data() -> t.Iterator[TDataItems]: + items = [{"a": 1, "b": "2"}, {"a": 2, "b": "3"}] + if yield_list: + yield items + else: + yield from items + + resource = some_data() + + class AnotherModel(BaseModel): + a: int + b: str + c: float = 0.5 + + # Use apply_hints to replace the validator + resource.apply_hints(columns=AnotherModel) + + data = list(resource) + # Items are validated with the new model + assert data == [AnotherModel(a=1, b="2", c=0.5), AnotherModel(a=2, b="3", c=0.5)] + + # Ensure only one validator is applied in steps + steps = resource._pipe.steps + assert len(steps) == 2 + + assert isinstance(steps[-1], ValidateItem) + assert steps[-1].model is AnotherModel # type: ignore[attr-defined] + + +@pytest.mark.parametrize("yield_list", [True, False]) +def test_validator_property_setter(yield_list: bool) -> None: + + @dlt.resource(columns=SimpleModel) + def some_data() -> t.Iterator[TDataItems]: + items = [{"a": 1, "b": "2"}, {"a": 2, "b": "3"}] + if yield_list: + yield items + else: + yield from items + + resource = some_data() + + assert isinstance(resource.validator, PydanticValidator) and resource.validator.model is SimpleModel + + class AnotherModel(BaseModel): + a: int + b: str + c: float = 0.5 + + resource.validator = PydanticValidator(AnotherModel) + + assert resource.validator and resource.validator.model is AnotherModel + + data = list(resource) + # Items are validated with the new model + assert data == [AnotherModel(a=1, b="2", c=0.5), AnotherModel(a=2, b="3", c=0.5)] + + +@pytest.mark.parametrize("yield_list", [True, False]) +def test_failed_validation(yield_list: bool) -> None: + @dlt.resource(columns=SimpleModel) + def some_data() -> t.Iterator[TDataItems]: + # yield item that fails schema validation + items = [{"a": 1, "b": "z"}, {"a": "not_int", "b": "x"}] + if yield_list: + yield items + else: + yield from items + + # extraction fails with ValidationError + with pytest.raises(ResourceExtractionError) as exinfo: + list(some_data()) + + assert isinstance(exinfo.value.__cause__, ValidationError) + assert str(PydanticValidator(SimpleModel)) in str(exinfo.value)