From b7b51dbcdc9927b550b1a8998917f4f98a4c739d Mon Sep 17 00:00:00 2001 From: arshPratap Date: Mon, 11 Sep 2023 09:15:49 +0530 Subject: [PATCH] feat: added unbind function to paddle frontend --- .../frontends/paddle/manipulation.py | 17 ++++++--- .../test_paddle/test_manipulation.py | 37 +++++++++++++++++++ 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/ivy/functional/frontends/paddle/manipulation.py b/ivy/functional/frontends/paddle/manipulation.py index 358858480d6c2..4799ff69bae11 100644 --- a/ivy/functional/frontends/paddle/manipulation.py +++ b/ivy/functional/frontends/paddle/manipulation.py @@ -86,11 +86,6 @@ def repeat_interleave(x, repeats, axis=None, name=None): return ivy.repeat(x, repeats, axis=axis) -@to_ivy_arrays_and_back -def repeat_interleave(x, repeats, axis=None, name=None): - return ivy.repeat(x, repeats, axis=axis) - - @to_ivy_arrays_and_back def reshape(x, shape): return ivy.reshape(x, shape) @@ -170,6 +165,18 @@ def tile(x, repeat_times, name=None): return ivy.tile(x, repeats=repeat_times) +@with_supported_dtypes( + {"2.5.1 and below": ("bool", "int32", "int64", "float16", "float32", "float64")}, + "paddle", +) +@to_ivy_arrays_and_back +def unbind(input, axis=0): + shape = list(input.shape) + num_splits = shape[axis] + shape.pop(axis) + return tuple([x.reshape(tuple(shape)) for x in split(input, num_splits, axis=axis)]) + + @with_supported_dtypes( { "2.5.1 and below": ( diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py index b42f5286b70fc..84e747400de97 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py @@ -721,6 +721,43 @@ def test_paddle_tile( ) +# unbind +@handle_frontend_test( + fn_tree="paddle.unbind", + dtypes_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, + max_num_dims=2, + max_dim_size=1, + ), + number_positional_args=st.just(1), + axis=st.integers(-1, 0), + test_with_out=st.just(False), +) +def test_paddle_unbind( + *, + dtypes_values, + axis, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + x_dtype, x = dtypes_values + axis = axis + helpers.test_frontend_function( + input_dtypes=x_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + axis=axis, + ) + + # unstack @handle_frontend_test( fn_tree="paddle.unstack",