Skip to content

Commit

Permalink
Add scalar test
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Oct 29, 2023
1 parent 19e46a2 commit dff5441
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 15 deletions.
2 changes: 1 addition & 1 deletion cpp/src/arrow/extension/variable_shape_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ const Result<std::shared_ptr<Tensor>> VariableShapeTensorArray::GetTensor(
}

std::vector<int64_t> strides;
// TODO: optimize ComputeStrides for ragged tensors
// TODO: optimize ComputeStrides for non-uniform tensors
ARROW_CHECK_OK(internal::ComputeStrides(*value_type.get(), shape,
ext_type->permutation(), &strides));

Expand Down
19 changes: 12 additions & 7 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -3637,18 +3637,23 @@ cdef class VariableShapeTensorArray(ExtensionArray):
"""
Convert variable shape tensor extension array to list of numpy arrays.
"""
cdef:
CVariableShapeTensorArray * ext_array = <CVariableShapeTensorArray *> (self.ap)
CResult[shared_ptr[CTensor]] ctensor

tensors = []
for i in range(len(self.storage)):
with nogil:
ctensor = ext_array.GetTensor(i)
tensors.append(pyarrow_wrap_tensor(GetResultValue(ctensor)).to_numpy())
tensors.append(self.get_tensor(i).to_numpy())

return tensors

def get_tensor(self, int64_t i):
"""
Get i-th tensor from variable shape tensor extension array.
"""
cdef:
CVariableShapeTensorArray* ext_array = <CVariableShapeTensorArray*>(self.ap)
CResult[shared_ptr[CTensor]] ctensor
with nogil:
ctensor = ext_array.GetTensor(i)
return pyarrow_wrap_tensor(GetResultValue(ctensor))

@staticmethod
def from_numpy_ndarray(obj):
"""
Expand Down
9 changes: 4 additions & 5 deletions python/pyarrow/scalar.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1039,11 +1039,10 @@ class VariableShapeTensorScalar(ExtensionScalar):
Note: ``permutation`` should be trivial (``None`` or ``[0, 1, ..., len(shape)-1]``).
"""

if self.type.permutation is None or self.type.permutation == list(range(len(self.type.shape))):
shape = self.get("shape")
np_flat = np.asarray(self.get("values").flatten())
numpy_tensor = np_flat.reshape(tuple(shape))
return numpy_tensor
if self.type.permutation is None or self.type.permutation == list(range(len(self.type.permutation))):
shape = self.value[0].values.to_pylist()
np_flat = np.asarray(self.value[1].values)
return np_flat.reshape(shape)
else:
raise ValueError(
'Only non-permuted tensors can be converted to numpy tensors.')
Expand Down
8 changes: 7 additions & 1 deletion python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,7 @@ def test_tensor_class_methods():


@pytest.mark.parametrize("value_type", (np.int8, np.int32, np.int64, np.float64))
def test_variable_shape_tensor_class_method(value_type):
def test_variable_shape_tensor_class_methods(value_type):
ndim = 2
shape_type = pa.list_(pa.uint32(), ndim)
arrow_type = pa.from_numpy_dtype(value_type)
Expand Down Expand Up @@ -1328,6 +1328,12 @@ def test_variable_shape_tensor_class_method(value_type):
{"data": [7, 8], "shape": [2, 1]},
]

expected_0 = np.array([[1, 2, 3], [4, 5, 6]], dtype=value_type)
expected_1 = np.array([[7], [8]], dtype=value_type)

np.testing.assert_array_equal(arr[0].to_numpy_ndarray(), expected_0)
np.testing.assert_array_equal(arr[1].to_numpy_ndarray(), expected_1)


@pytest.mark.parametrize("tensor_type", (
pa.fixed_shape_tensor(pa.int8(), [2, 2, 3]),
Expand Down
2 changes: 1 addition & 1 deletion python/pyarrow/types.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1692,7 +1692,7 @@ cdef class VariableShapeTensorType(BaseExtensionType):
self.dim_names, self.permutation,
self.uniform_shape)

def __variable_ext_scalar_class__(self):
def __arrow_ext_scalar_class__(self):
return VariableShapeTensorScalar


Expand Down

0 comments on commit dff5441

Please sign in to comment.