Skip to content

Commit

Permalink
update schema management
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Sep 13, 2023
1 parent 881d79a commit e707580
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 50 deletions.
12 changes: 7 additions & 5 deletions dlt/common/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def coerce_row(self, table_name: str, parent_table: str, row: StrAny) -> Tuple[D

return new_row, updated_table_partial

def resolve_evolution_settings_for_table(self, parent_table: str, table_name: str, schema_contract_settings_override: TSchemaContractSettings) -> TSchemaContractModes:
def resolve_contract_settings_for_table(self, parent_table: str, table_name: str) -> TSchemaContractModes:

def resolve_single(settings: TSchemaContractSettings) -> TSchemaContractModes:
settings = settings or {}
Expand All @@ -208,15 +208,14 @@ def resolve_single(settings: TSchemaContractSettings) -> TSchemaContractModes:
# modes
table_contract_modes = resolve_single(self.tables.get(table_with_settings, {}).get("schema_contract_settings", {}))
schema_contract_modes = resolve_single(self._settings.get("schema_contract_settings", {}))
overide_modes = resolve_single(schema_contract_settings_override)

# resolve to correct settings dict
settings = cast(TSchemaContractModes, {**DEFAULT_SCHEMA_CONTRACT_MODE, **schema_contract_modes, **table_contract_modes, **overide_modes})
settings = cast(TSchemaContractModes, {**DEFAULT_SCHEMA_CONTRACT_MODE, **schema_contract_modes, **table_contract_modes})

return settings


def check_schema_update(self, contract_modes: TSchemaContractModes, table_name: str, row: DictStrAny, partial_table: TPartialTableSchema, schema_contract_settings_override: TSchemaContractSettings) -> Tuple[DictStrAny, TPartialTableSchema]:
def check_schema_update(self, contract_modes: TSchemaContractModes, table_name: str, row: DictStrAny, partial_table: TPartialTableSchema) -> Tuple[DictStrAny, TPartialTableSchema]:
"""Checks if schema update mode allows for the requested changes, filter row or reject update, depending on the mode"""

assert partial_table
Expand Down Expand Up @@ -450,8 +449,11 @@ def update_normalizers(self) -> None:
normalizers["json"] = normalizers["json"] or self._normalizers_config["json"]
self._configure_normalizers(normalizers)

def set_schema_contract_settings(self, settings: TSchemaContractSettings) -> None:
def set_schema_contract_settings(self, settings: TSchemaContractSettings, update_table_settings: bool = False) -> None:
self._settings["schema_contract_settings"] = settings
if update_table_settings:
for table in self.tables.values():
table["schema_contract_settings"] = settings

def _infer_column(self, k: str, v: Any, data_type: TDataType = None, is_variant: bool = False) -> TColumnSchema:
column_schema = TColumnSchema(
Expand Down
2 changes: 1 addition & 1 deletion dlt/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def is_final_type(t: Type[Any]) -> bool:

def extract_union_types(t: Type[Any], no_none: bool = False) -> List[Any]:
if no_none:
return [arg for arg in get_args(t) if arg is not type(None)]
return [arg for arg in get_args(t) if arg is not type(None)] # noqa: E721
return list(get_args(t))

def is_literal_type(hint: Type[Any]) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion dlt/extract/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ def _wrap(*args: Any, **kwargs: Any) -> DltSource:

# prepare schema
schema = schema.clone(update_normalizers=True)
schema.set_schema_contract_settings(schema_contract_settings)

# convert to source
s = DltSource.from_data(name, source_section, schema.clone(update_normalizers=True), rv)
# apply hints
if max_table_nesting is not None:
s.max_table_nesting = max_table_nesting
s.schema_contract_settings = schema_contract_settings
# enable root propagation
s.root_key = root_key
return s
Expand Down
10 changes: 9 additions & 1 deletion dlt/extract/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext
from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer, RelationalNormalizerConfigPropagation
from dlt.common.schema import Schema
from dlt.common.schema.typing import TColumnName
from dlt.common.schema.typing import TColumnName, TSchemaContractSettings
from dlt.common.typing import AnyFun, StrAny, TDataItem, TDataItems, NoneType
from dlt.common.configuration.container import Container
from dlt.common.pipeline import PipelineContext, StateInjectableContext, SupportsPipelineRun, resource_state, source_state, pipeline_state
Expand Down Expand Up @@ -605,6 +605,14 @@ def max_table_nesting(self) -> int:
def max_table_nesting(self, value: int) -> None:
RelationalNormalizer.update_normalizer_config(self._schema, {"max_nesting": value})

@property
def schema_contract_settings(self) -> TSchemaContractSettings:
return self.schema.settings["schema_contract_settings"]

@schema_contract_settings.setter
def schema_contract_settings(self, settings: TSchemaContractSettings) -> None:
self.schema.set_schema_contract_settings(settings)

@property
def exhausted(self) -> bool:
"""check all selected pipes wether one of them has started. if so, the source is exhausted."""
Expand Down
29 changes: 13 additions & 16 deletions dlt/normalize/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dlt.common.runners import TRunMetrics, Runnable
from dlt.common.runtime import signals
from dlt.common.runtime.collector import Collector, NULL_COLLECTOR
from dlt.common.schema.typing import TStoredSchema, TTableSchemaColumns, TSchemaContractSettings, TSchemaContractModes
from dlt.common.schema.typing import TStoredSchema, TTableSchemaColumns, TSchemaContractModes
from dlt.common.schema.utils import merge_schema_updates
from dlt.common.storages.exceptions import SchemaNotFoundError
from dlt.common.storages import NormalizeStorage, SchemaStorage, LoadStorage, LoadStorageConfiguration, NormalizeStorageConfiguration
Expand All @@ -26,23 +26,22 @@
# normalize worker wrapping function (map_parallel, map_single) return type
TMapFuncRV = Tuple[Sequence[TSchemaUpdate], TRowCount]
# normalize worker wrapping function signature
TMapFuncType = Callable[[Schema, str, Sequence[str], TSchemaContractSettings], TMapFuncRV] # input parameters: (schema name, load_id, list of files to process)
TMapFuncType = Callable[[Schema, str, Sequence[str]], TMapFuncRV] # input parameters: (schema name, load_id, list of files to process)
# tuple returned by the worker
TWorkerRV = Tuple[List[TSchemaUpdate], int, List[str], TRowCount]


class Normalize(Runnable[ProcessPool]):

@with_config(spec=NormalizeConfiguration, sections=(known_sections.NORMALIZE,))
def __init__(self, collector: Collector = NULL_COLLECTOR, schema_storage: SchemaStorage = None, config: NormalizeConfiguration = config.value, schema_contract_settings: TSchemaContractSettings = None) -> None:
def __init__(self, collector: Collector = NULL_COLLECTOR, schema_storage: SchemaStorage = None, config: NormalizeConfiguration = config.value) -> None:
self.config = config
self.collector = collector
self.pool: ProcessPool = None
self.normalize_storage: NormalizeStorage = None
self.load_storage: LoadStorage = None
self.schema_storage: SchemaStorage = None
self._row_counts: TRowCount = {}
self.schema_contract_settings = schema_contract_settings

# setup storages
self.create_storages()
Expand Down Expand Up @@ -73,8 +72,7 @@ def w_normalize_files(
destination_caps: DestinationCapabilitiesContext,
stored_schema: TStoredSchema,
load_id: str,
extracted_items_files: Sequence[str],
schema_contract_settings: TSchemaContractSettings
extracted_items_files: Sequence[str]
) -> TWorkerRV:
schema_updates: List[TSchemaUpdate] = []
total_items = 0
Expand All @@ -99,7 +97,7 @@ def w_normalize_files(
items_count = 0
for line_no, line in enumerate(f):
items: List[TDataItem] = json.loads(line)
partial_update, items_count, r_counts = Normalize._w_normalize_chunk(load_storage, schema, load_id, root_table_name, items, schema_contract_settings)
partial_update, items_count, r_counts = Normalize._w_normalize_chunk(load_storage, schema, load_id, root_table_name, items)
schema_updates.append(partial_update)
total_items += items_count
merge_row_count(row_counts, r_counts)
Expand Down Expand Up @@ -128,7 +126,7 @@ def w_normalize_files(
return schema_updates, total_items, load_storage.closed_files(), row_counts

@staticmethod
def _w_normalize_chunk(load_storage: LoadStorage, schema: Schema, load_id: str, root_table_name: str, items: List[TDataItem], schema_contract_settings: TSchemaContractSettings) -> Tuple[TSchemaUpdate, int, TRowCount]:
def _w_normalize_chunk(load_storage: LoadStorage, schema: Schema, load_id: str, root_table_name: str, items: List[TDataItem]) -> Tuple[TSchemaUpdate, int, TRowCount]:
column_schemas: Dict[str, TTableSchemaColumns] = {} # quick access to column schema for writers below
schema_update: TSchemaUpdate = {}
schema_name = schema.name
Expand All @@ -139,7 +137,7 @@ def _w_normalize_chunk(load_storage: LoadStorage, schema: Schema, load_id: str,
for item in items:
for (table_name, parent_table), row in schema.normalize_data_item(item, load_id, root_table_name):
if not schema_contract_modes:
schema_contract_modes = schema.resolve_evolution_settings_for_table(parent_table, table_name, schema_contract_settings)
schema_contract_modes = schema.resolve_contract_settings_for_table(parent_table, table_name)

# filter row, may eliminate some or all fields
row = schema.filter_row(table_name, row)
Expand All @@ -153,7 +151,7 @@ def _w_normalize_chunk(load_storage: LoadStorage, schema: Schema, load_id: str,
row, partial_table = schema.coerce_row(table_name, parent_table, row)
# if we detect a migration, the check update
if partial_table:
row, partial_table = schema.check_schema_update(schema_contract_modes, table_name, row, partial_table, schema_contract_settings)
row, partial_table = schema.check_schema_update(schema_contract_modes, table_name, row, partial_table)
if not row:
continue

Expand Down Expand Up @@ -204,12 +202,12 @@ def group_worker_files(files: Sequence[str], no_groups: int) -> List[Sequence[st
l_idx = idx + 1
return chunk_files

def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str], schema_contract_settings: TSchemaContractSettings) -> TMapFuncRV:
def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str]) -> TMapFuncRV:
workers = self.pool._processes # type: ignore
chunk_files = self.group_worker_files(files, workers)
schema_dict: TStoredSchema = schema.to_dict()
config_tuple = (self.normalize_storage.config, self.load_storage.config, self.config.destination_capabilities, schema_dict)
param_chunk = [[*config_tuple, load_id, files, schema_contract_settings] for files in chunk_files]
param_chunk = [[*config_tuple, load_id, files] for files in chunk_files]
tasks: List[Tuple[AsyncResult[TWorkerRV], List[Any]]] = []
row_counts: TRowCount = {}

Expand Down Expand Up @@ -259,15 +257,14 @@ def map_parallel(self, schema: Schema, load_id: str, files: Sequence[str], schem

return schema_updates, row_counts

def map_single(self, schema: Schema, load_id: str, files: Sequence[str], schema_contract_settings: TSchemaContractSettings) -> TMapFuncRV:
def map_single(self, schema: Schema, load_id: str, files: Sequence[str]) -> TMapFuncRV:
result = Normalize.w_normalize_files(
self.normalize_storage.config,
self.load_storage.config,
self.config.destination_capabilities,
schema.to_dict(),
load_id,
files,
schema_contract_settings
files
)
self.update_schema(schema, result[0])
self.collector.update("Files", len(result[2]))
Expand All @@ -278,7 +275,7 @@ def spool_files(self, schema_name: str, load_id: str, map_f: TMapFuncType, files
schema = Normalize.load_or_create_schema(self.schema_storage, schema_name)

# process files in parallel or in single thread, depending on map_f
schema_updates, row_counts = map_f(schema, load_id, files, self.schema_contract_settings)
schema_updates, row_counts = map_f(schema, load_id, files)
# logger.metrics("Normalize metrics", extra=get_logging_extras([self.schema_version_gauge.labels(schema_name)]))
if len(schema_updates) > 0:
logger.info(f"Saving schema {schema_name} with version {schema.version}, writing manifest files")
Expand Down
6 changes: 5 additions & 1 deletion dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ def normalize(self, workers: int = 1, loader_file_format: TLoaderFileFormat = No
if not self.default_schema_name:
return None

# update global schema contract settings, could be moved into def normalize()
if schema_contract_settings is not None:
self.default_schema.set_schema_contract_settings(schema_contract_settings, True)

# make sure destination capabilities are available
self._get_destination_capabilities()
# create default normalize config
Expand All @@ -316,7 +320,7 @@ def normalize(self, workers: int = 1, loader_file_format: TLoaderFileFormat = No
# run with destination context
with self._maybe_destination_capabilities(loader_file_format=loader_file_format):
# shares schema storage with the pipeline so we do not need to install
normalize = Normalize(collector=self.collector, config=normalize_config, schema_storage=self._schema_storage, schema_contract_settings=schema_contract_settings)
normalize = Normalize(collector=self.collector, config=normalize_config, schema_storage=self._schema_storage)
try:
with signals.delayed_signals():
runner.run_pool(normalize.config, normalize)
Expand Down
Loading

0 comments on commit e707580

Please sign in to comment.