Skip to content

Commit

Permalink
finish contracts tests for now
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Jan 10, 2024
1 parent c661b35 commit 6ceeb15
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 6 deletions.
3 changes: 3 additions & 0 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any:
result = f(self, *args, **kwargs)
plugins_ctx.on_step_end(f.__name__, self)

# ensure messages queue is completely processed
plugins_ctx.process_queue()

return result

return _wrap # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion tests/load/pipeline/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def complex_data():
staging=destination_config.staging,
dataset_name="ds_" + uniq_id(),
)
print(info)

with dlt.pipeline().sql_client() as client:
complex_cn_table = client.make_qualified_table_name("complex_cn")
rows = select_data(dlt.pipeline(), f"SELECT cn FROM {complex_cn_table}")
Expand Down
81 changes: 76 additions & 5 deletions tests/pipeline/test_schema_contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import contextlib
from typing import Any, Callable, Iterator, Union, Optional, List, Dict, cast
from collections.abc import Iterable
from dlt.common.pipeline import SupportsPipeline
from dlt.common.schema.typing import TSchemaContract
from dlt.common.utils import uniq_id
from dlt.common.schema.exceptions import DataValidationError
Expand Down Expand Up @@ -123,9 +124,6 @@ def source() -> Iterator[DltResource]:
resource.table_name = resource.name
yield resource.with_name(resource.name + str(idx))

# run pipeline
if (plugin := pipeline.get_plugin("cp")) is not None:
plugin.reset()
pipeline.run(source(), schema_contract=settings.get("override"))

# check global settings
Expand All @@ -151,6 +149,10 @@ def __init__(self) -> None:
self.common_attributes: Dict[str, str] = None
self.calls_have_common_attributes = True

def on_step_start(self, step: str, pipeline: SupportsPipeline) -> None:
if step == "run":
self.reset()

def on_schema_contract_violation(self, error: DataValidationError, **kwargs: Any) -> None:
self.calls.append(error)
error_dict = error.__dict__.copy()
Expand All @@ -171,7 +173,6 @@ def row_count(self) -> int:
if (isinstance(call.data_item, Iterable) and not isinstance(call.data_item, dict))
else [call.data_item]
)
print(data_items)
for item in data_items:
if hasattr(item, "num_rows"):
count += item.num_rows
Expand All @@ -185,7 +186,7 @@ def first_row(self) -> Dict[str, Any]:
data_item = (
call.data_item
if not (isinstance(call.data_item, Iterable) and not isinstance(call.data_item, dict))
else call.data_item[0]
else call.data_item[0] # type: ignore
)
# arrow tables
if hasattr(data_item, "num_rows"):
Expand Down Expand Up @@ -426,6 +427,8 @@ def test_freeze_variants(contract_setting: str, setting_location: str) -> None:
)
assert table_counts[ITEMS_TABLE] == 10
assert OLD_COLUMN_NAME in pipeline.default_schema.tables[ITEMS_TABLE]["columns"]
plugin = cast(ContractsViolationPlugin, pipeline.get_plugin("cp"))
assert len(plugin.calls) == 0

# subtable should work
run_resource(pipeline, items_with_subtable, full_settings)
Expand All @@ -434,6 +437,7 @@ def test_freeze_variants(contract_setting: str, setting_location: str) -> None:
)
assert table_counts[ITEMS_TABLE] == 20
assert table_counts[SUBITEMS_TABLE] == 10
assert len(plugin.calls) == 0

# new should work
run_resource(pipeline, new_items, full_settings)
Expand All @@ -442,6 +446,7 @@ def test_freeze_variants(contract_setting: str, setting_location: str) -> None:
)
assert table_counts[ITEMS_TABLE] == 20
assert table_counts[NEW_ITEMS_TABLE] == 10
assert len(plugin.calls) == 0

# test adding new column
run_resource(pipeline, items_with_new_column, full_settings)
Expand All @@ -450,6 +455,7 @@ def test_freeze_variants(contract_setting: str, setting_location: str) -> None:
)
assert table_counts[ITEMS_TABLE] == 30
assert NEW_COLUMN_NAME in pipeline.default_schema.tables[ITEMS_TABLE]["columns"]
assert len(plugin.calls) == 0

# test adding variant column
with raises_frozen_exception(contract_setting == "freeze"):
Expand All @@ -466,6 +472,27 @@ def test_freeze_variants(contract_setting: str, setting_location: str) -> None:
40 if contract_setting in ["evolve", "discard_value"] else 30
)

# check plugin calls
if contract_setting in ["evolve", "freeze"]:
assert len(plugin.calls) == 0
else:
# TODO: here we are sending the row coerced into variant fields, we should probably send the original row
assert plugin.row_count == 10
assert plugin.common_attributes == {
"schema_name": "freeze_tests",
"table_name": ITEMS_TABLE,
"column_name": "some_int__v_text",
"schema_entity": "columns",
"contract_mode": contract_setting,
"table_schema": None,
"schema_contract": {
"tables": "evolve",
"columns": "evolve",
"data_type": contract_setting,
},
}
assert plugin.first_row == {"id": 0, "name": "item 0", "some_int__v_text": "hello"}


def test_settings_precedence() -> None:
pipeline = get_pipeline()
Expand Down Expand Up @@ -659,6 +686,8 @@ def get_items():

pipeline.run([get_items()], schema_contract={"columns": "freeze", "tables": "evolve"})
assert pipeline.last_trace.last_normalize_info.row_counts["items"] == 2
plugin = cast(ContractsViolationPlugin, pipeline.get_plugin("cp"))
assert len(plugin.calls) == 0


@pytest.mark.parametrize("table_mode", ["discard_row", "evolve", "freeze"])
Expand Down Expand Up @@ -688,6 +717,40 @@ def get_items():
1 if table_mode == "evolve" else 0
)

plugin = cast(ContractsViolationPlugin, pipeline.get_plugin("cp"))
if table_mode in ["freeze", "evolve"]:
assert len(plugin.calls) == 0
else:
assert plugin.row_count == 2
assert plugin.calls[0].__dict__ == {
"schema_name": pipeline.default_schema.name,
"table_name": "one",
"column_name": None,
"schema_entity": "tables",
"contract_mode": table_mode,
"table_schema": None,
"data_item": {"id": 1, "tables": "one"},
"schema_contract": {
"tables": table_mode,
"columns": "evolve",
"data_type": "evolve",
},
}
assert plugin.calls[1].__dict__ == {
"schema_name": pipeline.default_schema.name,
"table_name": "two",
"column_name": None,
"schema_entity": "tables",
"contract_mode": table_mode,
"table_schema": None,
"data_item": {"id": 2, "tables": "two", "new_column": "some val"},
"schema_contract": {
"tables": table_mode,
"columns": "evolve",
"data_type": "evolve",
},
}


@pytest.mark.parametrize("column_mode", ["discard_row", "evolve", "freeze"])
def test_defined_column_in_new_table(column_mode: str) -> None:
Expand All @@ -702,6 +765,8 @@ def get_items():

pipeline.run([get_items()], schema_contract={"columns": column_mode})
assert pipeline.last_trace.last_normalize_info.row_counts.get("items", 0) == 1
plugin = cast(ContractsViolationPlugin, pipeline.get_plugin("cp"))
assert len(plugin.calls) == 0


@pytest.mark.parametrize("column_mode", ["freeze", "discard_row", "evolve"])
Expand All @@ -721,6 +786,8 @@ def get_items():

pipeline.run([get_items()], schema_contract={"columns": column_mode})
assert pipeline.last_trace.last_normalize_info.row_counts.get("items", 0) == 1
plugin = cast(ContractsViolationPlugin, pipeline.get_plugin("cp"))
assert len(plugin.calls) == 0


@pytest.mark.parametrize("column_mode", ["freeze", "discard_row", "evolve"])
Expand All @@ -743,6 +810,8 @@ def items():

pipeline.run([items()], schema_contract={"columns": column_mode})
assert pipeline.last_trace.last_normalize_info.row_counts["items"] == 2
plugin = cast(ContractsViolationPlugin, pipeline.get_plugin("cp"))
assert len(plugin.calls) == 0


@pytest.mark.parametrize("column_mode", ["freeze", "discard_row", "evolve"])
Expand Down Expand Up @@ -775,3 +844,5 @@ def get_items():
# apply hints apply to `items` not the original resource, so doing get_items() below removed them completely
pipeline.run(items)
assert pipeline.last_trace.last_normalize_info.row_counts.get("items", 0) == 2
plugin = cast(ContractsViolationPlugin, pipeline.get_plugin("cp"))
assert len(plugin.calls) == 0

0 comments on commit 6ceeb15

Please sign in to comment.