From 8cd5cf790955ddf39b7c81d6ff50a932d24d8c7f Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Sun, 29 Oct 2023 03:26:46 +0100 Subject: [PATCH] Add scalar test --- .../arrow/extension/variable_shape_tensor.cc | 2 +- python/pyarrow/array.pxi | 28 ++++++++++++++----- python/pyarrow/scalar.pxi | 9 +++--- python/pyarrow/tests/test_extension_type.py | 8 +++++- python/pyarrow/types.pxi | 2 +- 5 files changed, 34 insertions(+), 15 deletions(-) diff --git a/cpp/src/arrow/extension/variable_shape_tensor.cc b/cpp/src/arrow/extension/variable_shape_tensor.cc index 0c6810671fe04..9c303bc1033fc 100644 --- a/cpp/src/arrow/extension/variable_shape_tensor.cc +++ b/cpp/src/arrow/extension/variable_shape_tensor.cc @@ -56,7 +56,7 @@ const Result> VariableShapeTensorArray::GetTensor( } std::vector strides; - // TODO: optimize ComputeStrides for ragged tensors + // TODO: optimize ComputeStrides for non-uniform tensors ARROW_CHECK_OK(internal::ComputeStrides(*value_type.get(), shape, ext_type->permutation(), &strides)); diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index f65f4b27b9951..99d2c42a97d1a 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -3637,18 +3637,32 @@ cdef class VariableShapeTensorArray(ExtensionArray): """ Convert variable shape tensor extension array to list of numpy arrays. """ - cdef: - CVariableShapeTensorArray * ext_array = (self.ap) - CResult[shared_ptr[CTensor]] ctensor - tensors = [] for i in range(len(self.storage)): - with nogil: - ctensor = ext_array.GetTensor(i) - tensors.append(pyarrow_wrap_tensor(GetResultValue(ctensor)).to_numpy()) + tensors.append(self.get_tensor(i).to_numpy()) return tensors + def get_tensor(self, int64_t i): + """ + Get i-th tensor from variable shape tensor extension array. + + Parameters + ---------- + i : int64_t + The index of the tensor to get. + + Returns + ------- + tensor : pyarrow.Tensor + """ + cdef: + CVariableShapeTensorArray* ext_array = (self.ap) + CResult[shared_ptr[CTensor]] ctensor + with nogil: + ctensor = ext_array.GetTensor(i) + return pyarrow_wrap_tensor(GetResultValue(ctensor)) + @staticmethod def from_numpy_ndarray(obj): """ diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index dbfa68dd8e258..7e4a3bddbe300 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -1039,11 +1039,10 @@ class VariableShapeTensorScalar(ExtensionScalar): Note: ``permutation`` should be trivial (``None`` or ``[0, 1, ..., len(shape)-1]``). """ - if self.type.permutation is None or self.type.permutation == list(range(len(self.type.shape))): - shape = self.get("shape") - np_flat = np.asarray(self.get("values").flatten()) - numpy_tensor = np_flat.reshape(tuple(shape)) - return numpy_tensor + if self.type.permutation is None or self.type.permutation == list(range(len(self.type.permutation))): + shape = self.value[0].values.to_pylist() + np_flat = np.asarray(self.value[1].values) + return np_flat.reshape(shape) else: raise ValueError( 'Only non-permuted tensors can be converted to numpy tensors.') diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 355e89cf28a9e..cf8e9c4774b30 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -1387,7 +1387,7 @@ def test_tensor_class_methods(): @pytest.mark.parametrize("value_type", (np.int8, np.int32, np.int64, np.float64)) -def test_variable_shape_tensor_class_method(value_type): +def test_variable_shape_tensor_class_methods(value_type): ndim = 2 shape_type = pa.list_(pa.uint32(), ndim) arrow_type = pa.from_numpy_dtype(value_type) @@ -1433,6 +1433,12 @@ def test_variable_shape_tensor_class_method(value_type): {"data": [7, 8], "shape": [2, 1]}, ] + expected_0 = np.array([[1, 2, 3], [4, 5, 6]], dtype=value_type) + expected_1 = np.array([[7], [8]], dtype=value_type) + + np.testing.assert_array_equal(arr[0].to_numpy_ndarray(), expected_0) + np.testing.assert_array_equal(arr[1].to_numpy_ndarray(), expected_1) + @pytest.mark.parametrize("tensor_type", ( pa.fixed_shape_tensor(pa.int8(), [2, 2, 3]), diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 2ce0d0ffec923..624610d33127f 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1695,7 +1695,7 @@ cdef class VariableShapeTensorType(BaseExtensionType): self.dim_names, self.permutation, self.uniform_shape) - def __variable_ext_scalar_class__(self): + def __arrow_ext_scalar_class__(self): return VariableShapeTensorScalar