Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add replace method to dataclass extension array API. #37

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions arrowbic/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar, overload

import immutables
import pyarrow as pa

from .base_types import NdArrayGeneric

Expand Down Expand Up @@ -49,6 +50,21 @@ def first_valid_item_in_iterable(it_items: Iterable[Optional[T]]) -> Tuple[int,
return (idx, None, consumed_values)


def get_validity_array(arr: pa.Array) -> Optional[pa.BooleanArray]:
"""Get the validity/null bitmap array from any PyArrow array. Returns None if none
allocated.

Args:
arr: Any PyArrow array.
Returns:
Boolean validity array (if existing).
"""
validity_buffer = arr.buffers()[0]
if validity_buffer is None:
return None
return pa.BooleanArray.from_buffers(pa.bool_(), len(arr), [None, validity_buffer], offset=arr.offset)


@overload
def as_immutable(obj: List[T]) -> Tuple[T, ...]:
...
Expand Down
54 changes: 53 additions & 1 deletion arrowbic/extensions/dataclass_array.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Dataclass extension array implementation.
"""
from typing import List, Optional, Type, TypeVar
from typing import Any, Dict, List, Optional, Type, TypeVar

import numpy as np
import pyarrow as pa

from arrowbic.core.array_ops import get_pyitem
from arrowbic.core.base_extension_array import BaseExtensionArray
from arrowbic.core.base_extension_type import BaseExtensionType
from arrowbic.core.utils import get_validity_array

TItem = TypeVar("TItem")
TExtArray = TypeVar("TExtArray", bound="DataclassArray") # type:ignore


class DataclassArray(BaseExtensionArray[TItem]):
Expand Down Expand Up @@ -66,3 +69,52 @@ def __getattr__(self, key: str) -> pa.Array:
if key in keys:
return self.storage.field(keys.index(key))
raise KeyError(f"Unknown field '{key}' in the Arrowbic dataclass extension array. Available columns: {keys}.")

def replace(self: TExtArray, **kwargs: Any) -> TExtArray:
"""Replace columns in a dataclass array.

NOTE: as per convention PyArrow arrays are immutable, this method build a new extension
array with updated columns.

Args:
kwargs: Columns to replace.
Returns:
Extension array with updated arrays.
Raises:
KeyError: if the input keys does not correspond to dataclass fields.
TypeError: ...
"""
from arrowbic.core.array_ops import asarray

if len(kwargs) == 0:
return self

arr_keys = set(self.keys())
in_keys = set(kwargs.keys())
if len(in_keys - arr_keys) > 0:
raise KeyError(f"The input keys do not correspond to dataclass fields: '{in_keys-arr_keys}'.")

# Combine input and existing column arrays.
in_arrs: Dict[str, pa.Array] = {k: asarray(v) for k, v in kwargs.items()}
print(in_arrs)
print(kwargs)

field_arrays: Dict[str, pa.Array] = {
f.name: in_arrs.get(f.name, self.storage.field(idx)) for idx, f in enumerate(self.storage.type)
}
# TODO: fix this mess between validity and mask array in PyArrow!!!
validity_array = get_validity_array(self)
mask_array = None
if validity_array is not None:
mask_array = pa.array(np.logical_not(validity_array.to_numpy(zero_copy_only=False)), type=pa.bool_())

# Re-build the storage struct array from the new values.
aw_field_infos = [pa.field(name=k, type=arr.type, nullable=True) for k, arr in field_arrays.items()]
storage_arr = pa.StructArray.from_arrays(list(field_arrays.values()), fields=aw_field_infos, mask=mask_array)

# By default: try keeping the same extension type. Rebuild only if necessary. No access to registry for caching.
ext_dc_type = self.type
if self.storage.type != storage_arr.type:
ext_dc_type = type(self.type)(storage_arr.type, ext_dc_type.item_pyclass, ext_dc_type.package_name)
ext_dc_arr = DataclassArray.from_storage(ext_dc_type, storage_arr)
return ext_dc_arr
6 changes: 3 additions & 3 deletions arrowbic/extensions/dataclass_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,6 @@ def __arrowbic_from_item_iterator__(
storage_arr = pa.StructArray.from_arrays(list(field_arrays.values()), fields=aw_field_infos, mask=mask)

# Build the extension array, using registry extension type cache if existing.
ext_tensor_type = find_registry_extension_type(type(first_item), storage_arr.type, registry=registry)
ext_tensor_arr = DataclassArray.from_storage(ext_tensor_type, storage_arr)
return ext_tensor_arr
ext_dc_type = find_registry_extension_type(type(first_item), storage_arr.type, registry=registry)
ext_dc_arr = DataclassArray.from_storage(ext_dc_type, storage_arr)
return ext_dc_arr
24 changes: 23 additions & 1 deletion tests/core/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import immutables
import pyarrow as pa

from arrowbic.core.utils import as_immutable, first_valid_item_in_iterable
from arrowbic.core.utils import as_immutable, first_valid_item_in_iterable, get_validity_array


def test__first_valid_item_in_iterable__list__proper_result() -> None:
Expand Down Expand Up @@ -42,3 +43,24 @@ def test__as_immutable__list_input() -> None:

def test__as_immutable__dict_input() -> None:
assert as_immutable({1: "1", 2: "2"}) == immutables.Map({1: "1", 2: "2"})


def test__get_validity_array__no_validity_bitmap() -> None:
arr = pa.array([1, 2, 3])
assert get_validity_array(arr) is None


def test__get_validity_array__direct_buffer_mapping() -> None:
arr = pa.array([1, None, 2, 3, None])
val_arr: pa.BooleanArray = get_validity_array(arr)
assert val_arr.type == pa.bool_()
assert len(val_arr) == 5
assert val_arr.to_pylist() == [True, False, True, True, False]


def test__get_validity_array__offset_buffer_mapping() -> None:
arr = pa.array([1, None, 2, 3, None])
val_arr: pa.BooleanArray = get_validity_array(arr[1:])
assert val_arr.type == pa.bool_()
assert len(val_arr) == 4
assert val_arr.to_pylist() == [False, True, True, False]
34 changes: 34 additions & 0 deletions tests/extensions/test_dataclass_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pyarrow as pa

from arrowbic.core.extension_type_registry import _global_registry
from arrowbic.core.utils import get_validity_array
from arrowbic.extensions import DataclassArray
from arrowbic.extensions.tensor_array import TensorArray

Expand Down Expand Up @@ -67,3 +68,36 @@ def test__dataclass_array__getattr__proper_columns(self) -> None:
assert isinstance(arr.data, TensorArray)
assert isinstance(arr.score, pa.FloatingPointArray)
assert isinstance(arr.name, pa.StringArray)

def test__dataclass_array__replace__no_inputs__return_same(self) -> None:
items = [
DummyData(DummyIntEnum.Invalid, np.array([1, 2, 3]), None, "name0"),
DummyData(DummyIntEnum.Valid, None, 3.0, "name2"),
]
arr_in = DataclassArray.from_iterator(items, registry=self.registry)
arr_out = arr_in.replace()
assert arr_out is arr_in

def test__dataclass_array__replace__invalid_input_keys(self) -> None:
items = [
DummyData(DummyIntEnum.Invalid, np.array([1, 2, 3]), None, "name0"),
DummyData(DummyIntEnum.Valid, None, 3.0, "name2"),
]
arr_in = DataclassArray.from_iterator(items, registry=self.registry)
with self.assertRaises(KeyError):
arr_in.replace(test=np.array([1, 2]))

def test__dataclass_array__replace__proper_fields_update(self) -> None:
items = [
None,
DummyData(DummyIntEnum.Invalid, np.array([1, 2, 3]), None, "name0"),
DummyData(DummyIntEnum.Valid, None, 3.0, "name2"),
]
arr_in = DataclassArray.from_iterator(items, registry=self.registry)
arr_out = arr_in.replace(score=np.array([4.0, 5.0, 6.0]), type=[DummyIntEnum.Invalid, DummyIntEnum.Valid, None])

assert len(arr_out) == len(arr_in)
assert get_validity_array(arr_out).to_pylist() == [False, True, True] # type:ignore
assert arr_out.type is arr_in.type
assert arr_out.score.to_pylist() == [4.0, 5.0, 6.0]
assert arr_out.storage.field(0).to_pylist() == [4.0, 5.0, 6.0]