From 30b1b2acd5e83c92110b61b1982e458896561ffb Mon Sep 17 00:00:00 2001 From: arshPratap Date: Thu, 28 Sep 2023 00:26:35 +0530 Subject: [PATCH] feat: added unbind function to paddle tensor class --- .../frontends/paddle/tensor/tensor.py | 16 +++++++ .../test_paddle/test_tensor/test_tensor.py | 43 +++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/ivy/functional/frontends/paddle/tensor/tensor.py b/ivy/functional/frontends/paddle/tensor/tensor.py index 3661cdaf67d6d..203ff6f4c467c 100644 --- a/ivy/functional/frontends/paddle/tensor/tensor.py +++ b/ivy/functional/frontends/paddle/tensor/tensor.py @@ -798,3 +798,19 @@ def real(self, name=None): ) def cast(self, dtype): return paddle_frontend.cast(self, dtype) + + @with_supported_dtypes( + { + "2.5.1 and below": ( + "bool", + "int32", + "int64", + "float16", + "float32", + "float64", + ) + }, + "paddle", + ) + def unbind(self, axis=0): + return paddle_frontend.unbind(self._ivy_array, axis=axis) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py index 255de088250a0..169efe6dd2993 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_tensor.py @@ -4160,6 +4160,49 @@ def test_paddle_tensor_trunc( ) +# unbind +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="paddle.to_tensor", + method_name="unbind", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, + max_num_dims=2, + max_dim_size=1, + force_int_axis=True, + min_axis=-1, + max_axis=0, + ), +) +def test_paddle_tensor_unbind( + dtype_x_axis, + frontend_method_data, + init_flags, + method_flags, + frontend, + on_device, + backend_fw, +): + input_dtypes, x, axis = dtype_x_axis + helpers.test_frontend_method( + init_input_dtypes=input_dtypes, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtypes, + method_all_as_kwargs_np={ + "axis": axis, + }, + frontend=frontend, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + on_device=on_device, + ) + + # unsqueeze @handle_frontend_method( class_tree=CLASS_TREE,