From 2422948a888cca4985c3740e6e4e521b082a9d44 Mon Sep 17 00:00:00 2001 From: arshPratap Date: Fri, 1 Sep 2023 13:07:03 +0530 Subject: [PATCH] Added glu activation to Paddle frontend --- .../paddle/nn/functional/activation.py | 11 +++++ .../test_functional/test_activation.py | 41 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/ivy/functional/frontends/paddle/nn/functional/activation.py b/ivy/functional/frontends/paddle/nn/functional/activation.py index cbff12b6b519..be73a7b4b129 100644 --- a/ivy/functional/frontends/paddle/nn/functional/activation.py +++ b/ivy/functional/frontends/paddle/nn/functional/activation.py @@ -40,6 +40,17 @@ def gelu(x, approximate=False, name=None): return ivy.gelu(x, approximate=approximate) +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def glu(x, axis=-1, name=None): + size = x.shape[axis] + ivy.utils.assertions.check_equal( + size % 2, 0, message="axis size must be divisible by 2", as_array=False + ) + a, b = ivy.split(x, num_or_size_splits=2, axis=axis) + return ivy.multiply(a, ivy.sigmoid(b)) + + @with_supported_dtypes({"2.4.2 and below": ("float32", "float64")}, "paddle") @to_ivy_arrays_and_back def gumbel_softmax(x, temperature=1.0, hard=False, axis=-1, name=None): diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py index a1b17a2d3dbf..8af3244a62df 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py @@ -128,6 +128,47 @@ def test_paddle_gelu( ) +# glu +@handle_frontend_test( + fn_tree="paddle.nn.functional.glu", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="linear", + min_value=-2, + min_num_dims=1, + min_dim_size=4, + max_dim_size=4, + ), + axis=helpers.ints(min_value=-1, max_value=0), + test_with_out=st.just(False), +) +def test_paddle_glu( + *, + dtype_and_x, + axis, + on_device, + backend_fw, + fn_tree, + frontend, + test_flags, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + rtol=1e-01, + atol=1e-01, + x=x[0], + axis=axis, + ) + + # gumbel_softmax @handle_frontend_test( fn_tree="paddle.nn.functional.gumbel_softmax",