From ae0bb9859526cae3a193d177db3429a702bc3fb8 Mon Sep 17 00:00:00 2001 From: arshPratap Date: Mon, 25 Sep 2023 01:48:25 +0530 Subject: [PATCH] feat: added unstack function to paddle tensor class --- .../frontends/paddle/tensor/tensor.py | 6 +++ .../test_paddle/test_tensor/test_tensor.py | 46 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/ivy/functional/frontends/paddle/tensor/tensor.py b/ivy/functional/frontends/paddle/tensor/tensor.py index eb33af47a81c1..ec18642d832fb 100644 --- a/ivy/functional/frontends/paddle/tensor/tensor.py +++ b/ivy/functional/frontends/paddle/tensor/tensor.py @@ -748,3 +748,9 @@ def mean(self, axis=None, keepdim=False, name=None): ) def less_equal(self, y, name=None): return paddle_frontend.less_equal(self._ivy_array, y) + + @with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "int32", "int64")}, "paddle" + ) + def unstack(self, axis=0, num=None): + return paddle_frontend.unstack(self._ivy_array, axis=axis) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py index 34bef7d1037d3..8385c7bb52e8b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py @@ -3957,6 +3957,52 @@ def test_paddle_tensor_unsqueeze_( ) +# unstack +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="paddle.to_tensor", + method_name="unstack", + dtype_value=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + max_num_dims=2, + max_dim_size=1, + ), + axis=helpers.get_axis( + shape=st.shared(helpers.get_shape(), key="shape"), + allow_neg=True, + force_int=True, + ), +) +def test_paddle_tensor_unstack( + dtype_value, + axis, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtype, x = dtype_value + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "axis": axis, + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # var @handle_frontend_method( class_tree=CLASS_TREE,