Skip to content

Commit

Permalink
feat: updated tests for max_pool2d
Browse files Browse the repository at this point in the history
  • Loading branch information
aibenStunner committed Nov 1, 2023
1 parent f220e58 commit 80462ec
Showing 1 changed file with 31 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -278,36 +278,54 @@ def test_paddle_avg_pool2d(

# max_pool2d
@handle_frontend_test(
fn_tree="paddle.nn.functional.max_pool2d",
fn_tree="paddle.nn.functional.pooling.max_pool2d",
dtype_x_k_s=helpers.arrays_for_pooling(
min_dims=4, max_dims=4, min_side=2, max_side=4
),
ceil_mode=st.sampled_from([True]),
data_format=st.sampled_from(["NCHW", "NHWC"]),
x_k_s_p=helpers.arrays_for_pooling(min_dims=4, max_dims=4, min_side=1, max_side=4),
stride=st.tuples(st.integers(1, 2), st.integers(1, 2)),
test_with_out=st.just(False),
)
def test_paddle_max_pool2d(
*,
x_k_s_p,
stride,
dtype_x_k_s,
ceil_mode,
data_format,
frontend,
*,
test_flags,
fn_tree,
backend_fw,
frontend,
fn_tree,
on_device,
):
input_dtype, x, kernel_size, _, padding = x_k_s_p
data_format = data_format
input_dtype, x, kernel, stride, padding = dtype_x_k_s

if data_format == "NCHW":
x[0] = x[0].reshape(
(x[0].shape[0], x[0].shape[3], x[0].shape[1], x[0].shape[2])
)
if len(stride) == 1:
stride = (stride[0], stride[0])
if padding == "SAME":
padding = test_pooling_functions.calculate_same_padding(
kernel, stride, x[0].shape[2:]
)
else:
padding = (0, 0)

if padding == "VALID" and ceil_mode:
ceil_mode = False

helpers.test_frontend_function(
input_dtypes=input_dtype,
test_flags=test_flags,
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
x=x[0],
kernel_size=kernel_size,
kernel_size=kernel,
stride=stride,
padding=padding,
ceil_mode=ceil_mode,
data_format=data_format,
)

Expand Down

0 comments on commit 80462ec

Please sign in to comment.