Skip to content

Commit

Permalink
expose predicate, handle error better, and add overloads
Browse files Browse the repository at this point in the history
  • Loading branch information
ion-elgreco committed Nov 25, 2023
1 parent 381df0d commit 3ab7687
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 1 deletion.
4 changes: 3 additions & 1 deletion crates/deltalake-core/src/operations/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,9 @@ impl std::future::IntoFuture for WriteBuilder {

match this.predicate {
Some(_pred) => {
todo!("Overwriting data based on predicate is not yet implemented")
return Err(DeltaTableError::Generic(
"Overwriting data based on predicate is not yet implemented".to_string(),
));
}
_ => {
let remove_actions = this
Expand Down
1 change: 1 addition & 0 deletions python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def write_to_deltalake(
mode: str,
max_rows_per_group: int,
overwrite_schema: bool,
predicate: Optional[str],
name: Optional[str],
description: Optional[str],
configuration: Optional[Mapping[str, Optional[str]]],
Expand Down
66 changes: 66 additions & 0 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Optional,
Tuple,
Union,
overload,
)
from urllib.parse import unquote

Expand Down Expand Up @@ -68,6 +69,68 @@ class AddAction:
stats: str


@overload
def write_deltalake(
table_or_uri: Union[str, Path, DeltaTable],
data: Union[
"pd.DataFrame",
ds.Dataset,
pa.Table,
pa.RecordBatch,
Iterable[pa.RecordBatch],
RecordBatchReader,
],
*,
schema: Optional[pa.Schema] = ...,
partition_by: Optional[Union[List[str], str]] = ...,
filesystem: Optional[pa_fs.FileSystem] = None,
mode: Literal["error", "append", "overwrite", "ignore"] = ...,
file_options: Optional[ds.ParquetFileWriteOptions] = ...,
max_partitions: Optional[int] = ...,
max_open_files: int = ...,
max_rows_per_file: int = ...,
min_rows_per_group: int = ...,
max_rows_per_group: int = ...,
name: Optional[str] = ...,
description: Optional[str] = ...,
configuration: Optional[Mapping[str, Optional[str]]] = ...,
overwrite_schema: bool = ...,
storage_options: Optional[Dict[str, str]] = ...,
partition_filters: Optional[List[Tuple[str, str, Any]]] = ...,
large_dtypes: bool = ...,
engine: Literal["pyarrow"] = ...,
) -> None:
...


@overload
def write_deltalake(
table_or_uri: Union[str, Path, DeltaTable],
data: Union[
"pd.DataFrame",
ds.Dataset,
pa.Table,
pa.RecordBatch,
Iterable[pa.RecordBatch],
RecordBatchReader,
],
*,
schema: Optional[pa.Schema] = ...,
partition_by: Optional[Union[List[str], str]] = ...,
mode: Literal["error", "append", "overwrite", "ignore"] = ...,
max_rows_per_group: int = ...,
name: Optional[str] = ...,
description: Optional[str] = ...,
configuration: Optional[Mapping[str, Optional[str]]] = ...,
overwrite_schema: bool = ...,
storage_options: Optional[Dict[str, str]] = ...,
predicate: Optional[str] = ...,
large_dtypes: bool = ...,
engine: Literal["rust"],
) -> None:
...


def write_deltalake(
table_or_uri: Union[str, Path, DeltaTable],
data: Union[
Expand Down Expand Up @@ -95,6 +158,7 @@ def write_deltalake(
overwrite_schema: bool = False,
storage_options: Optional[Dict[str, str]] = None,
partition_filters: Optional[List[Tuple[str, str, Any]]] = None,
predicate: Optional[str] = None,
large_dtypes: bool = False,
engine: Literal["pyarrow", "rust"] = "pyarrow",
) -> None:
Expand Down Expand Up @@ -164,6 +228,7 @@ def write_deltalake(
configuration: A map containing configuration options for the metadata action.
overwrite_schema: If True, allows updating the schema of the table.
storage_options: options passed to the native delta filesystem. Unused if 'filesystem' is defined.
predicate: When using `Overwrite` mode, replace data that matches a predicate. Only used in rust engine.
partition_filters: the partition filters that will be used for partition overwrite. Only used in pyarrow engine.
large_dtypes: If True, the data schema is kept in large_dtypes, has no effect on pandas dataframe input
"""
Expand Down Expand Up @@ -225,6 +290,7 @@ def write_deltalake(
mode=mode,
max_rows_per_group=max_rows_per_group,
overwrite_schema=overwrite_schema,
predicate=predicate,
name=name,
description=description,
configuration=configuration,
Expand Down
5 changes: 5 additions & 0 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,7 @@ fn write_to_deltalake(
max_rows_per_group: i64,
overwrite_schema: bool,
partition_by: Option<Vec<String>>,
predicate: Option<String>,
name: Option<String>,
description: Option<String>,
configuration: Option<HashMap<String, Option<String>>>,
Expand Down Expand Up @@ -1168,6 +1169,10 @@ fn write_to_deltalake(
builder = builder.with_description(description);
};

if let Some(predicate) = &predicate {
builder = builder.with_replace_where(predicate);
};

if let Some(config) = configuration {
builder = builder.with_configuration(config);
};
Expand Down
68 changes: 68 additions & 0 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,74 @@ def test_partition_overwrite_unfiltered_data_fails(
)


@pytest.mark.parametrize(
"value_1,value_2,value_type,filter_string",
[
(1, 2, pa.int64(), "1"),
(False, True, pa.bool_(), "false"),
(date(2022, 1, 1), date(2022, 1, 2), pa.date32(), "2022-01-01"),
],
)
def test_replace_where_overwrite(
tmp_path: pathlib.Path,
value_1: Any,
value_2: Any,
value_type: pa.DataType,
filter_string: str,
):
sample_data = pa.table(
{
"p1": pa.array(["1", "1", "2", "2"], pa.string()),
"p2": pa.array([value_1, value_2, value_1, value_2], value_type),
"val": pa.array([1, 1, 1, 1], pa.int64()),
}
)
write_deltalake(tmp_path, sample_data, mode="overwrite", partition_by=["p1", "p2"])

delta_table = DeltaTable(tmp_path)
assert (
delta_table.to_pyarrow_table().sort_by(
[("p1", "ascending"), ("p2", "ascending")]
)
== sample_data
)

sample_data = pa.table(
{
"p1": pa.array(["1", "1"], pa.string()),
"p2": pa.array([value_2, value_1], value_type),
"val": pa.array([2, 2], pa.int64()),
}
)
expected_data = pa.table(
{
"p1": pa.array(["1", "1", "2", "2"], pa.string()),
"p2": pa.array([value_1, value_2, value_1, value_2], value_type),
"val": pa.array([2, 2, 1, 1], pa.int64()),
}
)

with pytest.raises(
DeltaError,
match="Generic DeltaTable error: Overwriting data based on predicate is not yet implemented",
):
write_deltalake(
tmp_path,
sample_data,
mode="overwrite",
predicate="`p1` = 1",
engine="rust",
)

delta_table.update_incremental()
assert (
delta_table.to_pyarrow_table().sort_by(
[("p1", "ascending"), ("p2", "ascending")]
)
== expected_data
)


def test_partition_overwrite_with_new_partition(
tmp_path: pathlib.Path, sample_data_for_partitioning: pa.Table
):
Expand Down

0 comments on commit 3ab7687

Please sign in to comment.