Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fast array extraction #7227

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
test cast array xd to features fix
alex-hh committed Oct 15, 2024

Verified

This commit was signed with the committer’s verified signature.
bonjourmauko Mauko Quiroga-Alvarado
commit abbb59a14c23899c577bae94af45b5d706a74e40
19 changes: 17 additions & 2 deletions tests/test_table.py
Original file line number Diff line number Diff line change
@@ -1332,18 +1332,33 @@ def test_cast_array_to_feature_with_list_array_and_large_list_feature(from_list_
assert cast_array.type == expected_array_type


def all_arrays_equal(arr1, arr2):
if len(arr1) != len(arr2):
return False
for a1, a2 in zip(arr1, arr2):
if isinstance(a1, list) and isinstance(a2, list):
if not all_arrays_equal(a1, a2):
return False
elif isinstance(a1, np.ndarray) and isinstance(a2, np.ndarray):
if not (a1 == a2).all():
return False
elif a1 != a2:
return False
return True


def test_cast_array_xd_to_features_sequence():
arr = np.random.randint(0, 10, size=(8, 2, 3)).tolist()
arr = Array2DExtensionType(shape=(2, 3), dtype="int64").wrap_array(pa.array(arr, pa.list_(pa.list_(pa.int64()))))
arr = pa.ListArray.from_arrays([0, None, 4, 8], arr)
# Variable size list
casted_array = cast_array_to_feature(arr, Sequence(Array2D(shape=(2, 3), dtype="int32")))
assert casted_array.type == get_nested_type(Sequence(Array2D(shape=(2, 3), dtype="int32")))
assert (casted_array.to_pylist() == arr.to_pylist()).all()
assert all_arrays_equal(casted_array.to_pylist(), arr.to_pylist())
# Fixed size list
casted_array = cast_array_to_feature(arr, Sequence(Array2D(shape=(2, 3), dtype="int32"), length=4))
assert casted_array.type == get_nested_type(Sequence(Array2D(shape=(2, 3), dtype="int32"), length=4))
assert (casted_array.to_pylist() == arr.to_pylist()).all()
assert all_arrays_equal(casted_array.to_pylist(), arr.to_pylist())


def test_embed_array_storage(image_file):