Skip to content

Commit

Permalink
Ensure arrow field's nullable flag matches the schema column
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed May 31, 2024
1 parent 829b558 commit 1c43861
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 8 deletions.
26 changes: 20 additions & 6 deletions dlt/common/libs/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,18 @@ def should_normalize_arrow_schema(
schema: pyarrow.Schema,
columns: TTableSchemaColumns,
naming: NamingConvention,
) -> Tuple[bool, Mapping[str, str], Dict[str, str], TTableSchemaColumns]:
) -> Tuple[bool, Mapping[str, str], Dict[str, str], Dict[str, bool], TTableSchemaColumns]:
rename_mapping = get_normalized_arrow_fields_mapping(schema, naming)
rev_mapping = {v: k for k, v in rename_mapping.items()}
nullable_mapping = {k: v.get("nullable", True) for k, v in columns.items()}
# All fields from arrow schema that have nullable set to different value than in columns
# Key is the renamed column name
nullable_updates: Dict[str, bool] = {}
for field in schema:
norm_name = rename_mapping[field.name]
if norm_name in nullable_mapping and field.nullable != nullable_mapping[norm_name]:
nullable_updates[norm_name] = nullable_mapping[norm_name]

dlt_tables = list(map(naming.normalize_table_identifier, ("_dlt_id", "_dlt_load_id")))

# remove all columns that are dlt columns but are not present in arrow schema. we do not want to add such columns
Expand All @@ -239,8 +248,8 @@ def should_normalize_arrow_schema(
# check if nothing to rename
skip_normalize = (
list(rename_mapping.keys()) == list(rename_mapping.values()) == list(columns.keys())
)
return not skip_normalize, rename_mapping, rev_mapping, columns
) and not nullable_updates
return not skip_normalize, rename_mapping, rev_mapping, nullable_updates, columns


def normalize_py_arrow_item(
Expand All @@ -254,10 +263,11 @@ def normalize_py_arrow_item(
1. arrow schema field names will be normalized according to `naming`
2. arrows columns will be reordered according to `columns`
3. empty columns will be inserted if they are missing, types will be generated using `caps`
4. arrow columns with different nullability than corresponding schema columns will be updated
"""
schema = item.schema
should_normalize, rename_mapping, rev_mapping, columns = should_normalize_arrow_schema(
schema, columns, naming
should_normalize, rename_mapping, rev_mapping, nullable_updates, columns = (
should_normalize_arrow_schema(schema, columns, naming)
)
if not should_normalize:
return item
Expand All @@ -270,8 +280,12 @@ def normalize_py_arrow_item(
field_name = rev_mapping.pop(column_name, column_name)
if field_name in rename_mapping:
idx = schema.get_field_index(field_name)
new_field = schema.field(idx).with_name(column_name)
if column_name in nullable_updates:
# Set field nullable to match column
new_field = new_field.with_nullable(nullable_updates[column_name])
# use renamed field
new_fields.append(schema.field(idx).with_name(column_name))
new_fields.append(new_field)
new_columns.append(item.column(idx))
else:
# column does not exist in pyarrow. create empty field and column
Expand Down
4 changes: 2 additions & 2 deletions dlt/normalize/items_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _write_with_dlt_columns(
items_count += batch.num_rows
# we may need to normalize
if is_native_arrow_writer and should_normalize is None:
should_normalize, _, _, _ = pyarrow.should_normalize_arrow_schema(
should_normalize, _, _, _, _ = pyarrow.should_normalize_arrow_schema(
batch.schema, columns_schema, schema.naming
)
if should_normalize:
Expand Down Expand Up @@ -376,7 +376,7 @@ def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSch
)
if not must_rewrite:
# in rare cases normalization may be needed
must_rewrite, _, _, _ = pyarrow.should_normalize_arrow_schema(
must_rewrite, _, _, _, _ = pyarrow.should_normalize_arrow_schema(
arrow_schema, self.schema.get_table_columns(root_table_name), self.schema.naming
)
if must_rewrite:
Expand Down
32 changes: 32 additions & 0 deletions tests/libs/pyarrow/test_pyarrow_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,38 @@ def test_default_dlt_columns_not_added() -> None:
assert _row_at_index(result, 0) == [None, None, 1]


def test_non_nullable_columns() -> None:
"""Tests the case where arrow table is created with incomplete schema info,
such as when converting pandas dataframe to arrow. In this case normalize
should update not-null constraints in the arrow schema.
"""
table = pa.Table.from_pylist(
[
{
"col1": 1,
"col2": "hello",
# Include column that will be renamed by normalize
# To ensure nullable flag mapping is correct
"Col 3": "world",
},
]
)
columns = [
new_column("col1", "bigint", nullable=False),
new_column("col2", "text"),
new_column("col_3", "text", nullable=False),
]
result = _normalize(table, columns)

# new columns appear at the end
assert result.column_names == ["col1", "col2", "col_3"]
# Not-null columns are updated in arrow
assert result.schema.field("col1").nullable is False
assert result.schema.field("col_3").nullable is False
# col2 is still nullable
assert result.schema.field("col2").nullable is True


@pytest.mark.skip(reason="Somehow this does not fail, should we add an exception??")
def test_fails_if_adding_non_nullable_column() -> None:
table = pa.Table.from_pylist(
Expand Down
45 changes: 45 additions & 0 deletions tests/load/pipeline/test_arrow_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,48 @@ def some_data():

# Parquet schema is written with normalized column names
assert result_tbl.schema.names == expected_column_names


@pytest.mark.parametrize(
"destination_config",
destinations_configs(
default_sql_configs=True,
default_staging_configs=True,
all_staging_configs=True,
default_vector_configs=True,
),
ids=lambda x: x.name,
)
@pytest.mark.parametrize("item_type", ["arrow-table", "pandas", "arrow-batch"])
def test_load_arrow_with_not_null_columns(
item_type: TestDataItemFormat, destination_config: DestinationTestConfiguration
) -> None:
"""Resource schema contains non-nullable columns. Arrow schema should be written accordingly"""
item, records, _ = arrow_table_all_data_types(item_type, include_json=False, include_time=False)

@dlt.resource(primary_key="string", columns=[{"name": "int", "nullable": False}])
def some_data():
yield item

pipeline = destination_config.setup_pipeline("arrow_" + uniq_id())

pipeline.extract(some_data())

norm_storage = pipeline._get_normalize_storage()
extract_files = [
fn for fn in norm_storage.list_files_to_normalize_sorted() if fn.endswith(".parquet")
]
assert len(extract_files) == 1

# Check the extracted parquet file. It should have the respective non-nullable column in schema
with norm_storage.extracted_packages.storage.open_file(extract_files[0], "rb") as f:
result_tbl = pa.parquet.read_table(f)
assert result_tbl.schema.field("string").nullable is False
assert result_tbl.schema.field("string").type == pa.string()
assert result_tbl.schema.field("int").nullable is False
assert result_tbl.schema.field("int").type == pa.int64()

pipeline.normalize()
# Load is succesful
info = pipeline.load()
assert_load_info(info)

0 comments on commit 1c43861

Please sign in to comment.