From e707580597c37c525bf833eec12950473ad8d591 Mon Sep 17 00:00:00 2001 From: Dave Date: Wed, 13 Sep 2023 13:28:14 +0200 Subject: [PATCH] update schema management --- dlt/common/schema/schema.py | 12 ++--- dlt/common/typing.py | 2 +- dlt/extract/decorators.py | 2 +- dlt/extract/source.py | 10 ++++- dlt/normalize/normalize.py | 29 ++++++------ dlt/pipeline/pipeline.py | 6 ++- tests/load/test_freeze_and_data_contract.py | 50 ++++++++++----------- 7 files changed, 61 insertions(+), 50 deletions(-) diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index 3d1603b13a..8a225eb875 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -194,7 +194,7 @@ def coerce_row(self, table_name: str, parent_table: str, row: StrAny) -> Tuple[D return new_row, updated_table_partial - def resolve_evolution_settings_for_table(self, parent_table: str, table_name: str, schema_contract_settings_override: TSchemaContractSettings) -> TSchemaContractModes: + def resolve_contract_settings_for_table(self, parent_table: str, table_name: str) -> TSchemaContractModes: def resolve_single(settings: TSchemaContractSettings) -> TSchemaContractModes: settings = settings or {} @@ -208,15 +208,14 @@ def resolve_single(settings: TSchemaContractSettings) -> TSchemaContractModes: # modes table_contract_modes = resolve_single(self.tables.get(table_with_settings, {}).get("schema_contract_settings", {})) schema_contract_modes = resolve_single(self._settings.get("schema_contract_settings", {})) - overide_modes = resolve_single(schema_contract_settings_override) # resolve to correct settings dict - settings = cast(TSchemaContractModes, {**DEFAULT_SCHEMA_CONTRACT_MODE, **schema_contract_modes, **table_contract_modes, **overide_modes}) + settings = cast(TSchemaContractModes, {**DEFAULT_SCHEMA_CONTRACT_MODE, **schema_contract_modes, **table_contract_modes}) return settings - def check_schema_update(self, contract_modes: TSchemaContractModes, table_name: str, row: DictStrAny, partial_table: TPartialTableSchema, schema_contract_settings_override: TSchemaContractSettings) -> Tuple[DictStrAny, TPartialTableSchema]: + def check_schema_update(self, contract_modes: TSchemaContractModes, table_name: str, row: DictStrAny, partial_table: TPartialTableSchema) -> Tuple[DictStrAny, TPartialTableSchema]: """Checks if schema update mode allows for the requested changes, filter row or reject update, depending on the mode""" assert partial_table @@ -450,8 +449,11 @@ def update_normalizers(self) -> None: normalizers["json"] = normalizers["json"] or self._normalizers_config["json"] self._configure_normalizers(normalizers) - def set_schema_contract_settings(self, settings: TSchemaContractSettings) -> None: + def set_schema_contract_settings(self, settings: TSchemaContractSettings, update_table_settings: bool = False) -> None: self._settings["schema_contract_settings"] = settings + if update_table_settings: + for table in self.tables.values(): + table["schema_contract_settings"] = settings def _infer_column(self, k: str, v: Any, data_type: TDataType = None, is_variant: bool = False) -> TColumnSchema: column_schema = TColumnSchema( diff --git a/dlt/common/typing.py b/dlt/common/typing.py index 2f1cecc093..607ee15d68 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -77,7 +77,7 @@ def is_final_type(t: Type[Any]) -> bool: def extract_union_types(t: Type[Any], no_none: bool = False) -> List[Any]: if no_none: - return [arg for arg in get_args(t) if arg is not type(None)] + return [arg for arg in get_args(t) if arg is not type(None)] # noqa: E721 return list(get_args(t)) def is_literal_type(hint: Type[Any]) -> bool: diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index 1134dd9cc3..653d57d13f 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -170,13 +170,13 @@ def _wrap(*args: Any, **kwargs: Any) -> DltSource: # prepare schema schema = schema.clone(update_normalizers=True) - schema.set_schema_contract_settings(schema_contract_settings) # convert to source s = DltSource.from_data(name, source_section, schema.clone(update_normalizers=True), rv) # apply hints if max_table_nesting is not None: s.max_table_nesting = max_table_nesting + s.schema_contract_settings = schema_contract_settings # enable root propagation s.root_key = root_key return s diff --git a/dlt/extract/source.py b/dlt/extract/source.py index 52a0381dfe..da7119d3f4 100644 --- a/dlt/extract/source.py +++ b/dlt/extract/source.py @@ -11,7 +11,7 @@ from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer, RelationalNormalizerConfigPropagation from dlt.common.schema import Schema -from dlt.common.schema.typing import TColumnName +from dlt.common.schema.typing import TColumnName, TSchemaContractSettings from dlt.common.typing import AnyFun, StrAny, TDataItem, TDataItems, NoneType from dlt.common.configuration.container import Container from dlt.common.pipeline import PipelineContext, StateInjectableContext, SupportsPipelineRun, resource_state, source_state, pipeline_state @@ -605,6 +605,14 @@ def max_table_nesting(self) -> int: def max_table_nesting(self, value: int) -> None: RelationalNormalizer.update_normalizer_config(self._schema, {"max_nesting": value}) + @property + def schema_contract_settings(self) -> TSchemaContractSettings: + return self.schema.settings["schema_contract_settings"] + + @schema_contract_settings.setter + def schema_contract_settings(self, settings: TSchemaContractSettings) -> None: + self.schema.set_schema_contract_settings(settings) + @property def exhausted(self) -> bool: """check all selected pipes wether one of them has started. if so, the source is exhausted.""" diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index ae3e2e658c..73a192e6e1 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -11,7 +11,7 @@ from dlt.common.runners import TRunMetrics, Runnable from dlt.common.runtime import signals from dlt.common.runtime.collector import Collector, NULL_COLLECTOR -from dlt.common.schema.typing import TStoredSchema, TTableSchemaColumns, TSchemaContractSettings, TSchemaContractModes +from dlt.common.schema.typing import TStoredSchema, TTableSchemaColumns, TSchemaContractModes from dlt.common.schema.utils import merge_schema_updates from dlt.common.storages.exceptions import SchemaNotFoundError from dlt.common.storages import NormalizeStorage, SchemaStorage, LoadStorage, LoadStorageConfiguration, NormalizeStorageConfiguration @@ -26,7 +26,7 @@ # normalize worker wrapping function (map_parallel, map_single) return type TMapFuncRV = Tuple[Sequence[TSchemaUpdate], TRowCount] # normalize worker wrapping function signature -TMapFuncType = Callable[[Schema, str, Sequence[str], TSchemaContractSettings], TMapFuncRV] # input parameters: (schema name, load_id, list of files to process) +TMapFuncType = Callable[[Schema, str, Sequence[str]], TMapFuncRV] # input parameters: (schema name, load_id, list of files to process) # tuple returned by the worker TWorkerRV = Tuple[List[TSchemaUpdate], int, List[str], TRowCount] @@ -34,7 +34,7 @@ class Normalize(Runnable[ProcessPool]): @with_config(spec=NormalizeConfiguration, sections=(known_sections.NORMALIZE,)) - def __init__(self, collector: Collector = NULL_COLLECTOR, schema_storage: SchemaStorage = None, config: NormalizeConfiguration = config.value, schema_contract_settings: TSchemaContractSettings = None) -> None: + def __init__(self, collector: Collector = NULL_COLLECTOR, schema_storage: SchemaStorage = None, config: NormalizeConfiguration = config.value) -> None: self.config = config self.collector = collector self.pool: ProcessPool = None @@ -42,7 +42,6 @@ def __init__(self, collector: Collector = NULL_COLLECTOR, schema_storage: Schema self.load_storage: LoadStorage = None self.schema_storage: SchemaStorage = None self._row_counts: TRowCount = {} - self.schema_contract_settings = schema_contract_settings # setup storages self.create_storages() @@ -73,8 +72,7 @@ def w_normalize_files( destination_caps: DestinationCapabilitiesContext, stored_schema: TStoredSchema, load_id: str, - extracted_items_files: Sequence[str], - schema_contract_settings: TSchemaContractSettings + extracted_items_files: Sequence[str] ) -> TWorkerRV: schema_updates: List[TSchemaUpdate] = [] total_items = 0 @@ -99,7 +97,7 @@ def w_normalize_files( items_count = 0 for line_no, line in enumerate(f): items: List[TDataItem] = json.loads(line) - partial_update, items_count, r_counts = Normalize._w_normalize_chunk(load_storage, schema, load_id, root_table_name, items, schema_contract_settings) + partial_update, items_count, r_counts = Normalize._w_normalize_chunk(load_storage, schema, load_id, root_table_name, items) schema_updates.append(partial_update) total_items += items_count merge_row_count(row_counts, r_counts) @@ -128,7 +126,7 @@ def w_normalize_files( return schema_updates, total_items, load_storage.closed_files(), row_counts @staticmethod - def _w_normalize_chunk(load_storage: LoadStorage, schema: Schema, load_id: str, root_table_name: str, items: List[TDataItem], schema_contract_settings: TSchemaContractSettings) -> Tuple[TSchemaUpdate, int, TRowCount]: + def _w_normalize_chunk(load_storage: LoadStorage, schema: Schema, load_id: str, root_table_name: str, items: List[TDataItem]) -> Tuple[TSchemaUpdate, int, TRowCount]: column_schemas: Dict[str, TTableSchemaColumns] = {} # quick access to column schema for writers below schema_update: TSchemaUpdate = {} schema_name = schema.name @@ -139,7 +137,7 @@ def _w_normalize_chunk(load_storage: LoadStorage, schema: Schema, load_id: str, for item in items: for (table_name, parent_table), row in schema.normalize_data_item(item, load_id, root_table_name): if not schema_contract_modes: - schema_contract_modes = schema.resolve_evolution_settings_for_table(parent_table, table_name, schema_contract_settings) + schema_contract_modes = schema.resolve_contract_settings_for_table(parent_table, table_name) # filter row, may eliminate some or all fields row = schema.filter_row(table_name, row) @@ -153,7 +151,7 @@ def _w_normalize_chunk(load_storage: LoadStorage, schema: Schema, load_id: str, row, partial_table = schema.coerce_row(table_name, parent_table, row) # if we detect a migration, the check update if partial_table: - row, partial_table = schema.check_schema_update(schema_contract_modes, table_name, row, partial_table, schema_contract_settings) + row, partial_table = schema.check_schema_update(schema_contract_modes, table_name, row, partial_table) if not row: continue @@ -204,12 +202,12 @@ def group_worker_files(files: Sequence[str], no_groups: int) -> List[Sequence[st l_idx = idx + 1 return chunk_files - def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str], schema_contract_settings: TSchemaContractSettings) -> TMapFuncRV: + def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TMapFuncRV: workers = self.pool._processes # type: ignore chunk_files = self.group_worker_files(files, workers) schema_dict: TStoredSchema = schema.to_dict() config_tuple = (self.normalize_storage.config, self.load_storage.config, self.config.destination_capabilities, schema_dict) - param_chunk = [[*config_tuple, load_id, files, schema_contract_settings] for files in chunk_files] + param_chunk = [[*config_tuple, load_id, files] for files in chunk_files] tasks: List[Tuple[AsyncResult[TWorkerRV], List[Any]]] = [] row_counts: TRowCount = {} @@ -259,15 +257,14 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str], schem return schema_updates, row_counts - def map_single(self, schema: Schema, load_id: str, files: Sequence[str], schema_contract_settings: TSchemaContractSettings) -> TMapFuncRV: + def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TMapFuncRV: result = Normalize.w_normalize_files( self.normalize_storage.config, self.load_storage.config, self.config.destination_capabilities, schema.to_dict(), load_id, - files, - schema_contract_settings + files ) self.update_schema(schema, result[0]) self.collector.update("Files", len(result[2])) @@ -278,7 +275,7 @@ def spool_files(self, schema_name: str, load_id: str, map_f: TMapFuncType, files schema = Normalize.load_or_create_schema(self.schema_storage, schema_name) # process files in parallel or in single thread, depending on map_f - schema_updates, row_counts = map_f(schema, load_id, files, self.schema_contract_settings) + schema_updates, row_counts = map_f(schema, load_id, files) # logger.metrics("Normalize metrics", extra=get_logging_extras([self.schema_version_gauge.labels(schema_name)])) if len(schema_updates) > 0: logger.info(f"Saving schema {schema_name} with version {schema.version}, writing manifest files") diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 567f286ba1..1789a02e0f 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -304,6 +304,10 @@ def normalize(self, workers: int = 1, loader_file_format: TLoaderFileFormat = No if not self.default_schema_name: return None + # update global schema contract settings, could be moved into def normalize() + if schema_contract_settings is not None: + self.default_schema.set_schema_contract_settings(schema_contract_settings, True) + # make sure destination capabilities are available self._get_destination_capabilities() # create default normalize config @@ -316,7 +320,7 @@ def normalize(self, workers: int = 1, loader_file_format: TLoaderFileFormat = No # run with destination context with self._maybe_destination_capabilities(loader_file_format=loader_file_format): # shares schema storage with the pipeline so we do not need to install - normalize = Normalize(collector=self.collector, config=normalize_config, schema_storage=self._schema_storage, schema_contract_settings=schema_contract_settings) + normalize = Normalize(collector=self.collector, config=normalize_config, schema_storage=self._schema_storage) try: with signals.delayed_signals(): runner.run_pool(normalize.config, normalize) diff --git a/tests/load/test_freeze_and_data_contract.py b/tests/load/test_freeze_and_data_contract.py index 5ef936bd66..2460679737 100644 --- a/tests/load/test_freeze_and_data_contract.py +++ b/tests/load/test_freeze_and_data_contract.py @@ -113,25 +113,25 @@ def source() -> DltResource: pipeline.run(source(), schema_contract_settings=settings.get("override")) # check updated schema - assert pipeline.default_schema._settings["schema_contract_settings"] == settings.get("source") + assert pipeline.default_schema._settings["schema_contract_settings"] == (settings.get("override") or settings.get("source")) # check items table settings - assert pipeline.default_schema.tables["items"]["schema_contract_settings"] == settings.get("resource") + assert pipeline.default_schema.tables["items"]["schema_contract_settings"] == (settings.get("override") or settings.get("resource")) def get_pipeline(): import duckdb return dlt.pipeline(pipeline_name=uniq_id(), destination='duckdb', credentials=duckdb.connect(':memory:'), full_refresh=True) -@pytest.mark.parametrize("evolution_setting", schema_contract_settings) +@pytest.mark.parametrize("contract_setting", schema_contract_settings) @pytest.mark.parametrize("setting_location", LOCATIONS) -def test_freeze_new_tables(evolution_setting: str, setting_location: str) -> None: +def test_freeze_new_tables(contract_setting: str, setting_location: str) -> None: pipeline = get_pipeline() full_settings = { setting_location: { - "table": evolution_setting + "table": contract_setting }} run_resource(pipeline, items, {}) table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) @@ -149,7 +149,7 @@ def test_freeze_new_tables(evolution_setting: str, setting_location: str) -> Non assert VARIANT_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] # test adding new subtable - if evolution_setting == "freeze": + if contract_setting == "freeze": with pytest.raises(PipelineStepFailed) as py_ex: run_resource(pipeline, items_with_subtable, full_settings) assert isinstance(py_ex.value.__context__, SchemaFrozenException) @@ -157,27 +157,27 @@ def test_freeze_new_tables(evolution_setting: str, setting_location: str) -> Non run_resource(pipeline, items_with_subtable, full_settings) table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) - assert table_counts["items"] == 30 if evolution_setting in ["freeze"] else 40 - assert table_counts.get(SUBITEMS_TABLE, 0) == (10 if evolution_setting in ["evolve"] else 0) + assert table_counts["items"] == 30 if contract_setting in ["freeze"] else 40 + assert table_counts.get(SUBITEMS_TABLE, 0) == (10 if contract_setting in ["evolve"] else 0) # test adding new table - if evolution_setting == "freeze": + if contract_setting == "freeze": with pytest.raises(PipelineStepFailed) as py_ex: run_resource(pipeline, new_items, full_settings) assert isinstance(py_ex.value.__context__, SchemaFrozenException) else: run_resource(pipeline, new_items, full_settings) table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) - assert table_counts.get("new_items", 0) == (10 if evolution_setting in ["evolve"] else 0) + assert table_counts.get("new_items", 0) == (10 if contract_setting in ["evolve"] else 0) -@pytest.mark.parametrize("evolution_setting", schema_contract_settings) +@pytest.mark.parametrize("contract_setting", schema_contract_settings) @pytest.mark.parametrize("setting_location", LOCATIONS) -def test_freeze_new_columns(evolution_setting: str, setting_location: str) -> None: +def test_freeze_new_columns(contract_setting: str, setting_location: str) -> None: full_settings = { setting_location: { - "column": evolution_setting + "column": contract_setting }} pipeline = get_pipeline() @@ -199,43 +199,43 @@ def test_freeze_new_columns(evolution_setting: str, setting_location: str) -> No assert table_counts[NEW_ITEMS_TABLE] == 10 # test adding new column - if evolution_setting == "freeze": + if contract_setting == "freeze": with pytest.raises(PipelineStepFailed) as py_ex: run_resource(pipeline, items_with_new_column, full_settings) assert isinstance(py_ex.value.__context__, SchemaFrozenException) else: run_resource(pipeline, items_with_new_column, full_settings) - if evolution_setting == "evolve": + if contract_setting == "evolve": assert NEW_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] else: assert NEW_COLUMN_NAME not in pipeline.default_schema.tables["items"]["columns"] table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) - assert table_counts["items"] == (30 if evolution_setting in ["evolve", "discard-value"] else 20) + assert table_counts["items"] == (30 if contract_setting in ["evolve", "discard-value"] else 20) # test adding variant column - if evolution_setting == "freeze": + if contract_setting == "freeze": with pytest.raises(PipelineStepFailed) as py_ex: run_resource(pipeline, items_with_variant, full_settings) assert isinstance(py_ex.value.__context__, SchemaFrozenException) else: run_resource(pipeline, items_with_variant, full_settings) - if evolution_setting == "evolve": + if contract_setting == "evolve": assert VARIANT_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] else: assert VARIANT_COLUMN_NAME not in pipeline.default_schema.tables["items"]["columns"] table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) - assert table_counts["items"] == (40 if evolution_setting in ["evolve", "discard-value"] else 20) + assert table_counts["items"] == (40 if contract_setting in ["evolve", "discard-value"] else 20) -@pytest.mark.parametrize("evolution_setting", schema_contract_settings) +@pytest.mark.parametrize("contract_setting", schema_contract_settings) @pytest.mark.parametrize("setting_location", LOCATIONS) -def test_freeze_variants(evolution_setting: str, setting_location: str) -> None: +def test_freeze_variants(contract_setting: str, setting_location: str) -> None: full_settings = { setting_location: { - "data_type": evolution_setting + "data_type": contract_setting }} pipeline = get_pipeline() run_resource(pipeline, items, {}) @@ -262,19 +262,19 @@ def test_freeze_variants(evolution_setting: str, setting_location: str) -> None: assert NEW_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] # test adding variant column - if evolution_setting == "freeze": + if contract_setting == "freeze": with pytest.raises(PipelineStepFailed) as py_ex: run_resource(pipeline, items_with_variant, full_settings) assert isinstance(py_ex.value.__context__, SchemaFrozenException) else: run_resource(pipeline, items_with_variant, full_settings) - if evolution_setting == "evolve": + if contract_setting == "evolve": assert VARIANT_COLUMN_NAME in pipeline.default_schema.tables["items"]["columns"] else: assert VARIANT_COLUMN_NAME not in pipeline.default_schema.tables["items"]["columns"] table_counts = load_table_counts(pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()]) - assert table_counts["items"] == (40 if evolution_setting in ["evolve", "discard-value"] else 30) + assert table_counts["items"] == (40 if contract_setting in ["evolve", "discard-value"] else 30) def test_settings_precedence() -> None: