diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 58ddf69cea..d6ee5be4cd 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -223,9 +223,18 @@ def should_normalize_arrow_schema( schema: pyarrow.Schema, columns: TTableSchemaColumns, naming: NamingConvention, -) -> Tuple[bool, Mapping[str, str], Dict[str, str], TTableSchemaColumns]: +) -> Tuple[bool, Mapping[str, str], Dict[str, str], Dict[str, bool], TTableSchemaColumns]: rename_mapping = get_normalized_arrow_fields_mapping(schema, naming) rev_mapping = {v: k for k, v in rename_mapping.items()} + nullable_mapping = {k: v.get("nullable", True) for k, v in columns.items()} + # All fields from arrow schema that have nullable set to different value than in columns + # Key is the renamed column name + nullable_updates: Dict[str, bool] = {} + for field in schema: + norm_name = rename_mapping[field.name] + if norm_name in nullable_mapping and field.nullable != nullable_mapping[norm_name]: + nullable_updates[norm_name] = nullable_mapping[norm_name] + dlt_tables = list(map(naming.normalize_table_identifier, ("_dlt_id", "_dlt_load_id"))) # remove all columns that are dlt columns but are not present in arrow schema. we do not want to add such columns @@ -239,8 +248,8 @@ def should_normalize_arrow_schema( # check if nothing to rename skip_normalize = ( list(rename_mapping.keys()) == list(rename_mapping.values()) == list(columns.keys()) - ) - return not skip_normalize, rename_mapping, rev_mapping, columns + ) and not nullable_updates + return not skip_normalize, rename_mapping, rev_mapping, nullable_updates, columns def normalize_py_arrow_item( @@ -254,10 +263,11 @@ def normalize_py_arrow_item( 1. arrow schema field names will be normalized according to `naming` 2. arrows columns will be reordered according to `columns` 3. empty columns will be inserted if they are missing, types will be generated using `caps` + 4. arrow columns with different nullability than corresponding schema columns will be updated """ schema = item.schema - should_normalize, rename_mapping, rev_mapping, columns = should_normalize_arrow_schema( - schema, columns, naming + should_normalize, rename_mapping, rev_mapping, nullable_updates, columns = ( + should_normalize_arrow_schema(schema, columns, naming) ) if not should_normalize: return item @@ -270,8 +280,12 @@ def normalize_py_arrow_item( field_name = rev_mapping.pop(column_name, column_name) if field_name in rename_mapping: idx = schema.get_field_index(field_name) + new_field = schema.field(idx).with_name(column_name) + if column_name in nullable_updates: + # Set field nullable to match column + new_field = new_field.with_nullable(nullable_updates[column_name]) # use renamed field - new_fields.append(schema.field(idx).with_name(column_name)) + new_fields.append(new_field) new_columns.append(item.column(idx)) else: # column does not exist in pyarrow. create empty field and column diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index 742125850d..81220da2dd 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -295,7 +295,7 @@ def _write_with_dlt_columns( items_count += batch.num_rows # we may need to normalize if is_native_arrow_writer and should_normalize is None: - should_normalize, _, _, _ = pyarrow.should_normalize_arrow_schema( + should_normalize, _, _, _, _ = pyarrow.should_normalize_arrow_schema( batch.schema, columns_schema, schema.naming ) if should_normalize: @@ -376,7 +376,7 @@ def __call__(self, extracted_items_file: str, root_table_name: str) -> List[TSch ) if not must_rewrite: # in rare cases normalization may be needed - must_rewrite, _, _, _ = pyarrow.should_normalize_arrow_schema( + must_rewrite, _, _, _, _ = pyarrow.should_normalize_arrow_schema( arrow_schema, self.schema.get_table_columns(root_table_name), self.schema.naming ) if must_rewrite: diff --git a/tests/libs/pyarrow/test_pyarrow_normalizer.py b/tests/libs/pyarrow/test_pyarrow_normalizer.py index 25871edd45..70c21d4d77 100644 --- a/tests/libs/pyarrow/test_pyarrow_normalizer.py +++ b/tests/libs/pyarrow/test_pyarrow_normalizer.py @@ -99,6 +99,38 @@ def test_default_dlt_columns_not_added() -> None: assert _row_at_index(result, 0) == [None, None, 1] +def test_non_nullable_columns() -> None: + """Tests the case where arrow table is created with incomplete schema info, + such as when converting pandas dataframe to arrow. In this case normalize + should update not-null constraints in the arrow schema. + """ + table = pa.Table.from_pylist( + [ + { + "col1": 1, + "col2": "hello", + # Include column that will be renamed by normalize + # To ensure nullable flag mapping is correct + "Col 3": "world", + }, + ] + ) + columns = [ + new_column("col1", "bigint", nullable=False), + new_column("col2", "text"), + new_column("col_3", "text", nullable=False), + ] + result = _normalize(table, columns) + + # new columns appear at the end + assert result.column_names == ["col1", "col2", "col_3"] + # Not-null columns are updated in arrow + assert result.schema.field("col1").nullable is False + assert result.schema.field("col_3").nullable is False + # col2 is still nullable + assert result.schema.field("col2").nullable is True + + @pytest.mark.skip(reason="Somehow this does not fail, should we add an exception??") def test_fails_if_adding_non_nullable_column() -> None: table = pa.Table.from_pylist( diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index c5a37ee5bb..0bddfaabee 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -217,3 +217,48 @@ def some_data(): # Parquet schema is written with normalized column names assert result_tbl.schema.names == expected_column_names + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + default_sql_configs=True, + default_staging_configs=True, + all_staging_configs=True, + default_vector_configs=True, + ), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("item_type", ["arrow-table", "pandas", "arrow-batch"]) +def test_load_arrow_with_not_null_columns( + item_type: TestDataItemFormat, destination_config: DestinationTestConfiguration +) -> None: + """Resource schema contains non-nullable columns. Arrow schema should be written accordingly""" + item, records, _ = arrow_table_all_data_types(item_type, include_json=False, include_time=False) + + @dlt.resource(primary_key="string", columns=[{"name": "int", "nullable": False}]) + def some_data(): + yield item + + pipeline = destination_config.setup_pipeline("arrow_" + uniq_id()) + + pipeline.extract(some_data()) + + norm_storage = pipeline._get_normalize_storage() + extract_files = [ + fn for fn in norm_storage.list_files_to_normalize_sorted() if fn.endswith(".parquet") + ] + assert len(extract_files) == 1 + + # Check the extracted parquet file. It should have the respective non-nullable column in schema + with norm_storage.extracted_packages.storage.open_file(extract_files[0], "rb") as f: + result_tbl = pa.parquet.read_table(f) + assert result_tbl.schema.field("string").nullable is False + assert result_tbl.schema.field("string").type == pa.string() + assert result_tbl.schema.field("int").nullable is False + assert result_tbl.schema.field("int").type == pa.int64() + + pipeline.normalize() + # Load is succesful + info = pipeline.load() + assert_load_info(info)