diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 9d94def2f3..77320ecfc2 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -189,6 +189,9 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: result = f(self, *args, **kwargs) plugins_ctx.on_step_end(f.__name__, self) + # ensure messages queue is completely processed + plugins_ctx.process_queue() + return result return _wrap # type: ignore diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index d170fd553b..62668a6316 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -415,7 +415,7 @@ def complex_data(): staging=destination_config.staging, dataset_name="ds_" + uniq_id(), ) - print(info) + with dlt.pipeline().sql_client() as client: complex_cn_table = client.make_qualified_table_name("complex_cn") rows = select_data(dlt.pipeline(), f"SELECT cn FROM {complex_cn_table}") diff --git a/tests/pipeline/test_schema_contracts.py b/tests/pipeline/test_schema_contracts.py index fa1d015509..bf2dc1ba31 100644 --- a/tests/pipeline/test_schema_contracts.py +++ b/tests/pipeline/test_schema_contracts.py @@ -2,6 +2,7 @@ import contextlib from typing import Any, Callable, Iterator, Union, Optional, List, Dict, cast from collections.abc import Iterable +from dlt.common.pipeline import SupportsPipeline from dlt.common.schema.typing import TSchemaContract from dlt.common.utils import uniq_id from dlt.common.schema.exceptions import DataValidationError @@ -123,9 +124,6 @@ def source() -> Iterator[DltResource]: resource.table_name = resource.name yield resource.with_name(resource.name + str(idx)) - # run pipeline - if (plugin := pipeline.get_plugin("cp")) is not None: - plugin.reset() pipeline.run(source(), schema_contract=settings.get("override")) # check global settings @@ -151,6 +149,10 @@ def __init__(self) -> None: self.common_attributes: Dict[str, str] = None self.calls_have_common_attributes = True + def on_step_start(self, step: str, pipeline: SupportsPipeline) -> None: + if step == "run": + self.reset() + def on_schema_contract_violation(self, error: DataValidationError, **kwargs: Any) -> None: self.calls.append(error) error_dict = error.__dict__.copy() @@ -171,7 +173,6 @@ def row_count(self) -> int: if (isinstance(call.data_item, Iterable) and not isinstance(call.data_item, dict)) else [call.data_item] ) - print(data_items) for item in data_items: if hasattr(item, "num_rows"): count += item.num_rows @@ -185,7 +186,7 @@ def first_row(self) -> Dict[str, Any]: data_item = ( call.data_item if not (isinstance(call.data_item, Iterable) and not isinstance(call.data_item, dict)) - else call.data_item[0] + else call.data_item[0] # type: ignore ) # arrow tables if hasattr(data_item, "num_rows"): @@ -426,6 +427,8 @@ def test_freeze_variants(contract_setting: str, setting_location: str) -> None: ) assert table_counts[ITEMS_TABLE] == 10 assert OLD_COLUMN_NAME in pipeline.default_schema.tables[ITEMS_TABLE]["columns"] + plugin = cast(ContractsViolationPlugin, pipeline.get_plugin("cp")) + assert len(plugin.calls) == 0 # subtable should work run_resource(pipeline, items_with_subtable, full_settings) @@ -434,6 +437,7 @@ def test_freeze_variants(contract_setting: str, setting_location: str) -> None: ) assert table_counts[ITEMS_TABLE] == 20 assert table_counts[SUBITEMS_TABLE] == 10 + assert len(plugin.calls) == 0 # new should work run_resource(pipeline, new_items, full_settings) @@ -442,6 +446,7 @@ def test_freeze_variants(contract_setting: str, setting_location: str) -> None: ) assert table_counts[ITEMS_TABLE] == 20 assert table_counts[NEW_ITEMS_TABLE] == 10 + assert len(plugin.calls) == 0 # test adding new column run_resource(pipeline, items_with_new_column, full_settings) @@ -450,6 +455,7 @@ def test_freeze_variants(contract_setting: str, setting_location: str) -> None: ) assert table_counts[ITEMS_TABLE] == 30 assert NEW_COLUMN_NAME in pipeline.default_schema.tables[ITEMS_TABLE]["columns"] + assert len(plugin.calls) == 0 # test adding variant column with raises_frozen_exception(contract_setting == "freeze"): @@ -466,6 +472,27 @@ def test_freeze_variants(contract_setting: str, setting_location: str) -> None: 40 if contract_setting in ["evolve", "discard_value"] else 30 ) + # check plugin calls + if contract_setting in ["evolve", "freeze"]: + assert len(plugin.calls) == 0 + else: + # TODO: here we are sending the row coerced into variant fields, we should probably send the original row + assert plugin.row_count == 10 + assert plugin.common_attributes == { + "schema_name": "freeze_tests", + "table_name": ITEMS_TABLE, + "column_name": "some_int__v_text", + "schema_entity": "columns", + "contract_mode": contract_setting, + "table_schema": None, + "schema_contract": { + "tables": "evolve", + "columns": "evolve", + "data_type": contract_setting, + }, + } + assert plugin.first_row == {"id": 0, "name": "item 0", "some_int__v_text": "hello"} + def test_settings_precedence() -> None: pipeline = get_pipeline() @@ -659,6 +686,8 @@ def get_items(): pipeline.run([get_items()], schema_contract={"columns": "freeze", "tables": "evolve"}) assert pipeline.last_trace.last_normalize_info.row_counts["items"] == 2 + plugin = cast(ContractsViolationPlugin, pipeline.get_plugin("cp")) + assert len(plugin.calls) == 0 @pytest.mark.parametrize("table_mode", ["discard_row", "evolve", "freeze"]) @@ -688,6 +717,40 @@ def get_items(): 1 if table_mode == "evolve" else 0 ) + plugin = cast(ContractsViolationPlugin, pipeline.get_plugin("cp")) + if table_mode in ["freeze", "evolve"]: + assert len(plugin.calls) == 0 + else: + assert plugin.row_count == 2 + assert plugin.calls[0].__dict__ == { + "schema_name": pipeline.default_schema.name, + "table_name": "one", + "column_name": None, + "schema_entity": "tables", + "contract_mode": table_mode, + "table_schema": None, + "data_item": {"id": 1, "tables": "one"}, + "schema_contract": { + "tables": table_mode, + "columns": "evolve", + "data_type": "evolve", + }, + } + assert plugin.calls[1].__dict__ == { + "schema_name": pipeline.default_schema.name, + "table_name": "two", + "column_name": None, + "schema_entity": "tables", + "contract_mode": table_mode, + "table_schema": None, + "data_item": {"id": 2, "tables": "two", "new_column": "some val"}, + "schema_contract": { + "tables": table_mode, + "columns": "evolve", + "data_type": "evolve", + }, + } + @pytest.mark.parametrize("column_mode", ["discard_row", "evolve", "freeze"]) def test_defined_column_in_new_table(column_mode: str) -> None: @@ -702,6 +765,8 @@ def get_items(): pipeline.run([get_items()], schema_contract={"columns": column_mode}) assert pipeline.last_trace.last_normalize_info.row_counts.get("items", 0) == 1 + plugin = cast(ContractsViolationPlugin, pipeline.get_plugin("cp")) + assert len(plugin.calls) == 0 @pytest.mark.parametrize("column_mode", ["freeze", "discard_row", "evolve"]) @@ -721,6 +786,8 @@ def get_items(): pipeline.run([get_items()], schema_contract={"columns": column_mode}) assert pipeline.last_trace.last_normalize_info.row_counts.get("items", 0) == 1 + plugin = cast(ContractsViolationPlugin, pipeline.get_plugin("cp")) + assert len(plugin.calls) == 0 @pytest.mark.parametrize("column_mode", ["freeze", "discard_row", "evolve"]) @@ -743,6 +810,8 @@ def items(): pipeline.run([items()], schema_contract={"columns": column_mode}) assert pipeline.last_trace.last_normalize_info.row_counts["items"] == 2 + plugin = cast(ContractsViolationPlugin, pipeline.get_plugin("cp")) + assert len(plugin.calls) == 0 @pytest.mark.parametrize("column_mode", ["freeze", "discard_row", "evolve"]) @@ -775,3 +844,5 @@ def get_items(): # apply hints apply to `items` not the original resource, so doing get_items() below removed them completely pipeline.run(items) assert pipeline.last_trace.last_normalize_info.row_counts.get("items", 0) == 2 + plugin = cast(ContractsViolationPlugin, pipeline.get_plugin("cp")) + assert len(plugin.calls) == 0