Skip to content

Commit

Permalink
some PR work
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Sep 12, 2023
1 parent 6dcaa7d commit 881d79a
Show file tree
Hide file tree
Showing 18 changed files with 237 additions and 187 deletions.
4 changes: 2 additions & 2 deletions dlt/common/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dlt.common.destination import DestinationReference, TDestinationReferenceArg
from dlt.common.exceptions import DestinationHasFailedJobs, PipelineStateNotAvailable, ResourceNameNotAvailable, SourceSectionNotAvailable
from dlt.common.schema import Schema
from dlt.common.schema.typing import TColumnNames, TColumnSchema, TWriteDisposition, TSchemaEvolutionSettings
from dlt.common.schema.typing import TColumnNames, TColumnSchema, TWriteDisposition, TSchemaContractSettings
from dlt.common.source import get_current_pipe_name
from dlt.common.storages.load_storage import LoadPackageInfo
from dlt.common.typing import DictStrAny, REPattern
Expand Down Expand Up @@ -210,7 +210,7 @@ def run(
primary_key: TColumnNames = None,
schema: Schema = None,
loader_file_format: TLoaderFileFormat = None,
schema_evolution_settings: TSchemaEvolutionSettings = None,
schema_contract_settings: TSchemaContractSettings = None,
) -> LoadInfo:
...

Expand Down
6 changes: 4 additions & 2 deletions dlt/common/schema/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,7 @@ def __init__(self, schema_name: str, init_engine: int, from_engine: int, to_engi


class SchemaFrozenException(SchemaException):
def __init__(self, msg: str) -> None:
super().__init__(msg)
def __init__(self, schema_name: str, table_name: str, msg: str) -> None:
super().__init__(msg)
self.schema_name = schema_name
self.table_name = table_name
58 changes: 27 additions & 31 deletions dlt/common/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
from dlt.common.schema import utils
from dlt.common.data_types import py_type_to_sc_type, coerce_value, TDataType
from dlt.common.schema.typing import (COLUMN_HINTS, SCHEMA_ENGINE_VERSION, LOADS_TABLE_NAME, VERSION_TABLE_NAME, STATE_TABLE_NAME, TPartialTableSchema, TSchemaSettings, TSimpleRegex, TStoredSchema,
TSchemaTables, TTableSchema, TTableSchemaColumns, TColumnSchema, TColumnProp, TColumnHint, TTypeDetections, TSchemaEvolutionModes, TSchemaEvolutionSettings)
TSchemaTables, TTableSchema, TTableSchemaColumns, TColumnSchema, TColumnProp, TColumnHint, TTypeDetections, TSchemaContractModes, TSchemaContractSettings)
from dlt.common.schema.exceptions import (CannotCoerceColumnException, CannotCoerceNullException, InvalidSchemaName,
ParentTableNotFoundException, SchemaCorruptedException)
from dlt.common.validation import validate_dict
from dlt.common.schema.exceptions import SchemaFrozenException


DEFAULT_SCHEMA_EVOLUTION_MODES: TSchemaEvolutionModes = {
DEFAULT_SCHEMA_CONTRACT_MODE: TSchemaContractModes = {
"table": "evolve",
"column": "evolve",
"column_variant": "evolve"
"data_type": "evolve"
}

class Schema:
Expand Down Expand Up @@ -194,68 +194,64 @@ def coerce_row(self, table_name: str, parent_table: str, row: StrAny) -> Tuple[D

return new_row, updated_table_partial

def resolve_evolution_settings_for_table(self, parent_table: str, table_name: str, schema_evolution_settings_override: TSchemaEvolutionSettings) -> TSchemaEvolutionModes:
def resolve_evolution_settings_for_table(self, parent_table: str, table_name: str, schema_contract_settings_override: TSchemaContractSettings) -> TSchemaContractModes:

def resolve_single(settings: TSchemaEvolutionSettings) -> TSchemaEvolutionModes:
def resolve_single(settings: TSchemaContractSettings) -> TSchemaContractModes:
settings = settings or {}
if isinstance(settings, str):
return TSchemaEvolutionModes(table=settings, column=settings, column_variant=settings)
return TSchemaContractModes(table=settings, column=settings, data_type=settings)
return settings

# find table settings
table_with_settings = parent_table or table_name

# modes
table_evolution_modes = resolve_single(self.tables.get(table_with_settings, {}).get("schema_evolution_settings", {}))
schema_evolution_modes = resolve_single(self._settings.get("schema_evolution_settings", {}))
overide_modes = resolve_single(schema_evolution_settings_override)
table_contract_modes = resolve_single(self.tables.get(table_with_settings, {}).get("schema_contract_settings", {}))
schema_contract_modes = resolve_single(self._settings.get("schema_contract_settings", {}))
overide_modes = resolve_single(schema_contract_settings_override)

# resolve to correct settings dict
settings = cast(TSchemaEvolutionModes, {**DEFAULT_SCHEMA_EVOLUTION_MODES, **schema_evolution_modes, **table_evolution_modes, **overide_modes})
settings = cast(TSchemaContractModes, {**DEFAULT_SCHEMA_CONTRACT_MODE, **schema_contract_modes, **table_contract_modes, **overide_modes})

return settings


def check_schema_update(self, parent_table: str, table_name: str, row: DictStrAny, partial_table: TPartialTableSchema, schema_evolution_settings_override: TSchemaEvolutionSettings) -> Tuple[DictStrAny, TPartialTableSchema]:
def check_schema_update(self, contract_modes: TSchemaContractModes, table_name: str, row: DictStrAny, partial_table: TPartialTableSchema, schema_contract_settings_override: TSchemaContractSettings) -> Tuple[DictStrAny, TPartialTableSchema]:
"""Checks if schema update mode allows for the requested changes, filter row or reject update, depending on the mode"""

assert partial_table

# for now we defined the schema as new if there are no data columns defined
has_columns = self.has_data_columns
if not has_columns:
return row, partial_table

evolution_modes = self.resolve_evolution_settings_for_table(parent_table, table_name, schema_evolution_settings_override)

# default settings allow all evolutions, skipp all else
if evolution_modes == DEFAULT_SCHEMA_EVOLUTION_MODES:
if contract_modes == DEFAULT_SCHEMA_CONTRACT_MODE:
return row, partial_table

table_exists = table_name in self.tables and len(self.tables[table_name].get("columns", {}))
table_exists = table_name in self.tables and self.get_table_columns(table_name, include_incomplete=False)

# check case where we have a new table
if not table_exists:
if evolution_modes == "freeze-and-trim":
if contract_modes == "discard-value":
return None, None
if evolution_modes["table"] in ["freeze-and-discard", "freeze-and-trim"]:
if contract_modes["table"] in ["discard-row", "discard-value"]:
return None, None
if evolution_modes["table"] == "freeze-and-raise":
raise SchemaFrozenException(f"Trying to add table {table_name} but new tables are frozen.")
if contract_modes["table"] == "freeze":
raise SchemaFrozenException(self.name, table_name, f"Trying to add table {table_name} but new tables are frozen.")

# check columns
for item in list(row.keys()):
for item in list(row.keys()):
# if this is a new column for an existing table...
if table_exists and item not in self.tables[table_name]["columns"]:
if table_exists and (item not in self.tables[table_name]["columns"] or not utils.is_complete_column(self.tables[table_name]["columns"][item])):
is_variant = item in partial_table["columns"] and partial_table["columns"][item].get("variant")
if evolution_modes["column"] == "freeze-and-trim" or (is_variant and evolution_modes["column_variant"] == "freeze-and-trim"):
if contract_modes["column"] == "discard-value" or (is_variant and contract_modes["data_type"] == "discard-value"):
row.pop(item)
partial_table["columns"].pop(item)
if evolution_modes["column"] == "freeze-and-discard" or (is_variant and evolution_modes["column_variant"] == "freeze-and-discard"):
elif contract_modes["column"] == "discard-row" or (is_variant and contract_modes["data_type"] == "discard-row"):
return None, None
if evolution_modes["column"] == "freeze-and-raise" or (is_variant and evolution_modes["column_variant"] == "freeze-and-raise"):
raise SchemaFrozenException(f"Trying to add column {item} to table {table_name}  but columns are frozen.")
elif contract_modes["column"] == "freeze":
raise SchemaFrozenException(self.name, table_name, f"Trying to add column {item} to table {table_name}  but columns are frozen.")
elif is_variant and contract_modes["data_type"] == "freeze":
raise SchemaFrozenException(self.name, table_name, f"Trying to create new variant column {item} to table {table_name}  data_types are frozen.")


return row, partial_table

Expand Down Expand Up @@ -454,8 +450,8 @@ def update_normalizers(self) -> None:
normalizers["json"] = normalizers["json"] or self._normalizers_config["json"]
self._configure_normalizers(normalizers)

def set_schema_evolution_settings(self, settings: TSchemaEvolutionSettings) -> None:
self._settings["schema_evolution_settings"] = settings
def set_schema_contract_settings(self, settings: TSchemaContractSettings) -> None:
self._settings["schema_contract_settings"] = settings

def _infer_column(self, k: str, v: Any, data_type: TDataType = None, is_variant: bool = False) -> TColumnSchema:
column_schema = TColumnSchema(
Expand Down
16 changes: 8 additions & 8 deletions dlt/common/schema/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ class TColumnSchema(TColumnSchemaBase, total=False):
TColumnName = NewType("TColumnName", str)
SIMPLE_REGEX_PREFIX = "re:"

TSchemaEvolutionMode = Literal["evolve", "freeze-and-trim", "freeze-and-raise", "freeze-and-discard"]
TSchemaEvolutionMode = Literal["evolve", "discard-value", "freeze", "discard-row"]

class TSchemaEvolutionModes(TypedDict, total=False):
class TSchemaContractModes(TypedDict, total=False):
"""TypedDict defining the schema update settings"""
table: TSchemaEvolutionMode
column: TSchemaEvolutionMode
column_variant: TSchemaEvolutionMode
table: Optional[TSchemaEvolutionMode]
column: Optional[TSchemaEvolutionMode]
data_type: Optional[TSchemaEvolutionMode]

TSchemaEvolutionSettings = Union[TSchemaEvolutionMode, TSchemaEvolutionModes]
TSchemaContractSettings = Union[TSchemaEvolutionMode, TSchemaContractModes]

class TRowFilters(TypedDict, total=True):
excludes: Optional[List[TSimpleRegex]]
Expand All @@ -85,7 +85,7 @@ class TTableSchema(TypedDict, total=False):
name: Optional[str]
description: Optional[str]
write_disposition: Optional[TWriteDisposition]
schema_evolution_settings: Optional[TSchemaEvolutionSettings]
schema_contract_settings: Optional[TSchemaContractSettings]
parent: Optional[str]
filters: Optional[TRowFilters]
columns: TTableSchemaColumns
Expand All @@ -101,7 +101,7 @@ class TPartialTableSchema(TTableSchema):


class TSchemaSettings(TypedDict, total=False):
schema_evolution_settings: Optional[TSchemaEvolutionSettings]
schema_contract_settings: Optional[TSchemaContractSettings]
detections: Optional[List[TTypeDetections]]
default_hints: Optional[Dict[TColumnHint, List[TSimpleRegex]]]
preferred_types: Optional[Dict[TSimpleRegex, TDataType]]
Expand Down
14 changes: 7 additions & 7 deletions dlt/common/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dlt.common.schema import detections
from dlt.common.schema.typing import (COLUMN_HINTS, SCHEMA_ENGINE_VERSION, LOADS_TABLE_NAME, SIMPLE_REGEX_PREFIX, VERSION_TABLE_NAME, TColumnName, TPartialTableSchema, TSchemaTables, TSchemaUpdate,
TSimpleRegex, TStoredSchema, TTableSchema, TTableSchemaColumns, TColumnSchemaBase, TColumnSchema, TColumnProp,
TColumnHint, TTypeDetectionFunc, TTypeDetections, TWriteDisposition, TSchemaEvolutionSettings, TSchemaEvolutionModes)
TColumnHint, TTypeDetectionFunc, TTypeDetections, TWriteDisposition, TSchemaContractSettings, TSchemaContractModes)
from dlt.common.schema.exceptions import (CannotCoerceColumnException, ParentTableNotFoundException, SchemaEngineNoUpgradePathException, SchemaException,
TablePropertiesConflictException, InvalidSchemaName)

Expand Down Expand Up @@ -343,10 +343,10 @@ def migrate_filters(group: str, filters: List[str]) -> None:
if from_engine == 6 and to_engine > 6:
# migrate from sealed properties to schema evolution settings
schema_dict["settings"].pop("schema_sealed", None)
schema_dict["settings"]["schema_evolution_settings"] = None
schema_dict["settings"]["schema_contract_settings"] = None
for table in schema_dict["tables"].values():
table.pop("table_sealed", None)
table["schema_evolution_settings"] = None
table["schema_contract_settings"] = None
from_engine = 7

schema_dict["engine_version"] = from_engine
Expand Down Expand Up @@ -476,7 +476,7 @@ def merge_tables(table: TTableSchema, partial_table: TPartialTableSchema) -> TPa
table["columns"] = updated_columns

# always update evolution settings
table["schema_evolution_settings"] = partial_table.get("schema_evolution_settings")
table["schema_contract_settings"] = partial_table.get("schema_contract_settings")

return diff_table

Expand Down Expand Up @@ -644,7 +644,7 @@ def new_table(
columns: Sequence[TColumnSchema] = None,
validate_schema: bool = False,
resource: str = None,
schema_evolution_settings: TSchemaEvolutionSettings = None,
schema_contract_settings: TSchemaContractSettings = None,
) -> TTableSchema:

table: TTableSchema = {
Expand All @@ -655,12 +655,12 @@ def new_table(
table["parent"] = parent_table_name
assert write_disposition is None
assert resource is None
assert schema_evolution_settings is None
assert schema_contract_settings is None
else:
# set write disposition only for root tables
table["write_disposition"] = write_disposition or DEFAULT_WRITE_DISPOSITION
table["resource"] = resource or table_name
table["schema_evolution_settings"] = schema_evolution_settings
table["schema_contract_settings"] = schema_contract_settings
if validate_schema:
validate_dict_ignoring_xkeys(
spec=TColumnSchema,
Expand Down
17 changes: 7 additions & 10 deletions dlt/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,41 +66,38 @@ def asstr(self, verbosity: int = 0) -> str:
...


def is_union_type(t: Type[Any]) -> bool:
return get_origin(t) is Union

def is_optional_type(t: Type[Any]) -> bool:
return get_origin(t) is Union and type(None) in get_args(t)


def is_final_type(t: Type[Any]) -> bool:
return get_origin(t) is Final


def extract_optional_type(t: Type[Any]) -> Any:
return get_args(t)[0]

def extract_union_types(t: Type[Any], no_none: bool = False) -> List[Any]:
if no_none:
return [arg for arg in get_args(t) if arg is not type(None)]
return list(get_args(t))

def is_literal_type(hint: Type[Any]) -> bool:
return get_origin(hint) is Literal


def is_union(hint: Type[Any]) -> bool:
return get_origin(hint) is Union


def is_newtype_type(t: Type[Any]) -> bool:
return hasattr(t, "__supertype__")


def is_typeddict(t: Type[Any]) -> bool:
return isinstance(t, _TypedDict)


def is_list_generic_type(t: Type[Any]) -> bool:
try:
return issubclass(get_origin(t), C_Sequence)
except TypeError:
return False


def is_dict_generic_type(t: Type[Any]) -> bool:
try:
return issubclass(get_origin(t), C_Mapping)
Expand Down
29 changes: 20 additions & 9 deletions dlt/common/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Callable, Any, Type, get_type_hints, get_args

from dlt.common.exceptions import DictValidationException
from dlt.common.typing import StrAny, extract_optional_type, is_literal_type, is_optional_type, is_typeddict, is_list_generic_type, is_dict_generic_type, _TypedDict
from dlt.common.typing import StrAny, is_literal_type, is_optional_type, extract_union_types, is_union_type, is_typeddict, is_list_generic_type, is_dict_generic_type, _TypedDict, is_union


TFilterFunc = Callable[[str], bool]
Expand Down Expand Up @@ -49,15 +49,26 @@ def validate_dict(spec: Type[_TypedDict], doc: StrAny, path: str, filter_f: TFil
raise DictValidationException(f"In {path}: following fields are unexpected {unexpected}", path)

def verify_prop(pk: str, pv: Any, t: Any) -> None:
if is_optional_type(t):
# pass if value actually is none
if pv is None:
return
t = extract_optional_type(t)

# TODO: support for union types?
if pk == "schema_evolution_settings":
# covers none in optional and union types
if is_optional_type(t) and pv is None:
pass
elif is_union_type(t):
# pass if value actually is none
union_types = extract_union_types(t, no_none=True)
# this is the case for optional fields
if len(union_types) == 1:
verify_prop(pk, pv, union_types[0])
else:
has_passed = False
for ut in union_types:
try:
verify_prop(pk, pv, ut)
has_passed = True
except DictValidationException:
pass
if not has_passed:
type_names = [ut.__name__ for ut in union_types]
raise DictValidationException(f"In {path}: field {pk} value {pv} has invalid type {type(pv).__name__}. One of these types expected: {', '.join(type_names)}.", path, pk, pv)
elif is_literal_type(t):
a_l = get_args(t)
if pv not in a_l:
Expand Down
Loading

0 comments on commit 881d79a

Please sign in to comment.