From 27aca1e3ebec1a3580b8d5537828657322ea6867 Mon Sep 17 00:00:00 2001 From: arshPratap Date: Fri, 1 Sep 2023 13:07:03 +0530 Subject: [PATCH] Added glu activation to Paddle frontend --- .../paddle/nn/functional/activation.py | 7 +++++ .../test_functional/test_activation.py | 28 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/ivy/functional/frontends/paddle/nn/functional/activation.py b/ivy/functional/frontends/paddle/nn/functional/activation.py index cbff12b6b5197..aae1f526a4d7e 100644 --- a/ivy/functional/frontends/paddle/nn/functional/activation.py +++ b/ivy/functional/frontends/paddle/nn/functional/activation.py @@ -40,6 +40,13 @@ def gelu(x, approximate=False, name=None): return ivy.gelu(x, approximate=approximate) +@with_supported_dtypes({"2.4.2 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def glu(x, axis=-1, name=None): + a, b = ivy.split(x, num_or_size_splits=2, axis=axis) + return ivy.multiply(a, ivy.sigmoid(b)) + + @with_supported_dtypes({"2.4.2 and below": ("float32", "float64")}, "paddle") @to_ivy_arrays_and_back def gumbel_softmax(x, temperature=1.0, hard=False, axis=-1, name=None): diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py index a1b17a2d3dbf2..38e3e86ea47de 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_nn/test_functional/test_activation.py @@ -128,6 +128,34 @@ def test_paddle_gelu( ) +# glu +@handle_frontend_test( + fn_tree="paddle.nn.functional.glu", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_paddle_glu( + *, + dtype_and_input, + on_device, + backend_fw, + fn_tree, + frontend, + test_flags, +): + input_dtype, x = dtype_and_input + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + ) + + # gumbel_softmax @handle_frontend_test( fn_tree="paddle.nn.functional.gumbel_softmax",