From 9a362357032b35527e68de34317606b6989d96d5 Mon Sep 17 00:00:00 2001 From: Aryan <78106056+AryanSharma21@users.noreply.github.com> Date: Sat, 16 Sep 2023 05:22:47 +0530 Subject: [PATCH] Added count_nonzero to paddle_frontend (#23282) Co-authored-by: Aryan Sharma Co-authored-by: danielmunioz <47380745+danielmunioz@users.noreply.github.com> --- ivy/functional/frontends/paddle/math.py | 9 ++++++ .../test_frontends/test_paddle/test_math.py | 32 +++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/ivy/functional/frontends/paddle/math.py b/ivy/functional/frontends/paddle/math.py index c4d3a0130f79b..249f3f3ada016 100644 --- a/ivy/functional/frontends/paddle/math.py +++ b/ivy/functional/frontends/paddle/math.py @@ -134,6 +134,15 @@ def cosh(x, name=None): return ivy.cosh(x) +@with_supported_dtypes( + {"2.5.1 and below": ("int32", "int64", "float16", "float32", "float64", "bool")}, + "paddle", +) +@to_ivy_arrays_and_back +def count_nonzero(x, axis=None, keepdim=False, name=None): + return ivy.astype(ivy.count_nonzero(x, axis=axis, keepdims=keepdim), ivy.int64) + + @with_supported_dtypes( { "2.5.1 and below": ( diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_math.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_math.py index 926836080a9b3..7abed35d764fd 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_math.py @@ -586,6 +586,38 @@ def test_paddle_cosh( ) +# count_nonzero +@handle_frontend_test( + fn_tree="paddle.count_nonzero", + dtype_and_x=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes(kind="integer"), + valid_axis=True, + force_int_axis=True, + min_num_dims=1, + ), +) +def test_paddle_count_nonzero( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x, axis = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + on_device=on_device, + fn_tree=fn_tree, + test_flags=test_flags, + frontend=frontend, + x=x[0], + axis=axis, + ) + + # cumprod @handle_frontend_test( fn_tree="paddle.cumprod",