From 8c172a719bb578fb7db23ec4dbb7df5ec8bc2900 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Mon, 4 Nov 2024 12:55:46 +0400 Subject: [PATCH 1/3] add delta arrow load id partition handling --- dlt/common/libs/deltalake.py | 22 +++++++++--- .../load/pipeline/test_filesystem_pipeline.py | 36 +++++++++++++++++++ 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/dlt/common/libs/deltalake.py b/dlt/common/libs/deltalake.py index 9caba55183..ccce79278d 100644 --- a/dlt/common/libs/deltalake.py +++ b/dlt/common/libs/deltalake.py @@ -24,7 +24,10 @@ ) -def ensure_delta_compatible_arrow_schema(schema: pa.Schema) -> pa.Schema: +def ensure_delta_compatible_arrow_schema( + schema: pa.Schema, + partition_by: Optional[Union[List[str], str]] = None, +) -> pa.Schema: """Returns Arrow schema compatible with Delta table format. Casts schema to replace data types not supported by Delta. @@ -35,12 +38,23 @@ def ensure_delta_compatible_arrow_schema(schema: pa.Schema) -> pa.Schema: pa.types.is_time: pa.string(), pa.types.is_decimal256: pa.string(), # pyarrow does not allow downcasting to decimal128 } + + # partition fields can't be dictionary: https://github.com/delta-io/delta-rs/issues/2969 + if isinstance(partition_by, str): + partition_by = [partition_by] + if any(pa.types.is_dictionary(schema.field(col).type) for col in partition_by): + # cast all dictionary fields to string — this is rogue because + # 1. dictionary value type is disregarded + # 2. any non-partition dictionary fields are cast too + ARROW_TO_DELTA_COMPATIBLE_ARROW_TYPE_MAP[pa.types.is_dictionary] = pa.string() + # NOTE: also consider calling _convert_pa_schema_to_delta() from delta.schema which casts unsigned types return cast_arrow_schema_types(schema, ARROW_TO_DELTA_COMPATIBLE_ARROW_TYPE_MAP) def ensure_delta_compatible_arrow_data( - data: Union[pa.Table, pa.RecordBatchReader] + data: Union[pa.Table, pa.RecordBatchReader], + partition_by: Optional[Union[List[str], str]] = None, ) -> Union[pa.Table, pa.RecordBatchReader]: """Returns Arrow data compatible with Delta table format. @@ -53,7 +67,7 @@ def ensure_delta_compatible_arrow_data( version="17.0.0", msg="`pyarrow>=17.0.0` is needed for `delta` table format on `filesystem` destination.", ) - schema = ensure_delta_compatible_arrow_schema(data.schema) + schema = ensure_delta_compatible_arrow_schema(data.schema, partition_by) return data.cast(schema) @@ -87,7 +101,7 @@ def write_delta_table( # is released write_deltalake( # type: ignore[call-overload] table_or_uri=table_or_uri, - data=ensure_delta_compatible_arrow_data(data), + data=ensure_delta_compatible_arrow_data(data, partition_by), partition_by=partition_by, mode=get_delta_write_mode(write_disposition), schema_mode="merge", # enable schema evolution (adding new columns) diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index b8cf66608c..a36b743d25 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -586,6 +586,42 @@ def two_part(): assert dt.metadata().partition_columns == [] +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + table_format_filesystem_configs=True, + with_table_format="delta", + bucket_subset=(FILE_BUCKET), + ), + ids=lambda x: x.name, +) +def test_delta_table_partitioning_arrow_load_id( + destination_config: DestinationTestConfiguration, +) -> None: + """Tests partitioning on load id column added by Arrow normalizer. + + Case needs special handling because of bug in delta-rs: + https://github.com/delta-io/delta-rs/issues/2969 + """ + from dlt.common.libs.pyarrow import pyarrow + from dlt.common.libs.deltalake import get_delta_tables + + os.environ["NORMALIZE__PARQUET_NORMALIZER__ADD_DLT_LOAD_ID"] = "true" + + pipeline = destination_config.setup_pipeline("fs_pipe", dev_mode=True) + + info = pipeline.run( + pyarrow.table({"foo": [1]}), + table_name="delta_table", + columns={"_dlt_load_id": {"partition": True}}, + table_format="delta", + ) + assert_load_info(info) + dt = get_delta_tables(pipeline, "delta_table")["delta_table"] + assert dt.metadata().partition_columns == ["_dlt_load_id"] + assert load_table_counts(pipeline, "delta_table")["delta_table"] == 1 + + @pytest.mark.essential @pytest.mark.parametrize( "destination_config", From 7389bbd9d2c6caa8f3a48ad55921d3261ba72a7a Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Mon, 4 Nov 2024 15:11:33 +0400 Subject: [PATCH 2/3] handle None case --- dlt/common/libs/deltalake.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/dlt/common/libs/deltalake.py b/dlt/common/libs/deltalake.py index ccce79278d..bf508d7031 100644 --- a/dlt/common/libs/deltalake.py +++ b/dlt/common/libs/deltalake.py @@ -40,13 +40,14 @@ def ensure_delta_compatible_arrow_schema( } # partition fields can't be dictionary: https://github.com/delta-io/delta-rs/issues/2969 - if isinstance(partition_by, str): - partition_by = [partition_by] - if any(pa.types.is_dictionary(schema.field(col).type) for col in partition_by): - # cast all dictionary fields to string — this is rogue because - # 1. dictionary value type is disregarded - # 2. any non-partition dictionary fields are cast too - ARROW_TO_DELTA_COMPATIBLE_ARROW_TYPE_MAP[pa.types.is_dictionary] = pa.string() + if partition_by is not None: + if isinstance(partition_by, str): + partition_by = [partition_by] + if any(pa.types.is_dictionary(schema.field(col).type) for col in partition_by): + # cast all dictionary fields to string — this is rogue because + # 1. dictionary value type is disregarded + # 2. any non-partition dictionary fields are cast too + ARROW_TO_DELTA_COMPATIBLE_ARROW_TYPE_MAP[pa.types.is_dictionary] = pa.string() # NOTE: also consider calling _convert_pa_schema_to_delta() from delta.schema which casts unsigned types return cast_arrow_schema_types(schema, ARROW_TO_DELTA_COMPATIBLE_ARROW_TYPE_MAP) From 73a8f8ec078ba86e6d043e7a5452bae53a573bd2 Mon Sep 17 00:00:00 2001 From: Jorrit Sandbrink Date: Mon, 4 Nov 2024 15:34:53 +0400 Subject: [PATCH 3/3] handle delta arrow load id partition column merge disposition --- dlt/common/libs/deltalake.py | 3 ++- tests/load/pipeline/test_filesystem_pipeline.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/dlt/common/libs/deltalake.py b/dlt/common/libs/deltalake.py index bf508d7031..4047bc3a1a 100644 --- a/dlt/common/libs/deltalake.py +++ b/dlt/common/libs/deltalake.py @@ -131,9 +131,10 @@ def merge_delta_table( primary_keys = get_columns_names_with_prop(schema, "primary_key") predicate = " AND ".join([f"target.{c} = source.{c}" for c in primary_keys]) + partition_by = get_columns_names_with_prop(schema, "partition") qry = ( table.merge( - source=ensure_delta_compatible_arrow_data(data), + source=ensure_delta_compatible_arrow_data(data, partition_by), predicate=predicate, source_alias="source", target_alias="target", diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index a36b743d25..2ad175c8f5 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -610,6 +610,7 @@ def test_delta_table_partitioning_arrow_load_id( pipeline = destination_config.setup_pipeline("fs_pipe", dev_mode=True) + # append write disposition info = pipeline.run( pyarrow.table({"foo": [1]}), table_name="delta_table", @@ -621,6 +622,20 @@ def test_delta_table_partitioning_arrow_load_id( assert dt.metadata().partition_columns == ["_dlt_load_id"] assert load_table_counts(pipeline, "delta_table")["delta_table"] == 1 + # merge write disposition + info = pipeline.run( + pyarrow.table({"foo": [1, 2]}), + table_name="delta_table", + write_disposition={"disposition": "merge", "strategy": "upsert"}, + columns={"_dlt_load_id": {"partition": True}}, + primary_key="foo", + table_format="delta", + ) + assert_load_info(info) + dt = get_delta_tables(pipeline, "delta_table")["delta_table"] + assert dt.metadata().partition_columns == ["_dlt_load_id"] + assert load_table_counts(pipeline, "delta_table")["delta_table"] == 2 + @pytest.mark.essential @pytest.mark.parametrize(