Skip to content

Commit

Permalink
post merge code updates
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Oct 17, 2023
1 parent 96785c4 commit 302d909
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 68 deletions.
21 changes: 13 additions & 8 deletions dlt/common/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ def coerce_row(self, table_name: str, parent_table: str, row: StrAny) -> Tuple[D

return new_row, updated_table_partial

def apply_schema_contract(self, contract_modes: TSchemaContractDict, table_name: str, row: DictStrAny, partial_table: TPartialTableSchema) -> Tuple[DictStrAny, TPartialTableSchema]:
@staticmethod
def apply_schema_contract(schema: "Schema", contract_modes: TSchemaContractDict, table_name: str, row: DictStrAny, partial_table: TPartialTableSchema) -> Tuple[DictStrAny, TPartialTableSchema]:
"""
Checks if contract mode allows for the requested changes to the data and the schema. It will allow all changes to pass, filter out the row filter out
columns for both the data and the schema_update or reject the update completely, depending on the mode. An example settings could be:
Expand All @@ -219,41 +220,41 @@ def apply_schema_contract(self, contract_modes: TSchemaContractDict, table_name:
if contract_modes == DEFAULT_SCHEMA_CONTRACT_MODE:
return row, partial_table

is_new_table = (table_name not in self.tables) or (not self.tables[table_name]["columns"])
is_new_table = not schema or (table_name not in schema.tables) or (not schema.tables[table_name]["columns"])

# check case where we have a new table
if is_new_table:
if contract_modes["tables"] in ["discard_row", "discard_value"]:
return None, None
if contract_modes["tables"] == "freeze":
raise SchemaFrozenException(self.name, table_name, f"Trying to add table {table_name} but new tables are frozen.")
raise SchemaFrozenException(schema.name if schema else "", table_name, f"Trying to add table {table_name} but new tables are frozen.")

# in case we only check table creation in pipeline
if not row:
return row, partial_table

# if evolve once is set, allow all column changes
evolve_once = (table_name in self.tables) and self.tables[table_name].get("x-normalizer", {}).get("evolve_once", False) # type: ignore[attr-defined]
evolve_once = (table_name in schema.tables) and schema.tables[table_name].get("x-normalizer", {}).get("evolve_once", False) # type: ignore[attr-defined]
if evolve_once:
return row, partial_table

# check columns
for item in list(row.keys()):
# dlt cols may always be added
if item.startswith(self._dlt_tables_prefix):
if item.startswith(schema._dlt_tables_prefix):
continue
# if this is a new column for an existing table...
if not is_new_table and (item not in self.tables[table_name]["columns"] or not utils.is_complete_column(self.tables[table_name]["columns"][item])):
if not is_new_table and (item not in schema.tables[table_name]["columns"] or not utils.is_complete_column(schema.tables[table_name]["columns"][item])):
is_variant = (item in partial_table["columns"]) and partial_table["columns"][item].get("variant")
if contract_modes["columns"] == "discard_value" or (is_variant and contract_modes["data_type"] == "discard_value"):
row.pop(item)
partial_table["columns"].pop(item)
elif contract_modes["columns"] == "discard_row" or (is_variant and contract_modes["data_type"] == "discard_row"):
return None, None
elif is_variant and contract_modes["data_type"] == "freeze":
raise SchemaFrozenException(self.name, table_name, f"Trying to create new variant column {item} to table {table_name} data_types are frozen.")
raise SchemaFrozenException(schema.name, table_name, f"Trying to create new variant column {item} to table {table_name} data_types are frozen.")
elif contract_modes["columns"] == "freeze":
raise SchemaFrozenException(self.name, table_name, f"Trying to add column {item} to table {table_name} but columns are frozen.")
raise SchemaFrozenException(schema.name, table_name, f"Trying to add column {item} to table {table_name} but columns are frozen.")

return row, partial_table

Expand Down Expand Up @@ -463,6 +464,8 @@ def update_normalizers(self) -> None:
self._configure_normalizers(normalizers)

def set_schema_contract(self, settings: TSchemaContract, update_table_settings: bool = False) -> None:
if not settings:
return
self._settings["schema_contract"] = settings
if update_table_settings:
for table in self.tables.values():
Expand Down Expand Up @@ -666,6 +669,8 @@ def __repr__(self) -> str:
def resolve_contract_settings_for_table(parent_table: str, table_name: str, current_schema: Schema, incoming_schema: Schema = None, incoming_table: TTableSchema = None) -> TSchemaContractDict:
"""Resolve the exact applicable schema contract settings for the table during the normalization stage."""

current_schema = current_schema or incoming_schema

def resolve_single(settings: TSchemaContract) -> TSchemaContractDict:
settings = settings or {}
if isinstance(settings, str):
Expand Down
7 changes: 3 additions & 4 deletions dlt/extract/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def source(
root_key: bool = False,
schema: Schema = None,
schema_contract: TSchemaContract = None,
spec: Type[BaseConfiguration] = None
spec: Type[BaseConfiguration] = None,
_impl_cls: Type[TDltSourceImpl] = DltSource # type: ignore[assignment]
) -> Callable[[Callable[TSourceFunParams, Any]], Callable[TSourceFunParams, TDltSourceImpl]]:
...
Expand All @@ -83,9 +83,8 @@ def source(
root_key: bool = False,
schema: Schema = None,
schema_contract: TSchemaContract = None,
spec: Type[BaseConfiguration] = None
spec: Type[BaseConfiguration] = None,
_impl_cls: Type[TDltSourceImpl] = DltSource # type: ignore[assignment]
>>>>>>> devel
) -> Any:
"""A decorator that transforms a function returning one or more `dlt resources` into a `dlt source` in order to load it with `dlt`.
Expand Down Expand Up @@ -358,7 +357,7 @@ def make_resource(_name: str, _section: str, _data: Any, incremental: Incrementa
columns=columns,
primary_key=primary_key,
merge_key=merge_key,
schema_contract=schema_contract
schema_contract=schema_contract,
table_format=table_format
)
return DltResource.from_data(_data, _name, _section, table_template, selected, cast(DltResource, data_from), incremental=incremental)
Expand Down
59 changes: 46 additions & 13 deletions dlt/extract/extract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import os
from typing import ClassVar, List, Set, Dict, Type, Any, Sequence, Optional
from typing import ClassVar, List, Set, Dict, Type, Any, Sequence, Optional, Set
from collections import defaultdict

from dlt.common.configuration.container import Container
Expand Down Expand Up @@ -116,14 +116,17 @@ def __init__(
schema: Schema,
resources_with_items: Set[str],
dynamic_tables: TSchemaUpdate,
collector: Collector = NULL_COLLECTOR
collector: Collector = NULL_COLLECTOR,
pipeline_schema: Schema = None
) -> None:
self._storage = storage
self.schema = schema
self.dynamic_tables = dynamic_tables
self.collector = collector
self.resources_with_items = resources_with_items
self.extract_id = extract_id
self.disallowed_tables: Set[str] = set()
self.pipeline_schema = pipeline_schema

@property
def storage(self) -> ExtractorItemStorage:
Expand All @@ -148,7 +151,6 @@ def write_table(self, resource: DltResource, items: TDataItems, meta: Any) -> No
if isinstance(meta, TableNameMeta):
table_name = meta.table_name
self._write_static_table(resource, table_name, items)
self._write_item(table_name, resource.name, items)
else:
if resource._table_name_hint_fun:
if isinstance(items, list):
Expand All @@ -160,7 +162,6 @@ def write_table(self, resource: DltResource, items: TDataItems, meta: Any) -> No
# write item belonging to table with static name
table_name = resource.table_name # type: ignore[assignment]
self._write_static_table(resource, table_name, items)
self._write_item(table_name, resource.name, items)

def write_empty_file(self, table_name: str) -> None:
table_name = self.schema.naming.normalize_table_identifier(table_name)
Expand All @@ -179,7 +180,8 @@ def _write_dynamic_table(self, resource: DltResource, item: TDataItem) -> None:
table_name = resource._table_name_hint_fun(item)
existing_table = self.dynamic_tables.get(table_name)
if existing_table is None:
self.dynamic_tables[table_name] = [resource.compute_table_schema(item)]
if not self._add_dynamic_table(resource, data_item=item):
return
else:
# quick check if deep table merge is required
if resource._table_has_other_dynamic_hints:
Expand All @@ -195,9 +197,40 @@ def _write_dynamic_table(self, resource: DltResource, item: TDataItem) -> None:
def _write_static_table(self, resource: DltResource, table_name: str, items: TDataItems) -> None:
existing_table = self.dynamic_tables.get(table_name)
if existing_table is None:
static_table = resource.compute_table_schema()
static_table["name"] = table_name
self.dynamic_tables[table_name] = [static_table]
if not self._add_dynamic_table(resource, table_name=table_name):
return
self._write_item(table_name, resource.name, items)

def _add_dynamic_table(self, resource: DltResource, data_item: TDataItem = None, table_name: Optional[str] = None) -> bool:
"""
Computes new table and does contract checks
"""
# TODO: We have to normalize table identifiers here
table = resource.compute_table_schema(data_item)
if table_name:
table["name"] = table_name

# fast exit if we already evaluated this
if table["name"] in self.disallowed_tables:
return False

# this is a new table so allow evolve once
# TODO: is this the correct check for a new table, should a table with only incomplete columns be new too?
is_new_table = (self.pipeline_schema == None) or (table["name"] not in self.pipeline_schema.tables) or (not self.pipeline_schema.tables[table["name"]]["columns"])
if is_new_table:
table["x-normalizer"] = {"evolve_once": True} # type: ignore[typeddict-unknown-key]

# apply schema contract and apply on pipeline schema
# here we only check that table may be created
schema_contract = resolve_contract_settings_for_table(None, table["name"], self.pipeline_schema, self.schema, table)
_, checked_table = Schema.apply_schema_contract(self.pipeline_schema, schema_contract, table["name"], None, table)

if not checked_table:
self.disallowed_tables.add(table["name"])
return False

self.dynamic_tables[checked_table["name"]] = [checked_table]
return True


class JsonLExtractor(Extractor):
Expand Down Expand Up @@ -238,6 +271,7 @@ def extract(
storage: ExtractorStorage,
collector: Collector = NULL_COLLECTOR,
*,
pipeline_schema: Schema = None,
max_parallel_items: int = None,
workers: int = None,
futures_poll_interval: float = None
Expand All @@ -247,10 +281,10 @@ def extract(
resources_with_items: Set[str] = set()
extractors: Dict[TLoaderFileFormat, Extractor] = {
"puae-jsonl": JsonLExtractor(
extract_id, storage, schema, resources_with_items, dynamic_tables, collector=collector
extract_id, storage, schema, resources_with_items, dynamic_tables, collector=collector, pipeline_schema=pipeline_schema
),
"arrow": ArrowExtractor(
extract_id, storage, schema, resources_with_items, dynamic_tables, collector=collector
extract_id, storage, schema, resources_with_items, dynamic_tables, collector=collector, pipeline_schema=pipeline_schema
)
}
last_item_format: Optional[TLoaderFileFormat] = None
Expand Down Expand Up @@ -318,11 +352,10 @@ def extract_with_schema(
with contextlib.suppress(DataItemRequiredForDynamicTableHints):
if resource.write_disposition == "replace":
reset_resource_state(resource.name)

extractor = extract(extract_id, source, storage, collector, pipeline_schema, max_parallel_items=max_parallel_items, workers=workers)
extractor = extract(extract_id, source, storage, collector, max_parallel_items=max_parallel_items, workers=workers, pipeline_schema=pipeline_schema)
# iterate over all items in the pipeline and update the schema if dynamic table hints were present
for _, partials in extractor.items():
for partial in partials:
schema.update_table(schema.normalize_table_identifiers(partial))
source.schema.update_table(source.schema.normalize_table_identifiers(partial))

return extract_id
78 changes: 45 additions & 33 deletions dlt/normalize/items_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from dlt.common import json, logger
from dlt.common.json import custom_pua_decode
from dlt.common.runtime import signals
from dlt.common.schema.typing import TTableSchemaColumns
from dlt.common.storages import NormalizeStorage, LoadStorage, NormalizeStorageConfiguration, FileStorage
from dlt.common.schema.typing import TTableSchemaColumns, TSchemaContractDict
from dlt.common.storages import NormalizeStorage, LoadStorage, FileStorage
from dlt.common.typing import TDataItem
from dlt.common.schema import TSchemaUpdate, Schema
from dlt.common.utils import TRowCount, merge_row_count, increase_row_count

from dlt.common.schema.schema import resolve_contract_settings_for_table

class ItemsNormalizer(Protocol):
def __call__(
Expand Down Expand Up @@ -41,45 +41,57 @@ def _normalize_chunk(
schema_name = schema.name
items_count = 0
row_counts: TRowCount = {}
schema_contract: TSchemaContractDict = None

for item in items:
for (table_name, parent_table), row in schema.normalize_data_item(
item, load_id, root_table_name
):
if not schema_contract:
schema_contract = resolve_contract_settings_for_table(parent_table, table_name, schema)
# filter row, may eliminate some or all fields
row = schema.filter_row(table_name, row)
# do not process empty rows
if row:
# decode pua types
for k, v in row.items():
row[k] = custom_pua_decode(v) # type: ignore
# coerce row of values into schema table, generating partial table with new columns if any
row, partial_table = schema.coerce_row(
table_name, parent_table, row
)
# theres a new table or new columns in existing table
if partial_table:
# update schema and save the change
schema.update_table(partial_table)
table_updates = schema_update.setdefault(table_name, [])
table_updates.append(partial_table)
# update our columns
column_schemas[table_name] = schema.get_table_columns(
table_name
)
# get current columns schema
columns = column_schemas.get(table_name)
if not columns:
columns = schema.get_table_columns(table_name)
column_schemas[table_name] = columns
# store row
# TODO: it is possible to write to single file from many processes using this: https://gitlab.com/warsaw/flufl.lock
load_storage.write_data_item(
load_id, schema_name, table_name, row, columns
if not row:
continue

# decode pua types
for k, v in row.items():
row[k] = custom_pua_decode(v) # type: ignore
# coerce row of values into schema table, generating partial table with new columns if any
row, partial_table = schema.coerce_row(
table_name, parent_table, row
)

# if we detect a migration, check schema contract
if partial_table:
row, partial_table = Schema.apply_schema_contract(schema, schema_contract, table_name, row, partial_table)
if not row:
continue

# theres a new table or new columns in existing table
if partial_table:
# update schema and save the change
schema.update_table(partial_table)
table_updates = schema_update.setdefault(table_name, [])
table_updates.append(partial_table)
# update our columns
column_schemas[table_name] = schema.get_table_columns(
table_name
)
# count total items
items_count += 1
increase_row_count(row_counts, table_name, 1)
# get current columns schema
columns = column_schemas.get(table_name)
if not columns:
columns = schema.get_table_columns(table_name)
column_schemas[table_name] = columns
# store row
# TODO: it is possible to write to single file from many processes using this: https://gitlab.com/warsaw/flufl.lock
load_storage.write_data_item(
load_id, schema_name, table_name, row, columns
)
# count total items
items_count += 1
increase_row_count(row_counts, table_name, 1)
signals.raise_if_signalled()
return schema_update, items_count, row_counts

Expand Down
17 changes: 11 additions & 6 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,20 +863,25 @@ def _extract_source(self, storage: ExtractorStorage, source: DltSource, max_para
source_schema = source.schema
source_schema.update_normalizers()

# discover the existing pipeline schema
pipeline_schema = self._schema_storage[source_schema.name] if source_schema.name in self._schema_storage else None

# extract into pipeline schema
extract_id = extract_with_schema(storage, source, source_schema, self.collector, max_parallel_items, workers)
source.schema.set_schema_contract(global_contract, True)
extract_id = extract_with_schema(storage, source, pipeline_schema, self.collector, max_parallel_items, workers)

# save import with fully discovered schema
self._schema_storage.save_import_schema_if_not_exists(source_schema)

# if source schema does not exist in the pipeline
if source_schema.name not in self._schema_storage:
# create new schema
# save schema if not present in store
if not pipeline_schema:
self._schema_storage.save_schema(source_schema)
pipeline_schema = source_schema

# update pipeline schema (do contract checks here)
pipeline_schema = self._schema_storage[source_schema.name]
# update pipeline schema
print(source_schema.tables)
pipeline_schema.update_schema(source_schema)
pipeline_schema.set_schema_contract(global_contract, True)

# set as default if this is first schema in pipeline
if not self.default_schema_name:
Expand Down
Loading

0 comments on commit 302d909

Please sign in to comment.