From 61ab9970c0820ae92235175b8539f1d9040cd32a Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Fri, 9 Aug 2024 09:53:24 -0400 Subject: [PATCH] Raise/warn on incomplete columns in normalize (#1504) * Raise/warn on incomplete columns in normalize Raise on not-nullable columns to catch e.g. misspelled merge/primary key key * Update error msg * Test for null values * Lint * Delete now invalid tests * Fix common test --- dlt/common/destination/utils.py | 11 +-- dlt/common/schema/exceptions.py | 24 +++++ dlt/common/schema/utils.py | 13 +++ dlt/normalize/normalize.py | 2 + dlt/normalize/schema.py | 17 ++++ tests/load/pipeline/test_merge_disposition.py | 93 +++++++++++-------- .../load/pipeline/test_replace_disposition.py | 2 +- tests/pipeline/test_pipeline_trace.py | 2 +- 8 files changed, 113 insertions(+), 51 deletions(-) create mode 100644 dlt/normalize/schema.py diff --git a/dlt/common/destination/utils.py b/dlt/common/destination/utils.py index 2c5e97df14..931413126c 100644 --- a/dlt/common/destination/utils.py +++ b/dlt/common/destination/utils.py @@ -6,7 +6,6 @@ from dlt.common.schema.exceptions import ( SchemaIdentifierNormalizationCollision, ) -from dlt.common.schema.utils import is_complete_column from dlt.common.typing import DictStrStr from .capabilities import DestinationCapabilitiesContext @@ -25,7 +24,6 @@ def verify_schema_capabilities( * Checks if schema has collisions due to case sensitivity of the identifiers """ - log = logger.warning if warnings else logger.info # collect all exceptions to show all problems in the schema exception_log: List[Exception] = [] # combined casing function @@ -79,7 +77,7 @@ def verify_schema_capabilities( ) column_name_lookup: DictStrStr = {} - for column_name, column in dict(table["columns"]).items(): + for column_name in dict(table["columns"]): # detect table name conflict cased_column_name = case_identifier(column_name) if cased_column_name in column_name_lookup: @@ -105,11 +103,4 @@ def verify_schema_capabilities( capabilities.max_column_identifier_length, ) ) - if not is_complete_column(column): - log( - f"A column {column_name} in table {table_name} in schema" - f" {schema.name} is incomplete. It was not bound to the data during" - " normalizations stage and its data type is unknown. Did you add this" - " column manually in code ie. as a merge key?" - ) return exception_log diff --git a/dlt/common/schema/exceptions.py b/dlt/common/schema/exceptions.py index 2f016577ce..1055163942 100644 --- a/dlt/common/schema/exceptions.py +++ b/dlt/common/schema/exceptions.py @@ -8,6 +8,7 @@ TSchemaEvolutionMode, ) from dlt.common.normalizers.naming import NamingConvention +from dlt.common.schema.typing import TColumnSchema, TColumnSchemaBase class SchemaException(DltException): @@ -231,3 +232,26 @@ def __init__( class ColumnNameConflictException(SchemaException): pass + + +class UnboundColumnException(SchemaException): + def __init__(self, schema_name: str, table_name: str, column: TColumnSchemaBase) -> None: + self.column = column + self.schema_name = schema_name + self.table_name = table_name + nullable: bool = column.get("nullable", False) + key_type: str = "" + if column.get("merge_key"): + key_type = "merge key" + elif column.get("primary_key"): + key_type = "primary key" + + msg = f"The column {column['name']} in table {table_name} did not receive any data during this load. " + if key_type or not nullable: + msg += f"It is marked as non-nullable{' '+key_type} and it must have values. " + + msg += ( + "This can happen if you specify the column manually, for example using the 'merge_key', 'primary_key' or 'columns' argument " + "but it does not exist in the data." + ) + super().__init__(schema_name, msg) diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index aa5de9611c..d879c21b3c 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -352,6 +352,19 @@ def is_complete_column(col: TColumnSchemaBase) -> bool: return bool(col.get("name")) and bool(col.get("data_type")) +def is_nullable_column(col: TColumnSchemaBase) -> bool: + """Returns true if column is nullable""" + return col.get("nullable", True) + + +def find_incomplete_columns(tables: List[TTableSchema]) -> Iterable[Tuple[str, TColumnSchemaBase, bool]]: + """Yields (table_name, column, nullable) for all incomplete columns in `tables`""" + for table in tables: + for col in table["columns"].values(): + if not is_complete_column(col): + yield table["name"], col, is_nullable_column(col) + + def compare_complete_columns(a: TColumnSchema, b: TColumnSchema) -> bool: """Compares mandatory fields of complete columns""" assert is_complete_column(a) diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 98154cd5cf..e80931605c 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -34,6 +34,7 @@ from dlt.normalize.configuration import NormalizeConfiguration from dlt.normalize.exceptions import NormalizeJobFailed from dlt.normalize.worker import w_normalize_files, group_worker_files, TWorkerRV +from dlt.normalize.schema import verify_normalized_schema # normalize worker wrapping function signature @@ -195,6 +196,7 @@ def spool_files( x_normalizer["seen-data"] = True # schema is updated, save it to schema volume if schema.is_modified: + verify_normalized_schema(schema) logger.info( f"Saving schema {schema.name} with version {schema.stored_version}:{schema.version}" ) diff --git a/dlt/normalize/schema.py b/dlt/normalize/schema.py new file mode 100644 index 0000000000..4967fab18f --- /dev/null +++ b/dlt/normalize/schema.py @@ -0,0 +1,17 @@ +from dlt.common.schema import Schema +from dlt.common.schema.utils import find_incomplete_columns +from dlt.common.schema.exceptions import UnboundColumnException +from dlt.common import logger + +def verify_normalized_schema(schema: Schema) -> None: + """Verify the schema is valid for next stage after normalization. + + 1. Log warning if any incomplete nullable columns are in any data tables + 2. Raise `UnboundColumnException` on incomplete non-nullable columns (e.g. missing merge/primary key) + """ + for table_name, column, nullable in find_incomplete_columns(schema.data_tables(seen_data_only=True)): + exc = UnboundColumnException(schema.name, table_name, column) + if nullable: + logger.warning(str(exc)) + else: + raise exc diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index 63188d4f5e..b2197dd273 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -11,7 +11,11 @@ from dlt.common.configuration.container import Container from dlt.common.pipeline import StateInjectableContext from dlt.common.schema.utils import has_table_seen_data -from dlt.common.schema.exceptions import SchemaCorruptedException +from dlt.common.schema.exceptions import ( + SchemaCorruptedException, + UnboundColumnException, + CannotCoerceNullException, +) from dlt.common.schema.typing import TLoaderMergeStrategy from dlt.common.typing import StrAny from dlt.common.utils import digest128 @@ -20,6 +24,7 @@ from dlt.extract import DltResource from dlt.sources.helpers.transform import skip_first, take_first from dlt.pipeline.exceptions import PipelineStepFailed +from dlt.normalize.exceptions import NormalizeJobFailed from tests.pipeline.utils import ( assert_load_info, @@ -445,44 +450,6 @@ def test_merge_no_merge_keys(destination_config: DestinationTestConfiguration) - assert github_1_counts["issues"] == 100 - 45 + 10 -@pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name -) -def test_merge_keys_non_existing_columns(destination_config: DestinationTestConfiguration) -> None: - p = destination_config.setup_pipeline("github_3", dev_mode=True) - github_data = github() - # set keys names that do not exist in the data - github_data.load_issues.apply_hints(merge_key=("mA1", "Ma2"), primary_key=("123-x",)) - # skip first 45 rows - github_data.load_issues.add_filter(skip_first(45)) - info = p.run(github_data, loader_file_format=destination_config.file_format) - assert_load_info(info) - github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) - assert github_1_counts["issues"] == 100 - 45 - assert ( - p.default_schema.tables["issues"]["columns"]["m_a1"].items() - > {"merge_key": True, "nullable": False}.items() - ) - - # for non merge destinations we just check that the run passes - if not destination_config.supports_merge: - return - - # all the keys are invalid so the merge falls back to append - github_data = github() - github_data.load_issues.apply_hints(merge_key=("mA1", "Ma2"), primary_key=("123-x",)) - github_data.load_issues.add_filter(take_first(1)) - info = p.run(github_data, loader_file_format=destination_config.file_format) - assert_load_info(info) - github_2_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) - assert github_2_counts["issues"] == 100 - 45 + 1 - with p._sql_job_client(p.default_schema) as job_c: - _, storage_cols = job_c.get_storage_table("issues") - storage_cols = normalize_storage_table_cols("issues", storage_cols, p.default_schema) - assert "url" in storage_cols - assert "m_a1" not in storage_cols # unbound columns were not created - - @pytest.mark.parametrize( "destination_config", destinations_configs(default_sql_configs=True, file_format="parquet"), @@ -1242,3 +1209,51 @@ def r(): with pytest.raises(PipelineStepFailed) as pip_ex: p.run(r()) assert isinstance(pip_ex.value.__context__, SchemaCorruptedException) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb"]), + ids=lambda x: x.name, +) +def test_missing_merge_key_column(destination_config: DestinationTestConfiguration) -> None: + """Merge key is not present in data, error is raised""" + + @dlt.resource(merge_key="not_a_column", write_disposition={"disposition": "merge"}) + def merging_test_table(): + yield {"foo": "bar"} + + p = destination_config.setup_pipeline("abstract", full_refresh=True) + with pytest.raises(PipelineStepFailed) as pip_ex: + p.run(merging_test_table()) + + ex = pip_ex.value + assert ex.step == "normalize" + assert isinstance(ex.__context__, UnboundColumnException) + + assert "not_a_column" in str(ex) + assert "merge key" in str(ex) + assert "merging_test_table" in str(ex) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb"]), + ids=lambda x: x.name, +) +def test_merge_key_null_values(destination_config: DestinationTestConfiguration) -> None: + """Merge key is present in data, but some rows have null values""" + + @dlt.resource(merge_key="id", write_disposition={"disposition": "merge"}) + def r(): + yield [{"id": 1}, {"id": None}, {"id": 2}] + + p = destination_config.setup_pipeline("abstract", full_refresh=True) + with pytest.raises(PipelineStepFailed) as pip_ex: + p.run(r()) + + ex = pip_ex.value + assert ex.step == "normalize" + + assert isinstance(ex.__context__, NormalizeJobFailed) + assert isinstance(ex.__context__.__context__, CannotCoerceNullException) diff --git a/tests/load/pipeline/test_replace_disposition.py b/tests/load/pipeline/test_replace_disposition.py index 12bc69abe0..d49ce2904f 100644 --- a/tests/load/pipeline/test_replace_disposition.py +++ b/tests/load/pipeline/test_replace_disposition.py @@ -58,7 +58,7 @@ def norm_table_counts(counts: Dict[str, int], *child_tables: str) -> Dict[str, i offset = 1000 # keep merge key with unknown column to test replace SQL generator - @dlt.resource(name="items", write_disposition="replace", primary_key="id", merge_key="NA") + @dlt.resource(name="items", write_disposition="replace", primary_key="id") def load_items(): # will produce 3 jobs for the main table with 40 items each # 6 jobs for the sub_items diff --git a/tests/pipeline/test_pipeline_trace.py b/tests/pipeline/test_pipeline_trace.py index 7122b4a4c6..3239e01bab 100644 --- a/tests/pipeline/test_pipeline_trace.py +++ b/tests/pipeline/test_pipeline_trace.py @@ -46,7 +46,7 @@ def inject_tomls( ): @dlt.resource(write_disposition="replace", primary_key="id") def data(): - yield [1, 2, 3] + yield [{"id": 1}, {"id": 2}, {"id": 3}] return data()