From 302d90959a2ab953c4d88ba83b3bffdd1ed5ed06 Mon Sep 17 00:00:00 2001 From: Dave Date: Tue, 17 Oct 2023 22:08:43 +0200 Subject: [PATCH] post merge code updates --- dlt/common/schema/schema.py | 21 +++-- dlt/extract/decorators.py | 7 +- dlt/extract/extract.py | 59 ++++++++++---- dlt/normalize/items_normalizers.py | 78 +++++++++++-------- dlt/pipeline/pipeline.py | 17 ++-- ...e_functions.py => test_schema_contract.py} | 6 +- ...a_contract.py => test_schema_contracts.py} | 2 +- 7 files changed, 122 insertions(+), 68 deletions(-) rename tests/common/schema/{test_contract_mode_functions.py => test_schema_contract.py} (99%) rename tests/load/{test_freeze_and_data_contract.py => test_schema_contracts.py} (99%) diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index 8ff475d732..262f12bb2c 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -195,7 +195,8 @@ def coerce_row(self, table_name: str, parent_table: str, row: StrAny) -> Tuple[D return new_row, updated_table_partial - def apply_schema_contract(self, contract_modes: TSchemaContractDict, table_name: str, row: DictStrAny, partial_table: TPartialTableSchema) -> Tuple[DictStrAny, TPartialTableSchema]: + @staticmethod + def apply_schema_contract(schema: "Schema", contract_modes: TSchemaContractDict, table_name: str, row: DictStrAny, partial_table: TPartialTableSchema) -> Tuple[DictStrAny, TPartialTableSchema]: """ Checks if contract mode allows for the requested changes to the data and the schema. It will allow all changes to pass, filter out the row filter out columns for both the data and the schema_update or reject the update completely, depending on the mode. An example settings could be: @@ -219,31 +220,31 @@ def apply_schema_contract(self, contract_modes: TSchemaContractDict, table_name: if contract_modes == DEFAULT_SCHEMA_CONTRACT_MODE: return row, partial_table - is_new_table = (table_name not in self.tables) or (not self.tables[table_name]["columns"]) + is_new_table = not schema or (table_name not in schema.tables) or (not schema.tables[table_name]["columns"]) # check case where we have a new table if is_new_table: if contract_modes["tables"] in ["discard_row", "discard_value"]: return None, None if contract_modes["tables"] == "freeze": - raise SchemaFrozenException(self.name, table_name, f"Trying to add table {table_name} but new tables are frozen.") + raise SchemaFrozenException(schema.name if schema else "", table_name, f"Trying to add table {table_name} but new tables are frozen.") # in case we only check table creation in pipeline if not row: return row, partial_table # if evolve once is set, allow all column changes - evolve_once = (table_name in self.tables) and self.tables[table_name].get("x-normalizer", {}).get("evolve_once", False) # type: ignore[attr-defined] + evolve_once = (table_name in schema.tables) and schema.tables[table_name].get("x-normalizer", {}).get("evolve_once", False) # type: ignore[attr-defined] if evolve_once: return row, partial_table # check columns for item in list(row.keys()): # dlt cols may always be added - if item.startswith(self._dlt_tables_prefix): + if item.startswith(schema._dlt_tables_prefix): continue # if this is a new column for an existing table... - if not is_new_table and (item not in self.tables[table_name]["columns"] or not utils.is_complete_column(self.tables[table_name]["columns"][item])): + if not is_new_table and (item not in schema.tables[table_name]["columns"] or not utils.is_complete_column(schema.tables[table_name]["columns"][item])): is_variant = (item in partial_table["columns"]) and partial_table["columns"][item].get("variant") if contract_modes["columns"] == "discard_value" or (is_variant and contract_modes["data_type"] == "discard_value"): row.pop(item) @@ -251,9 +252,9 @@ def apply_schema_contract(self, contract_modes: TSchemaContractDict, table_name: elif contract_modes["columns"] == "discard_row" or (is_variant and contract_modes["data_type"] == "discard_row"): return None, None elif is_variant and contract_modes["data_type"] == "freeze": - raise SchemaFrozenException(self.name, table_name, f"Trying to create new variant column {item} to table {table_name} data_types are frozen.") + raise SchemaFrozenException(schema.name, table_name, f"Trying to create new variant column {item} to table {table_name} data_types are frozen.") elif contract_modes["columns"] == "freeze": - raise SchemaFrozenException(self.name, table_name, f"Trying to add column {item} to table {table_name} but columns are frozen.") + raise SchemaFrozenException(schema.name, table_name, f"Trying to add column {item} to table {table_name} but columns are frozen.") return row, partial_table @@ -463,6 +464,8 @@ def update_normalizers(self) -> None: self._configure_normalizers(normalizers) def set_schema_contract(self, settings: TSchemaContract, update_table_settings: bool = False) -> None: + if not settings: + return self._settings["schema_contract"] = settings if update_table_settings: for table in self.tables.values(): @@ -666,6 +669,8 @@ def __repr__(self) -> str: def resolve_contract_settings_for_table(parent_table: str, table_name: str, current_schema: Schema, incoming_schema: Schema = None, incoming_table: TTableSchema = None) -> TSchemaContractDict: """Resolve the exact applicable schema contract settings for the table during the normalization stage.""" + current_schema = current_schema or incoming_schema + def resolve_single(settings: TSchemaContract) -> TSchemaContractDict: settings = settings or {} if isinstance(settings, str): diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index fe1fd16ab0..2d53258a65 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -69,7 +69,7 @@ def source( root_key: bool = False, schema: Schema = None, schema_contract: TSchemaContract = None, - spec: Type[BaseConfiguration] = None + spec: Type[BaseConfiguration] = None, _impl_cls: Type[TDltSourceImpl] = DltSource # type: ignore[assignment] ) -> Callable[[Callable[TSourceFunParams, Any]], Callable[TSourceFunParams, TDltSourceImpl]]: ... @@ -83,9 +83,8 @@ def source( root_key: bool = False, schema: Schema = None, schema_contract: TSchemaContract = None, - spec: Type[BaseConfiguration] = None + spec: Type[BaseConfiguration] = None, _impl_cls: Type[TDltSourceImpl] = DltSource # type: ignore[assignment] ->>>>>>> devel ) -> Any: """A decorator that transforms a function returning one or more `dlt resources` into a `dlt source` in order to load it with `dlt`. @@ -358,7 +357,7 @@ def make_resource(_name: str, _section: str, _data: Any, incremental: Incrementa columns=columns, primary_key=primary_key, merge_key=merge_key, - schema_contract=schema_contract + schema_contract=schema_contract, table_format=table_format ) return DltResource.from_data(_data, _name, _section, table_template, selected, cast(DltResource, data_from), incremental=incremental) diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index c4a9bc8c8c..dfd2a43e6b 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -1,6 +1,6 @@ import contextlib import os -from typing import ClassVar, List, Set, Dict, Type, Any, Sequence, Optional +from typing import ClassVar, List, Set, Dict, Type, Any, Sequence, Optional, Set from collections import defaultdict from dlt.common.configuration.container import Container @@ -116,7 +116,8 @@ def __init__( schema: Schema, resources_with_items: Set[str], dynamic_tables: TSchemaUpdate, - collector: Collector = NULL_COLLECTOR + collector: Collector = NULL_COLLECTOR, + pipeline_schema: Schema = None ) -> None: self._storage = storage self.schema = schema @@ -124,6 +125,8 @@ def __init__( self.collector = collector self.resources_with_items = resources_with_items self.extract_id = extract_id + self.disallowed_tables: Set[str] = set() + self.pipeline_schema = pipeline_schema @property def storage(self) -> ExtractorItemStorage: @@ -148,7 +151,6 @@ def write_table(self, resource: DltResource, items: TDataItems, meta: Any) -> No if isinstance(meta, TableNameMeta): table_name = meta.table_name self._write_static_table(resource, table_name, items) - self._write_item(table_name, resource.name, items) else: if resource._table_name_hint_fun: if isinstance(items, list): @@ -160,7 +162,6 @@ def write_table(self, resource: DltResource, items: TDataItems, meta: Any) -> No # write item belonging to table with static name table_name = resource.table_name # type: ignore[assignment] self._write_static_table(resource, table_name, items) - self._write_item(table_name, resource.name, items) def write_empty_file(self, table_name: str) -> None: table_name = self.schema.naming.normalize_table_identifier(table_name) @@ -179,7 +180,8 @@ def _write_dynamic_table(self, resource: DltResource, item: TDataItem) -> None: table_name = resource._table_name_hint_fun(item) existing_table = self.dynamic_tables.get(table_name) if existing_table is None: - self.dynamic_tables[table_name] = [resource.compute_table_schema(item)] + if not self._add_dynamic_table(resource, data_item=item): + return else: # quick check if deep table merge is required if resource._table_has_other_dynamic_hints: @@ -195,9 +197,40 @@ def _write_dynamic_table(self, resource: DltResource, item: TDataItem) -> None: def _write_static_table(self, resource: DltResource, table_name: str, items: TDataItems) -> None: existing_table = self.dynamic_tables.get(table_name) if existing_table is None: - static_table = resource.compute_table_schema() - static_table["name"] = table_name - self.dynamic_tables[table_name] = [static_table] + if not self._add_dynamic_table(resource, table_name=table_name): + return + self._write_item(table_name, resource.name, items) + + def _add_dynamic_table(self, resource: DltResource, data_item: TDataItem = None, table_name: Optional[str] = None) -> bool: + """ + Computes new table and does contract checks + """ + # TODO: We have to normalize table identifiers here + table = resource.compute_table_schema(data_item) + if table_name: + table["name"] = table_name + + # fast exit if we already evaluated this + if table["name"] in self.disallowed_tables: + return False + + # this is a new table so allow evolve once + # TODO: is this the correct check for a new table, should a table with only incomplete columns be new too? + is_new_table = (self.pipeline_schema == None) or (table["name"] not in self.pipeline_schema.tables) or (not self.pipeline_schema.tables[table["name"]]["columns"]) + if is_new_table: + table["x-normalizer"] = {"evolve_once": True} # type: ignore[typeddict-unknown-key] + + # apply schema contract and apply on pipeline schema + # here we only check that table may be created + schema_contract = resolve_contract_settings_for_table(None, table["name"], self.pipeline_schema, self.schema, table) + _, checked_table = Schema.apply_schema_contract(self.pipeline_schema, schema_contract, table["name"], None, table) + + if not checked_table: + self.disallowed_tables.add(table["name"]) + return False + + self.dynamic_tables[checked_table["name"]] = [checked_table] + return True class JsonLExtractor(Extractor): @@ -238,6 +271,7 @@ def extract( storage: ExtractorStorage, collector: Collector = NULL_COLLECTOR, *, + pipeline_schema: Schema = None, max_parallel_items: int = None, workers: int = None, futures_poll_interval: float = None @@ -247,10 +281,10 @@ def extract( resources_with_items: Set[str] = set() extractors: Dict[TLoaderFileFormat, Extractor] = { "puae-jsonl": JsonLExtractor( - extract_id, storage, schema, resources_with_items, dynamic_tables, collector=collector + extract_id, storage, schema, resources_with_items, dynamic_tables, collector=collector, pipeline_schema=pipeline_schema ), "arrow": ArrowExtractor( - extract_id, storage, schema, resources_with_items, dynamic_tables, collector=collector + extract_id, storage, schema, resources_with_items, dynamic_tables, collector=collector, pipeline_schema=pipeline_schema ) } last_item_format: Optional[TLoaderFileFormat] = None @@ -318,11 +352,10 @@ def extract_with_schema( with contextlib.suppress(DataItemRequiredForDynamicTableHints): if resource.write_disposition == "replace": reset_resource_state(resource.name) - - extractor = extract(extract_id, source, storage, collector, pipeline_schema, max_parallel_items=max_parallel_items, workers=workers) + extractor = extract(extract_id, source, storage, collector, max_parallel_items=max_parallel_items, workers=workers, pipeline_schema=pipeline_schema) # iterate over all items in the pipeline and update the schema if dynamic table hints were present for _, partials in extractor.items(): for partial in partials: - schema.update_table(schema.normalize_table_identifiers(partial)) + source.schema.update_table(source.schema.normalize_table_identifiers(partial)) return extract_id diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index 2f613f4b40..526345bc89 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -5,12 +5,12 @@ from dlt.common import json, logger from dlt.common.json import custom_pua_decode from dlt.common.runtime import signals -from dlt.common.schema.typing import TTableSchemaColumns -from dlt.common.storages import NormalizeStorage, LoadStorage, NormalizeStorageConfiguration, FileStorage +from dlt.common.schema.typing import TTableSchemaColumns, TSchemaContractDict +from dlt.common.storages import NormalizeStorage, LoadStorage, FileStorage from dlt.common.typing import TDataItem from dlt.common.schema import TSchemaUpdate, Schema from dlt.common.utils import TRowCount, merge_row_count, increase_row_count - +from dlt.common.schema.schema import resolve_contract_settings_for_table class ItemsNormalizer(Protocol): def __call__( @@ -41,45 +41,57 @@ def _normalize_chunk( schema_name = schema.name items_count = 0 row_counts: TRowCount = {} + schema_contract: TSchemaContractDict = None for item in items: for (table_name, parent_table), row in schema.normalize_data_item( item, load_id, root_table_name ): + if not schema_contract: + schema_contract = resolve_contract_settings_for_table(parent_table, table_name, schema) # filter row, may eliminate some or all fields row = schema.filter_row(table_name, row) # do not process empty rows - if row: - # decode pua types - for k, v in row.items(): - row[k] = custom_pua_decode(v) # type: ignore - # coerce row of values into schema table, generating partial table with new columns if any - row, partial_table = schema.coerce_row( - table_name, parent_table, row - ) - # theres a new table or new columns in existing table - if partial_table: - # update schema and save the change - schema.update_table(partial_table) - table_updates = schema_update.setdefault(table_name, []) - table_updates.append(partial_table) - # update our columns - column_schemas[table_name] = schema.get_table_columns( - table_name - ) - # get current columns schema - columns = column_schemas.get(table_name) - if not columns: - columns = schema.get_table_columns(table_name) - column_schemas[table_name] = columns - # store row - # TODO: it is possible to write to single file from many processes using this: https://gitlab.com/warsaw/flufl.lock - load_storage.write_data_item( - load_id, schema_name, table_name, row, columns + if not row: + continue + + # decode pua types + for k, v in row.items(): + row[k] = custom_pua_decode(v) # type: ignore + # coerce row of values into schema table, generating partial table with new columns if any + row, partial_table = schema.coerce_row( + table_name, parent_table, row + ) + + # if we detect a migration, check schema contract + if partial_table: + row, partial_table = Schema.apply_schema_contract(schema, schema_contract, table_name, row, partial_table) + if not row: + continue + + # theres a new table or new columns in existing table + if partial_table: + # update schema and save the change + schema.update_table(partial_table) + table_updates = schema_update.setdefault(table_name, []) + table_updates.append(partial_table) + # update our columns + column_schemas[table_name] = schema.get_table_columns( + table_name ) - # count total items - items_count += 1 - increase_row_count(row_counts, table_name, 1) + # get current columns schema + columns = column_schemas.get(table_name) + if not columns: + columns = schema.get_table_columns(table_name) + column_schemas[table_name] = columns + # store row + # TODO: it is possible to write to single file from many processes using this: https://gitlab.com/warsaw/flufl.lock + load_storage.write_data_item( + load_id, schema_name, table_name, row, columns + ) + # count total items + items_count += 1 + increase_row_count(row_counts, table_name, 1) signals.raise_if_signalled() return schema_update, items_count, row_counts diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 1005af0f1e..c2c9ea9fc8 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -863,20 +863,25 @@ def _extract_source(self, storage: ExtractorStorage, source: DltSource, max_para source_schema = source.schema source_schema.update_normalizers() + # discover the existing pipeline schema + pipeline_schema = self._schema_storage[source_schema.name] if source_schema.name in self._schema_storage else None + # extract into pipeline schema - extract_id = extract_with_schema(storage, source, source_schema, self.collector, max_parallel_items, workers) + source.schema.set_schema_contract(global_contract, True) + extract_id = extract_with_schema(storage, source, pipeline_schema, self.collector, max_parallel_items, workers) # save import with fully discovered schema self._schema_storage.save_import_schema_if_not_exists(source_schema) - # if source schema does not exist in the pipeline - if source_schema.name not in self._schema_storage: - # create new schema + # save schema if not present in store + if not pipeline_schema: self._schema_storage.save_schema(source_schema) + pipeline_schema = source_schema - # update pipeline schema (do contract checks here) - pipeline_schema = self._schema_storage[source_schema.name] + # update pipeline schema + print(source_schema.tables) pipeline_schema.update_schema(source_schema) + pipeline_schema.set_schema_contract(global_contract, True) # set as default if this is first schema in pipeline if not self.default_schema_name: diff --git a/tests/common/schema/test_contract_mode_functions.py b/tests/common/schema/test_schema_contract.py similarity index 99% rename from tests/common/schema/test_contract_mode_functions.py rename to tests/common/schema/test_schema_contract.py index 95f2d0e267..d635983367 100644 --- a/tests/common/schema/test_contract_mode_functions.py +++ b/tests/common/schema/test_schema_contract.py @@ -245,7 +245,7 @@ def test_check_adding_new_columns(base_settings) -> None: "column_2": 123 } data_with_new_row = { - **data, # type: ignore + **data, "new_column": "some string" } table_update: TTableSchema = { @@ -276,7 +276,7 @@ def test_check_adding_new_columns(base_settings) -> None: "column_2": 123, } data_with_new_row = { - **data, # type: ignore + **data, "incomplete_column_1": "some other string", } table_update = { @@ -312,7 +312,7 @@ def test_check_adding_new_variant() -> None: "column_2": 123 } data_with_new_row = { - **data, # type: ignore + **data, "column_2_variant": 345345 } table_update: TTableSchema = { diff --git a/tests/load/test_freeze_and_data_contract.py b/tests/load/test_schema_contracts.py similarity index 99% rename from tests/load/test_freeze_and_data_contract.py rename to tests/load/test_schema_contracts.py index 7b2beee84c..da2e1b2568 100644 --- a/tests/load/test_freeze_and_data_contract.py +++ b/tests/load/test_schema_contracts.py @@ -123,7 +123,7 @@ def source() -> DltResource: pipeline.run(source(), schema_contract=settings.get("override")) # check updated schema - assert pipeline.default_schema._settings.get("schema_contract", {}) == (settings.get("override") or settings.get("source")) + assert pipeline.default_schema._settings.get("schema_contract", None) == (settings.get("override") or settings.get("source")) # check items table settings assert pipeline.default_schema.tables["items"].get("schema_contract", {}) == (settings.get("override") or settings.get("resource") or {})