From 564d998dcb316a338f6cb414a251540ffba417d4 Mon Sep 17 00:00:00 2001 From: arshPratap Date: Mon, 11 Sep 2023 09:15:49 +0530 Subject: [PATCH] feat: added unbind function to paddle frontend --- .../frontends/paddle/manipulation.py | 11 ++++++ .../test_paddle/test_manipulation.py | 37 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/ivy/functional/frontends/paddle/manipulation.py b/ivy/functional/frontends/paddle/manipulation.py index 358858480d6c2..af7ab64a9c8cf 100644 --- a/ivy/functional/frontends/paddle/manipulation.py +++ b/ivy/functional/frontends/paddle/manipulation.py @@ -170,6 +170,17 @@ def tile(x, repeat_times, name=None): return ivy.tile(x, repeats=repeat_times) +@with_supported_dtypes( + {"2.5.1 and below": ("bool", "int32", "int64", "float16", "float32", "float64")}, + "paddle", +) +@to_ivy_arrays_and_back +def unbind(input, axis=0): + shape = list(input.shape) + shape.pop(axis) + return tuple([x.reshape(tuple(shape)) for x in split(input, 1, axis=axis)]) + + @with_supported_dtypes( { "2.5.1 and below": ( diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py index b42f5286b70fc..84e747400de97 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_manipulation.py @@ -721,6 +721,43 @@ def test_paddle_tile( ) +# unbind +@handle_frontend_test( + fn_tree="paddle.unbind", + dtypes_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, + max_num_dims=2, + max_dim_size=1, + ), + number_positional_args=st.just(1), + axis=st.integers(-1, 0), + test_with_out=st.just(False), +) +def test_paddle_unbind( + *, + dtypes_values, + axis, + on_device, + fn_tree, + backend_fw, + frontend, + test_flags, +): + x_dtype, x = dtypes_values + axis = axis + helpers.test_frontend_function( + input_dtypes=x_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + axis=axis, + ) + + # unstack @handle_frontend_test( fn_tree="paddle.unstack",