diff --git a/tests/libs/pyarrow/__init__.py b/tests/libs/pyarrow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/libs/pyarrow/test_pyarrow.py b/tests/libs/pyarrow/test_pyarrow.py new file mode 100644 index 0000000000..f81b3d1b99 --- /dev/null +++ b/tests/libs/pyarrow/test_pyarrow.py @@ -0,0 +1,191 @@ +from datetime import timezone, datetime, timedelta # noqa: I251 +from copy import deepcopy +from typing import List, Any + +import pytest +import pyarrow as pa + +from dlt.common import pendulum +from dlt.common.libs.pyarrow import ( + py_arrow_to_table_schema_columns, + from_arrow_scalar, + get_py_arrow_timestamp, + to_arrow_scalar, + get_py_arrow_datatype, + remove_null_columns, + remove_columns, + append_column, + rename_columns, + is_arrow_item, +) +from dlt.common.destination import DestinationCapabilitiesContext +from tests.cases import TABLE_UPDATE_COLUMNS_SCHEMA + + +def test_py_arrow_to_table_schema_columns(): + dlt_schema = deepcopy(TABLE_UPDATE_COLUMNS_SCHEMA) + + caps = DestinationCapabilitiesContext.generic_capabilities() + # The arrow schema will add precision + dlt_schema["col4"]["precision"] = caps.timestamp_precision + dlt_schema["col6"]["precision"], dlt_schema["col6"]["scale"] = caps.decimal_precision + dlt_schema["col11"]["precision"] = caps.timestamp_precision + dlt_schema["col4_null"]["precision"] = caps.timestamp_precision + dlt_schema["col6_null"]["precision"], dlt_schema["col6_null"]["scale"] = caps.decimal_precision + dlt_schema["col11_null"]["precision"] = caps.timestamp_precision + + # Ignoring wei as we can't distinguish from decimal + dlt_schema["col8"]["precision"], dlt_schema["col8"]["scale"] = (76, 0) + dlt_schema["col8"]["data_type"] = "decimal" + dlt_schema["col8_null"]["precision"], dlt_schema["col8_null"]["scale"] = (76, 0) + dlt_schema["col8_null"]["data_type"] = "decimal" + # No json type + dlt_schema["col9"]["data_type"] = "text" + del dlt_schema["col9"]["variant"] + dlt_schema["col9_null"]["data_type"] = "text" + del dlt_schema["col9_null"]["variant"] + + # arrow string fields don't have precision + del dlt_schema["col5_precision"]["precision"] + + # Convert to arrow schema + arrow_schema = pa.schema( + [ + pa.field( + column["name"], + get_py_arrow_datatype(column, caps, "UTC"), + nullable=column["nullable"], + ) + for column in dlt_schema.values() + ] + ) + + result = py_arrow_to_table_schema_columns(arrow_schema) + + # Resulting schema should match the original + assert result == dlt_schema + + +def test_to_arrow_scalar() -> None: + naive_dt = get_py_arrow_timestamp(6, tz=None) + # print(naive_dt) + # naive datetimes are converted as UTC when time aware python objects are used + assert to_arrow_scalar(datetime(2021, 1, 1, 5, 2, 32), naive_dt).as_py() == datetime( + 2021, 1, 1, 5, 2, 32 + ) + assert to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone.utc), naive_dt + ).as_py() == datetime(2021, 1, 1, 5, 2, 32) + assert to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), naive_dt + ).as_py() == datetime(2021, 1, 1, 5, 2, 32) + timedelta(hours=8) + + # naive datetimes are treated like UTC + utc_dt = get_py_arrow_timestamp(6, tz="UTC") + dt_converted = to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), utc_dt + ).as_py() + assert dt_converted.utcoffset().seconds == 0 + assert dt_converted == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc) + + berlin_dt = get_py_arrow_timestamp(6, tz="Europe/Berlin") + dt_converted = to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), berlin_dt + ).as_py() + # no dst + assert dt_converted.utcoffset().seconds == 60 * 60 + assert dt_converted == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc) + + +def test_from_arrow_scalar() -> None: + naive_dt = get_py_arrow_timestamp(6, tz=None) + sc_dt = to_arrow_scalar(datetime(2021, 1, 1, 5, 2, 32), naive_dt) + + # this value is like UTC + py_dt = from_arrow_scalar(sc_dt) + assert isinstance(py_dt, pendulum.DateTime) + # and we convert to explicit UTC + assert py_dt == datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone.utc) + + # converts to UTC + berlin_dt = get_py_arrow_timestamp(6, tz="Europe/Berlin") + sc_dt = to_arrow_scalar( + datetime(2021, 1, 1, 5, 2, 32, tzinfo=timezone(timedelta(hours=-8))), berlin_dt + ) + py_dt = from_arrow_scalar(sc_dt) + assert isinstance(py_dt, pendulum.DateTime) + assert py_dt.tzname() == "UTC" + assert py_dt == datetime(2021, 1, 1, 13, 2, 32, tzinfo=timezone.utc) + + +def _row_at_index(table: pa.Table, index: int) -> List[Any]: + return [table.column(column_name)[index].as_py() for column_name in table.column_names] + + +@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch]) +def test_remove_null_columns(pa_type: Any) -> None: + table = pa_type.from_pylist( + [ + {"a": 1, "b": 2, "c": None}, + {"a": 1, "b": None, "c": None}, + ] + ) + result = remove_null_columns(table) + assert result.column_names == ["a", "b"] + assert _row_at_index(result, 0) == [1, 2] + assert _row_at_index(result, 1) == [1, None] + + +@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch]) +def test_remove_columns(pa_type: Any) -> None: + table = pa_type.from_pylist( + [ + {"a": 1, "b": 2, "c": 5}, + {"a": 1, "b": 3, "c": 4}, + ] + ) + result = remove_columns(table, ["b"]) + assert result.column_names == ["a", "c"] + assert _row_at_index(result, 0) == [1, 5] + assert _row_at_index(result, 1) == [1, 4] + + +@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch]) +def test_append_column(pa_type: Any) -> None: + table = pa_type.from_pylist( + [ + {"a": 1, "b": 2}, + {"a": 1, "b": 3}, + ] + ) + result = append_column(table, "c", pa.array([5, 6])) + assert result.column_names == ["a", "b", "c"] + assert _row_at_index(result, 0) == [1, 2, 5] + assert _row_at_index(result, 1) == [1, 3, 6] + + +@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch]) +def test_rename_column(pa_type: Any) -> None: + table = pa_type.from_pylist( + [ + {"a": 1, "b": 2, "c": 5}, + {"a": 1, "b": 3, "c": 4}, + ] + ) + result = rename_columns(table, ["one", "two", "three"]) + assert result.column_names == ["one", "two", "three"] + assert _row_at_index(result, 0) == [1, 2, 5] + assert _row_at_index(result, 1) == [1, 3, 4] + + +@pytest.mark.parametrize("pa_type", [pa.Table, pa.RecordBatch]) +def test_is_arrow_item(pa_type: Any) -> None: + table = pa_type.from_pylist( + [ + {"a": 1, "b": 2, "c": 5}, + {"a": 1, "b": 3, "c": 4}, + ] + ) + assert is_arrow_item(table) + assert not is_arrow_item(table.to_pydict()) + assert not is_arrow_item("hello") diff --git a/tests/libs/pyarrow/test_pyarrow_normalizer.py b/tests/libs/pyarrow/test_pyarrow_normalizer.py new file mode 100644 index 0000000000..25871edd45 --- /dev/null +++ b/tests/libs/pyarrow/test_pyarrow_normalizer.py @@ -0,0 +1,111 @@ +from typing import List, Any + +import pyarrow as pa +import pytest + +from dlt.common.libs.pyarrow import normalize_py_arrow_item, NameNormalizationClash +from dlt.common.normalizers import explicit_normalizers, import_normalizers +from dlt.common.schema.utils import new_column, TColumnSchema +from dlt.common.destination import DestinationCapabilitiesContext + + +def _normalize(table: pa.Table, columns: List[TColumnSchema]) -> pa.Table: + _, naming, _ = import_normalizers(explicit_normalizers()) + caps = DestinationCapabilitiesContext() + columns_schema = {c["name"]: c for c in columns} + return normalize_py_arrow_item(table, columns_schema, naming, caps) + + +def _row_at_index(table: pa.Table, index: int) -> List[Any]: + return [table.column(column_name)[index].as_py() for column_name in table.column_names] + + +def test_quick_return_if_nothing_to_do() -> None: + table = pa.Table.from_pylist( + [ + {"a": 1, "b": 2}, + ] + ) + columns = [new_column("a", "bigint"), new_column("b", "bigint")] + result = _normalize(table, columns) + # same object returned + assert result == table + + +def test_pyarrow_reorder_columns() -> None: + table = pa.Table.from_pylist( + [ + {"col_new": "hello", "col1": 1, "col2": "a"}, + ] + ) + columns = [new_column("col2", "text"), new_column("col1", "bigint")] + result = _normalize(table, columns) + # new columns appear at the end + assert result.column_names == ["col2", "col1", "col_new"] + assert _row_at_index(result, 0) == ["a", 1, "hello"] + + +def test_pyarrow_add_empty_types() -> None: + table = pa.Table.from_pylist( + [ + {"col1": 1}, + ] + ) + columns = [new_column("col1", "bigint"), new_column("col2", "text")] + result = _normalize(table, columns) + # new columns appear at the end + assert result.column_names == ["col1", "col2"] + assert _row_at_index(result, 0) == [1, None] + assert result.schema.field(1).type == "string" + + +def test_field_normalization_clash() -> None: + table = pa.Table.from_pylist( + [ + {"col^New": "hello", "col_new": 1}, + ] + ) + with pytest.raises(NameNormalizationClash): + _normalize(table, []) + + +def test_field_normalization() -> None: + table = pa.Table.from_pylist( + [ + {"col^New": "hello", "col2": 1}, + ] + ) + result = _normalize(table, []) + assert result.column_names == ["col_new", "col2"] + assert _row_at_index(result, 0) == ["hello", 1] + + +def test_default_dlt_columns_not_added() -> None: + table = pa.Table.from_pylist( + [ + {"col1": 1}, + ] + ) + columns = [ + new_column("_dlt_something", "bigint"), + new_column("_dlt_id", "text"), + new_column("_dlt_load_id", "text"), + new_column("col2", "text"), + new_column("col1", "text"), + ] + result = _normalize(table, columns) + # no dlt_id or dlt_load_id columns + assert result.column_names == ["_dlt_something", "col2", "col1"] + assert _row_at_index(result, 0) == [None, None, 1] + + +@pytest.mark.skip(reason="Somehow this does not fail, should we add an exception??") +def test_fails_if_adding_non_nullable_column() -> None: + table = pa.Table.from_pylist( + [ + {"col1": 1}, + ] + ) + columns = [new_column("col1", "bigint"), new_column("col2", "text", nullable=False)] + with pytest.raises(Exception): + _normalize(table, columns) diff --git a/tests/normalize/test_normalize.py b/tests/normalize/test_normalize.py index 91997a921e..3891c667c3 100644 --- a/tests/normalize/test_normalize.py +++ b/tests/normalize/test_normalize.py @@ -6,6 +6,7 @@ from dlt.common import json from dlt.common.destination.capabilities import TLoaderFileFormat from dlt.common.schema.schema import Schema +from dlt.common.schema.utils import new_table from dlt.common.storages.exceptions import SchemaNotFoundError from dlt.common.typing import StrAny from dlt.common.data_types import TDataType @@ -601,7 +602,7 @@ def extract_and_normalize_cases(normalize: Normalize, cases: Sequence[str]) -> s return normalize_pending(normalize) -def normalize_pending(normalize: Normalize) -> str: +def normalize_pending(normalize: Normalize, schema: Schema = None) -> str: # pool not required for map_single load_ids = normalize.normalize_storage.extracted_packages.list_packages() assert len(load_ids) == 1, "Only one package allowed or rewrite tests" @@ -609,7 +610,7 @@ def normalize_pending(normalize: Normalize) -> str: normalize._step_info_start_load_id(load_id) normalize.load_storage.new_packages.create_package(load_id) # read schema from package - schema = normalize.normalize_storage.extracted_packages.load_schema(load_id) + schema = schema or normalize.normalize_storage.extracted_packages.load_schema(load_id) # get files schema_files = normalize.normalize_storage.extracted_packages.list_new_jobs(load_id) # normalize without pool @@ -708,3 +709,40 @@ def assert_timestamp_data_type(load_storage: LoadStorage, data_type: TDataType) event_schema = load_storage.normalized_packages.load_schema(loads[0]) # in raw normalize timestamp column must not be coerced to timestamp assert event_schema.get_table_columns("event")["timestamp"]["data_type"] == data_type + + +def test_removal_of_normalizer_schema_section_and_add_seen_data(raw_normalize: Normalize) -> None: + extract_cases( + raw_normalize, + [ + "event.event.user_load_1", + ], + ) + load_ids = raw_normalize.normalize_storage.extracted_packages.list_packages() + assert len(load_ids) == 1 + extracted_schema = raw_normalize.normalize_storage.extracted_packages.load_schema(load_ids[0]) + + # add some normalizer blocks + extracted_schema.tables["event"] = new_table("event") + extracted_schema.tables["event__parse_data__intent_ranking"] = new_table( + "event__parse_data__intent_ranking" + ) + extracted_schema.tables["event__random_table"] = new_table("event__random_table") + + # add x-normalizer info (and other block to control) + extracted_schema.tables["event"]["x-normalizer"] = {"evolve-columns-once": True} # type: ignore + extracted_schema.tables["event"]["x-other-info"] = "blah" # type: ignore + extracted_schema.tables["event__parse_data__intent_ranking"]["x-normalizer"] = {"seen-data": True, "random-entry": 1234} # type: ignore + extracted_schema.tables["event__random_table"]["x-normalizer"] = {"evolve-columns-once": True} # type: ignore + + normalize_pending(raw_normalize, extracted_schema) + schema = raw_normalize.schema_storage.load_schema("event") + # seen data gets added, schema settings get removed + assert schema.tables["event"]["x-normalizer"] == {"seen-data": True} # type: ignore + assert schema.tables["event__parse_data__intent_ranking"]["x-normalizer"] == { # type: ignore + "seen-data": True, + "random-entry": 1234, + } + # no data seen here, so seen-data is not set and evolve settings stays until first data is seen + assert schema.tables["event__random_table"]["x-normalizer"] == {"evolve-columns-once": True} # type: ignore + assert "x-other-info" in schema.tables["event"] diff --git a/tests/pipeline/test_schema_contracts.py b/tests/pipeline/test_schema_contracts.py index 2dba9d7f6d..579a6289cf 100644 --- a/tests/pipeline/test_schema_contracts.py +++ b/tests/pipeline/test_schema_contracts.py @@ -1,6 +1,6 @@ import dlt, os, pytest import contextlib -from typing import Any, Callable, Iterator, Union, Optional +from typing import Any, Callable, Iterator, Union, Optional, Type from dlt.common.schema.typing import TSchemaContract from dlt.common.utils import uniq_id @@ -9,6 +9,7 @@ from dlt.extract import DltResource from dlt.pipeline.pipeline import Pipeline from dlt.pipeline.exceptions import PipelineStepFailed +from dlt.extract.exceptions import ResourceExtractionError from tests.load.pipeline.utils import load_table_counts from tests.utils import ( @@ -20,23 +21,33 @@ skip_if_not_active("duckdb") -schema_contract = ["evolve", "discard_value", "discard_row", "freeze"] +SCHEMA_CONTRACT = ["evolve", "discard_value", "discard_row", "freeze"] LOCATIONS = ["source", "resource", "override"] SCHEMA_ELEMENTS = ["tables", "columns", "data_type"] +OLD_COLUMN_NAME = "name" +NEW_COLUMN_NAME = "new_col" +VARIANT_COLUMN_NAME = "some_int__v_text" +SUBITEMS_TABLE = "items__sub_items" +NEW_ITEMS_TABLE = "new_items" +ITEMS_TABLE = "items" + + @contextlib.contextmanager -def raises_frozen_exception(check_raise: bool = True) -> Any: +def raises_step_exception(check_raise: bool = True, expected_nested_error: Type[Any] = None) -> Any: + expected_nested_error = expected_nested_error or DataValidationError if not check_raise: yield return with pytest.raises(PipelineStepFailed) as py_exc: yield if py_exc.value.step == "extract": - assert isinstance(py_exc.value.__context__, DataValidationError) + print(type(py_exc.value.__context__)) + assert isinstance(py_exc.value.__context__, expected_nested_error) else: # normalize - assert isinstance(py_exc.value.__context__.__context__, DataValidationError) + assert isinstance(py_exc.value.__context__.__context__, expected_nested_error) def items(settings: TSchemaContract) -> Any: @@ -74,26 +85,51 @@ def load_items(): yield { "id": index, "name": f"item {index}", - "sub_items": [{"id": index + 1000, "name": f"sub item {index + 1000}"}], + "sub_items": [ + {"id": index + 1000, "SomeInt": 5, "name": f"sub item {index + 1000}"} + ], } return load_items -def new_items(settings: TSchemaContract) -> Any: - @dlt.resource(name="new_items", write_disposition="append", schema_contract=settings) +def items_with_new_column_in_subtable(settings: TSchemaContract) -> Any: + @dlt.resource(name="Items", write_disposition="append", schema_contract=settings) def load_items(): for _, index in enumerate(range(0, 10), 1): - yield {"id": index, "some_int": 1, "name": f"item {index}"} + yield { + "id": index, + "name": f"item {index}", + "sub_items": [ + {"id": index + 1000, "name": f"sub item {index + 1000}", "New^Col": "hello"} + ], + } return load_items -OLD_COLUMN_NAME = "name" -NEW_COLUMN_NAME = "new_col" -VARIANT_COLUMN_NAME = "some_int__v_text" -SUBITEMS_TABLE = "items__sub_items" -NEW_ITEMS_TABLE = "new_items" +def items_with_variant_in_subtable(settings: TSchemaContract) -> Any: + @dlt.resource(name="Items", write_disposition="append", schema_contract=settings) + def load_items(): + for _, index in enumerate(range(0, 10), 1): + yield { + "id": index, + "name": f"item {index}", + "sub_items": [ + {"id": index + 1000, "name": f"sub item {index + 1000}", "SomeInt": "hello"} + ], + } + + return load_items + + +def new_items(settings: TSchemaContract) -> Any: + @dlt.resource(name=NEW_ITEMS_TABLE, write_disposition="append", schema_contract=settings) + def load_items(): + for _, index in enumerate(range(0, 10), 1): + yield {"id": index, "some_int": 1, "name": f"item {index}"} + + return load_items def run_resource( @@ -106,10 +142,10 @@ def run_resource( for item in settings.keys(): assert item in LOCATIONS ev_settings = settings[item] - if ev_settings in schema_contract: + if ev_settings in SCHEMA_CONTRACT: continue for key, val in ev_settings.items(): - assert val in schema_contract + assert val in SCHEMA_CONTRACT assert key in SCHEMA_ELEMENTS @dlt.source(name="freeze_tests", schema_contract=settings.get("source")) @@ -130,7 +166,7 @@ def source() -> Iterator[DltResource]: ) # check items table settings - # assert pipeline.default_schema.tables["items"].get("schema_contract", {}) == (settings.get("resource") or {}) + # assert pipeline.default_schema.tables[ITEMS_TABLE].get("schema_contract", {}) == (settings.get("resource") or {}) # check effective table settings # assert resolve_contract_settings_for_table(None, "items", pipeline.default_schema) == expand_schema_contract_settings(settings.get("resource") or settings.get("override") or "evolve") @@ -147,7 +183,7 @@ def get_pipeline(): ) -@pytest.mark.parametrize("contract_setting", schema_contract) +@pytest.mark.parametrize("contract_setting", SCHEMA_CONTRACT) @pytest.mark.parametrize("setting_location", LOCATIONS) @pytest.mark.parametrize("item_format", ALL_TEST_DATA_ITEM_FORMATS) def test_new_tables( @@ -160,23 +196,23 @@ def test_new_tables( table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 10 - assert OLD_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] + assert table_counts[ITEMS_TABLE] == 10 + assert OLD_COLUMN_NAME in pipeline.default_schema.tables[ITEMS_TABLE]["columns"] run_resource(pipeline, items_with_new_column, full_settings, item_format) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 20 - assert NEW_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] + assert table_counts[ITEMS_TABLE] == 20 + assert NEW_COLUMN_NAME in pipeline.default_schema.tables[ITEMS_TABLE]["columns"] # test adding new table - with raises_frozen_exception(contract_setting == "freeze"): + with raises_step_exception(contract_setting == "freeze"): run_resource(pipeline, new_items, full_settings, item_format) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts.get("new_items", 0) == (10 if contract_setting in ["evolve"] else 0) + assert table_counts.get(NEW_ITEMS_TABLE, 0) == (10 if contract_setting in ["evolve"] else 0) # delete extracted files if left after exception pipeline.drop_pending_packages() @@ -187,21 +223,21 @@ def test_new_tables( table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 30 - assert VARIANT_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] + assert table_counts[ITEMS_TABLE] == 30 + assert VARIANT_COLUMN_NAME in pipeline.default_schema.tables[ITEMS_TABLE]["columns"] # test adding new subtable - with raises_frozen_exception(contract_setting == "freeze"): + with raises_step_exception(contract_setting == "freeze"): run_resource(pipeline, items_with_subtable, full_settings) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 30 if contract_setting in ["freeze"] else 40 + assert table_counts[ITEMS_TABLE] == 30 if contract_setting in ["freeze"] else 40 assert table_counts.get(SUBITEMS_TABLE, 0) == (10 if contract_setting in ["evolve"] else 0) -@pytest.mark.parametrize("contract_setting", schema_contract) +@pytest.mark.parametrize("contract_setting", SCHEMA_CONTRACT) @pytest.mark.parametrize("setting_location", LOCATIONS) @pytest.mark.parametrize("item_format", ALL_TEST_DATA_ITEM_FORMATS) def test_new_columns( @@ -214,8 +250,8 @@ def test_new_columns( table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 10 - assert OLD_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] + assert table_counts[ITEMS_TABLE] == 10 + assert OLD_COLUMN_NAME in pipeline.default_schema.tables[ITEMS_TABLE]["columns"] # new should work run_resource(pipeline, new_items, full_settings, item_format) @@ -223,24 +259,24 @@ def test_new_columns( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) expected_items_count = 10 - assert table_counts["items"] == expected_items_count + assert table_counts[ITEMS_TABLE] == expected_items_count assert table_counts[NEW_ITEMS_TABLE] == 10 # test adding new column twice: filter will try to catch it before it is added for the second time - with raises_frozen_exception(contract_setting == "freeze"): + with raises_step_exception(contract_setting == "freeze"): run_resource(pipeline, items_with_new_column, full_settings, item_format, duplicates=2) # delete extracted files if left after exception pipeline.drop_pending_packages() if contract_setting == "evolve": - assert NEW_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] + assert NEW_COLUMN_NAME in pipeline.default_schema.tables[ITEMS_TABLE]["columns"] else: - assert NEW_COLUMN_NAME not in pipeline.default_schema.tables["items"]["columns"] + assert NEW_COLUMN_NAME not in pipeline.default_schema.tables[ITEMS_TABLE]["columns"] table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) expected_items_count += 20 if contract_setting in ["evolve", "discard_value"] else 0 - assert table_counts["items"] == expected_items_count + assert table_counts[ITEMS_TABLE] == expected_items_count # NOTE: arrow / pandas do not support variants and subtables so we must skip if item_format == "object": @@ -250,46 +286,85 @@ def test_new_columns( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) expected_items_count += 10 - assert table_counts["items"] == expected_items_count - assert table_counts[SUBITEMS_TABLE] == 10 + expected_subtable_items_count = 10 + assert table_counts[ITEMS_TABLE] == expected_items_count + assert table_counts[SUBITEMS_TABLE] == expected_subtable_items_count # test adding variant column run_resource(pipeline, items_with_variant, full_settings) # variants are not new columns and should be able to always evolve - assert VARIANT_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] + assert VARIANT_COLUMN_NAME in pipeline.default_schema.tables[ITEMS_TABLE]["columns"] table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) expected_items_count += 10 - assert table_counts["items"] == expected_items_count + assert table_counts[ITEMS_TABLE] == expected_items_count + # test adding new column in subtable (subtable exists already) + with raises_step_exception(contract_setting == "freeze"): + run_resource(pipeline, items_with_new_column_in_subtable, full_settings, item_format) + # delete extracted files if left after exception + pipeline.drop_pending_packages() + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) + # main table only does not get loaded on freeze exception + expected_items_count += 0 if contract_setting in ["freeze"] else 10 + # subtable gets loaded on evolve and discard + expected_subtable_items_count += ( + 10 if contract_setting in ["evolve", "discard_value"] else 0 + ) + assert table_counts[ITEMS_TABLE] == expected_items_count + assert table_counts[SUBITEMS_TABLE] == expected_subtable_items_count + # new column may only appear in evolve mode + if contract_setting == "evolve": + assert NEW_COLUMN_NAME in pipeline.default_schema.tables[SUBITEMS_TABLE]["columns"] + else: + assert NEW_COLUMN_NAME not in pipeline.default_schema.tables[SUBITEMS_TABLE]["columns"] + + # loading variant column will always work in subtable + run_resource(pipeline, items_with_variant_in_subtable, full_settings, item_format) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) + expected_subtable_items_count += 10 + expected_items_count += 10 + assert table_counts[ITEMS_TABLE] == expected_items_count + assert table_counts[SUBITEMS_TABLE] == expected_subtable_items_count + assert VARIANT_COLUMN_NAME in pipeline.default_schema.tables[SUBITEMS_TABLE]["columns"] -@pytest.mark.parametrize("contract_setting", schema_contract) + +@pytest.mark.parametrize("contract_setting", SCHEMA_CONTRACT) @pytest.mark.parametrize("setting_location", LOCATIONS) -def test_freeze_variants(contract_setting: str, setting_location: str) -> None: +def test_variant_columns(contract_setting: str, setting_location: str) -> None: full_settings = {setting_location: {"data_type": contract_setting}} pipeline = get_pipeline() run_resource(pipeline, items, {}) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 10 - assert OLD_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] + expected_items_count = 10 + expected_subtable_items_count = 0 + assert table_counts[ITEMS_TABLE] == expected_items_count + assert OLD_COLUMN_NAME in pipeline.default_schema.tables[ITEMS_TABLE]["columns"] # subtable should work run_resource(pipeline, items_with_subtable, full_settings) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 20 - assert table_counts[SUBITEMS_TABLE] == 10 + expected_items_count += 10 + expected_subtable_items_count += 10 + assert table_counts[ITEMS_TABLE] == expected_items_count + assert table_counts[SUBITEMS_TABLE] == expected_subtable_items_count # new should work run_resource(pipeline, new_items, full_settings) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 20 + assert table_counts[ITEMS_TABLE] == expected_items_count + assert table_counts[SUBITEMS_TABLE] == expected_subtable_items_count assert table_counts[NEW_ITEMS_TABLE] == 10 # test adding new column @@ -297,21 +372,54 @@ def test_freeze_variants(contract_setting: str, setting_location: str) -> None: table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 30 - assert NEW_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] + expected_items_count += 10 + assert table_counts[ITEMS_TABLE] == expected_items_count + assert table_counts[SUBITEMS_TABLE] == expected_subtable_items_count + assert NEW_COLUMN_NAME in pipeline.default_schema.tables[ITEMS_TABLE]["columns"] # test adding variant column - with raises_frozen_exception(contract_setting == "freeze"): + with raises_step_exception(contract_setting == "freeze"): run_resource(pipeline, items_with_variant, full_settings) + pipeline.drop_pending_packages() if contract_setting == "evolve": - assert VARIANT_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] + assert VARIANT_COLUMN_NAME in pipeline.default_schema.tables[ITEMS_TABLE]["columns"] else: - assert VARIANT_COLUMN_NAME not in pipeline.default_schema.tables["items"]["columns"] + assert VARIANT_COLUMN_NAME not in pipeline.default_schema.tables[ITEMS_TABLE]["columns"] table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == (40 if contract_setting in ["evolve", "discard_value"] else 30) + expected_items_count += 10 if contract_setting in ["evolve", "discard_value"] else 0 + assert table_counts[ITEMS_TABLE] == expected_items_count + + # test adding new column in subtable (subtable exists already) + run_resource(pipeline, items_with_new_column_in_subtable, full_settings) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) + expected_items_count += 10 + expected_subtable_items_count += 10 + assert table_counts[ITEMS_TABLE] == expected_items_count + assert table_counts[SUBITEMS_TABLE] == expected_subtable_items_count + assert NEW_COLUMN_NAME in pipeline.default_schema.tables[SUBITEMS_TABLE]["columns"] + + # loading variant column will always work in subtable + with raises_step_exception(contract_setting == "freeze"): + run_resource(pipeline, items_with_variant_in_subtable, full_settings) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) + # main table only does not get loaded on freeze exception + expected_items_count += 0 if contract_setting in ["freeze"] else 10 + # subtable gets loaded on evolve and discard + expected_subtable_items_count += 10 if contract_setting in ["evolve", "discard_value"] else 0 + assert table_counts[ITEMS_TABLE] == expected_items_count + assert table_counts[SUBITEMS_TABLE] == expected_subtable_items_count + # new column may only appear in evolve mode + if contract_setting == "evolve": + assert VARIANT_COLUMN_NAME in pipeline.default_schema.tables[SUBITEMS_TABLE]["columns"] + else: + assert VARIANT_COLUMN_NAME not in pipeline.default_schema.tables[SUBITEMS_TABLE]["columns"] def test_settings_precedence() -> None: @@ -339,14 +447,14 @@ def test_settings_precedence_2() -> None: table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 10 + assert table_counts[ITEMS_TABLE] == 10 # trying to add variant when forbidden on source will fail run_resource(pipeline, items_with_variant, {"source": {"data_type": "discard_row"}}) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 10 + assert table_counts[ITEMS_TABLE] == 10 # if allowed on resource it will pass run_resource( @@ -357,7 +465,7 @@ def test_settings_precedence_2() -> None: table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 20 + assert table_counts[ITEMS_TABLE] == 20 # if allowed on override it will also pass run_resource( @@ -372,7 +480,7 @@ def test_settings_precedence_2() -> None: table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 30 + assert table_counts[ITEMS_TABLE] == 30 @pytest.mark.parametrize("setting_location", LOCATIONS) @@ -384,21 +492,21 @@ def test_change_mode(setting_location: str) -> None: table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 10 + assert table_counts[ITEMS_TABLE] == 10 # trying to add variant when forbidden will fail run_resource(pipeline, items_with_variant, {setting_location: {"data_type": "discard_row"}}) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 10 + assert table_counts[ITEMS_TABLE] == 10 # now allow run_resource(pipeline, items_with_variant, {setting_location: {"data_type": "evolve"}}) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 20 + assert table_counts[ITEMS_TABLE] == 20 @pytest.mark.parametrize("setting_location", LOCATIONS) @@ -409,29 +517,29 @@ def test_single_settings_value(setting_location: str) -> None: table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 10 + assert table_counts[ITEMS_TABLE] == 10 # trying to add variant when forbidden will fail run_resource(pipeline, items_with_variant, {setting_location: "discard_row"}) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 10 + assert table_counts[ITEMS_TABLE] == 10 # trying to add new column will fail run_resource(pipeline, items_with_new_column, {setting_location: "discard_row"}) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 10 + assert table_counts[ITEMS_TABLE] == 10 # trying to add new table will fail run_resource(pipeline, new_items, {setting_location: "discard_row"}) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) - assert table_counts["items"] == 10 - assert "new_items" not in table_counts + assert table_counts[ITEMS_TABLE] == 10 + assert NEW_ITEMS_TABLE not in table_counts def test_data_contract_interaction() -> None: @@ -472,10 +580,6 @@ def get_items_with_model(): def get_items_new_col(): yield from [{"id": 5, "name": "dave", "amount": 6, "new_col": "hello"}] - @dlt.resource(name="items") - def get_items_subtable(): - yield from [{"id": 5, "name": "dave", "amount": 6, "sub": [{"hello": "dave"}]}] - # test valid object pipeline = get_pipeline() # items with model work @@ -485,7 +589,7 @@ def get_items_subtable(): # loading once with pydantic will freeze the cols pipeline = get_pipeline() pipeline.run([get_items_with_model()]) - with raises_frozen_exception(True): + with raises_step_exception(True): pipeline.run([get_items_new_col()]) # it is possible to override contract when there are new columns @@ -505,7 +609,7 @@ def get_items(): yield {"id": 2, "name": "dave", "amount": 50, "new_column": "some val"} pipeline.run([get_items()], schema_contract={"columns": "freeze", "tables": "evolve"}) - assert pipeline.last_trace.last_normalize_info.row_counts["items"] == 2 + assert pipeline.last_trace.last_normalize_info.row_counts[ITEMS_TABLE] == 2 @pytest.mark.parametrize("table_mode", ["discard_row", "evolve", "freeze"]) @@ -524,7 +628,7 @@ def get_items(): } yield {"id": 2, "tables": "two", "new_column": "some val"} - with raises_frozen_exception(table_mode == "freeze"): + with raises_step_exception(table_mode == "freeze"): pipeline.run([get_items()], schema_contract={"tables": table_mode}) if table_mode != "freeze": @@ -589,7 +693,7 @@ def items(): } pipeline.run([items()], schema_contract={"columns": column_mode}) - assert pipeline.last_trace.last_normalize_info.row_counts["items"] == 2 + assert pipeline.last_trace.last_normalize_info.row_counts[ITEMS_TABLE] == 2 @pytest.mark.parametrize("column_mode", ["freeze", "discard_row", "evolve"]) @@ -622,3 +726,93 @@ def get_items(): # apply hints apply to `items` not the original resource, so doing get_items() below removed them completely pipeline.run(items) assert pipeline.last_trace.last_normalize_info.row_counts.get("items", 0) == 2 + + +@pytest.mark.parametrize("contract_setting", SCHEMA_CONTRACT) +@pytest.mark.parametrize("as_list", [True, False]) +def test_pydantic_contract_implementation(contract_setting: str, as_list: bool) -> None: + from pydantic import BaseModel + + class Items(BaseModel): + id: int # noqa: A003 + name: str + + def get_items(as_list: bool = False): + items = [ + { + "id": 5, + "name": "dave", + } + ] + if as_list: + yield items + else: + yield from items + + def get_items_extra_attribute(as_list: bool = False): + items = [{"id": 5, "name": "dave", "blah": "blubb"}] + if as_list: + yield items + else: + yield from items + + def get_items_extra_variant(as_list: bool = False): + items = [ + { + "id": "five", + "name": "dave", + } + ] + if as_list: + yield items + else: + yield from items + + # test columns complying to model + pipeline = get_pipeline() + pipeline.run( + [get_items(as_list)], + schema_contract={"columns": contract_setting}, + columns=Items, + table_name="items", + ) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) + assert table_counts[ITEMS_TABLE] == 1 + + # test columns extra attribute + with raises_step_exception( + contract_setting in ["freeze"], + expected_nested_error=( + ResourceExtractionError if contract_setting == "freeze" else NotImplementedError + ), + ): + pipeline.run( + [get_items_extra_attribute(as_list)], + schema_contract={"columns": contract_setting}, + columns=Items, + table_name="items", + ) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) + assert table_counts[ITEMS_TABLE] == 1 if (contract_setting in ["freeze", "discard_row"]) else 2 + + # test columns with variant + with raises_step_exception( + contract_setting in ["freeze", "discard_value"], + expected_nested_error=( + ResourceExtractionError if contract_setting == "freeze" else NotImplementedError + ), + ): + pipeline.run( + [get_items_extra_variant(as_list)], + schema_contract={"data_type": contract_setting}, + columns=Items, + table_name="items", + ) + table_counts = load_table_counts( + pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] + ) + assert table_counts[ITEMS_TABLE] == 1 if (contract_setting in ["freeze", "discard_row"]) else 3 diff --git a/tests/pipeline/utils.py b/tests/pipeline/utils.py index 036154b582..569ab69bfc 100644 --- a/tests/pipeline/utils.py +++ b/tests/pipeline/utils.py @@ -12,7 +12,7 @@ from dlt.common.typing import DictStrAny from dlt.destinations.impl.filesystem.filesystem import FilesystemClient from dlt.pipeline.exceptions import SqlClientNotAvailable - +from dlt.common.storages import FileStorage from tests.utils import TEST_STORAGE_ROOT PIPELINE_TEST_CASES_PATH = "./tests/pipeline/cases/"