Skip to content

Commit

Permalink
feat(python): allow python objects to be passed as new values in `.up…
Browse files Browse the repository at this point in the history
…date()` (#1749)

# Description
A user can now add a new_values dictionary that contains python objects
as a value.


Some weird behavior's I noticed, probably related to datafusion,
updating a timestamp column has to be done by providing a unix timestamp
in microseconds. I personally find this very confusing, I was expecting
to be able to pass "2012-10-01" for example in the updates.

Another weird behaviour is with list of string columns. I can pass
`{"list_of_string_col":"[1,2,3]"}` or
`{"list_of_string_col":"['1','2','3']"}` and both will work. I expect
the first one to raise an exception on invalid datatypes. Combined
datatypes `"[1,2,'3']"` luckily do raise an error by datafusion.



# Related Issue(s)
<!---
For example:

- closes #106
--->
- closes #1740

---------

Co-authored-by: Will Jones <[email protected]>
  • Loading branch information
ion-elgreco and wjones127 authored Nov 4, 2023
1 parent 45e7841 commit 5a5dbcd
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 12 deletions.
73 changes: 61 additions & 12 deletions python/deltalake/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,24 +496,28 @@ def vacuum(

def update(
self,
updates: Dict[str, str],
updates: Optional[Dict[str, str]] = None,
new_values: Optional[
Dict[str, Union[int, float, str, datetime, bool, List[Any]]]
] = None,
predicate: Optional[str] = None,
writer_properties: Optional[Dict[str, int]] = None,
error_on_type_mismatch: bool = True,
) -> Dict[str, Any]:
"""UPDATE records in the Delta Table that matches an optional predicate.
"""`UPDATE` records in the Delta Table that matches an optional predicate. Either updates or new_values needs
to be passed for it to execute.
Args:
updates: a mapping of column name to update SQL expression.
new_values: a mapping of column name to python datatype.
predicate: a logical expression, defaults to None
writer_properties: Pass writer properties to the Rust parquet writer, see options
https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html,
only the following fields are supported: `data_page_size_limit`, `dictionary_page_size_limit`,
`data_page_row_count_limit`, `write_batch_size`, `max_row_group_size`.
error_on_type_mismatch: specify if merge will return error if data types are mismatching, default = True
writer_properties: Pass writer properties to the Rust parquet writer, see options https://arrow.apache.org/rust/parquet/file/properties/struct.WriterProperties.html,
only the following fields are supported: `data_page_size_limit`, `dictionary_page_size_limit`,
`data_page_row_count_limit`, `write_batch_size`, `max_row_group_size`.
error_on_type_mismatch: specify if update will return error if data types are mismatching :default = True
Returns:
the metrics from delete
the metrics from update
Examples:
Expand All @@ -522,18 +526,63 @@ def update(
```
from deltalake import DeltaTable
dt = DeltaTable("tmp")
dt.update(predicate="id = '5'", updates = {"deleted": True})
dt.update(predicate="id = '5'", updates = {"deleted": 'True'})
```
Update all row values. This is equivalent to `UPDATE table SET id = concat(id, '_old')`.
Update all row values. This is equivalent to
``UPDATE table SET deleted = true, id = concat(id, '_old')``.
```
from deltalake import DeltaTable
dt = DeltaTable("tmp")
dt.update(updates={"deleted": True, "id": "concat(id, '_old')"})
dt.update(updates = {"deleted": 'True', "id": "concat(id, '_old')"})
```
To use Python objects instead of SQL strings, use the `new_values` parameter
instead of the `updates` parameter. For example, this is equivalent to
``UPDATE table SET price = 150.10 WHERE id = '5'``
```
from deltalake import DeltaTable
dt = DeltaTable("tmp")
dt.update(predicate="id = '5'", new_values = {"price": 150.10})
```
"""
if updates is None and new_values is not None:
updates = {}
for key, value in new_values.items():
if isinstance(value, (int, float, bool, list)):
value = str(value)
elif isinstance(value, str):
value = f"'{value}'"
elif isinstance(value, datetime):
value = str(
int(value.timestamp() * 1000 * 1000)
) # convert to microseconds
else:
raise TypeError(
"Invalid datatype provided in new_values, only int, float, bool, list, str or datetime or accepted."
)
updates[key] = value
elif updates is not None and new_values is None:
for key, value in updates.items():
print(type(key), type(value))
if not isinstance(value, str) or not isinstance(key, str):
raise TypeError(
f"The values of the updates parameter must all be SQL strings. Got {updates}. Did you mean to use the new_values parameter?"
)

elif updates is not None and new_values is not None:
raise ValueError(
"Passing updates and new_values at same time is not allowed, pick one."
)
else:
raise ValueError(
"Either updates or new_values need to be passed to update the table."
)
metrics = self._table.update(
updates, predicate, writer_properties, safe_cast=not error_on_type_mismatch
updates,
predicate,
writer_properties,
safe_cast=not error_on_type_mismatch,
)
return json.loads(metrics)

Expand Down
91 changes: 91 additions & 0 deletions python/tests/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def sample_table():
"id": pa.array(["1", "2", "3", "4", "5"]),
"price": pa.array(list(range(nrows)), pa.int64()),
"sold": pa.array(list(range(nrows)), pa.int64()),
"price_float": pa.array(list(range(nrows)), pa.float64()),
"items_in_bucket": pa.array([["item1", "item2", "item3"]] * nrows),
"deleted": pa.array([False] * nrows),
}
)
Expand All @@ -30,6 +32,8 @@ def test_update_with_predicate(tmp_path: pathlib.Path, sample_table: pa.Table):
"id": pa.array(["1", "2", "3", "4", "5"]),
"price": pa.array(list(range(nrows)), pa.int64()),
"sold": pa.array(list(range(nrows)), pa.int64()),
"price_float": pa.array(list(range(nrows)), pa.float64()),
"items_in_bucket": pa.array([["item1", "item2", "item3"]] * nrows),
"deleted": pa.array([False, False, False, False, True]),
}
)
Expand All @@ -54,6 +58,8 @@ def test_update_wo_predicate(tmp_path: pathlib.Path, sample_table: pa.Table):
"id": pa.array(["1", "2", "3", "4", "5"]),
"price": pa.array(list(range(nrows)), pa.int64()),
"sold": pa.array(list(range(nrows)), pa.int64()),
"price_float": pa.array(list(range(nrows)), pa.float64()),
"items_in_bucket": pa.array([["item1", "item2", "item3"]] * nrows),
"deleted": pa.array([True] * 5),
}
)
Expand Down Expand Up @@ -93,6 +99,8 @@ def test_update_wo_predicate_multiple_updates(
"id": pa.array(["1_1", "2_1", "3_1", "4_1", "5_1"]),
"price": pa.array([0, 1, 2, 3, 4], pa.int64()),
"sold": pa.array([0, 1, 4, 9, 16], pa.int64()),
"price_float": pa.array(list(range(5)), pa.float64()),
"items_in_bucket": pa.array([["item1", "item2", "item3"]] * 5),
"deleted": pa.array([True] * 5),
}
)
Expand All @@ -107,3 +115,86 @@ def test_update_wo_predicate_multiple_updates(

assert last_action["operation"] == "UPDATE"
assert result == expected


def test_update_with_predicate_and_new_values(
tmp_path: pathlib.Path, sample_table: pa.Table
):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)

nrows = 5
expected = pa.table(
{
"id": pa.array(["1", "2", "3", "4", "new_id"]),
"price": pa.array(list(range(nrows)), pa.int64()),
"sold": pa.array([0, 1, 2, 3, 100], pa.int64()),
"price_float": pa.array([0, 1, 2, 3, 9999], pa.float64()),
"items_in_bucket": pa.array(
[["item1", "item2", "item3"]] * 4 + [["item4", "item5", "item6"]]
),
"deleted": pa.array([False, False, False, False, True]),
}
)

dt.update(
new_values={
"id": "new_id",
"deleted": True,
"sold": 100,
"price_float": 9999,
"items_in_bucket": ["item4", "item5", "item6"],
},
predicate="price > 3",
)

result = dt.to_pyarrow_table()
last_action = dt.history(1)[0]

assert last_action["operation"] == "UPDATE"
assert result == expected


def test_update_no_inputs(tmp_path: pathlib.Path, sample_table: pa.Table):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)

with pytest.raises(Exception) as excinfo:
dt.update()

assert (
str(excinfo.value)
== "Either updates or new_values need to be passed to update the table."
)


def test_update_to_many_inputs(tmp_path: pathlib.Path, sample_table: pa.Table):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)

with pytest.raises(Exception) as excinfo:
dt.update(updates={}, new_values={})

assert (
str(excinfo.value)
== "Passing updates and new_values at same time is not allowed, pick one."
)


def test_update_with_incorrect_updates_input(
tmp_path: pathlib.Path, sample_table: pa.Table
):
write_deltalake(tmp_path, sample_table, mode="append")

dt = DeltaTable(tmp_path)
updates = {"col": {}}
with pytest.raises(Exception) as excinfo:
dt.update(new_values=updates)

assert (
str(excinfo.value)
== "Invalid datatype provided in new_values, only int, float, bool, list, str or datetime or accepted."
)

0 comments on commit 5a5dbcd

Please sign in to comment.