Skip to content

Commit

Permalink
Add max_pool3d for MindSpore frontend (#21780)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolasb0 authored Sep 1, 2023
1 parent cde6fff commit 1eb64d5
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
46 changes: 46 additions & 0 deletions ivy/functional/frontends/mindspore/ops/function/nn_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,52 @@ def log_softmax(input, axis=-1):
return ivy.log_softmax(input)


@with_supported_dtypes(
{
"2.0.0 and below": (
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"float16",
"float32",
"float64",
)
},
"mindspore",
)
@to_ivy_arrays_and_back
def max_pool3d(
input,
kernel_size,
stride=None,
padding=0,
dilation=1,
ceil_mode=False,
return_indices=False,
):
# ToDo: Add return_indices once superset in implemented

if not stride:
stride = kernel_size

data_format = "NCDHW"

return ivy.max_pool3d(
input,
kernel_size,
stride,
padding,
data_format=data_format,
dilation=dilation,
ceil_mode=ceil_mode,
)


@with_supported_dtypes(
{
"2.0 and below": (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,52 @@ def test_mindspore_log_softmax(
# )


# max_pool3d
@pytest.mark.skip("Testing pipeline not yet implemented")
@handle_frontend_test(
fn_tree="mindspore.ops.function.nn_func.max_pool3d",
x_k_s_p=helpers.arrays_for_pooling(
min_dims=5,
max_dims=5,
min_side=1,
max_side=4,
only_explicit_padding=True,
return_dilation=True,
data_format="channel_first",
),
test_with_out=st.just(False),
ceil_mode=st.sampled_from([True, False]),
)
def test_mindspore_max_pool3d(
x_k_s_p,
ceil_mode,
*,
test_flags,
frontend,
backend_fw,
fn_tree,
on_device,
):
input_dtypes, x, kernel_size, stride, padding, dilation = x_k_s_p

padding = (padding[0][0], padding[1][0], padding[2][0])

helpers.test_frontend_function(
input_dtypes=input_dtypes,
backend_to_test=backend_fw,
test_flags=test_flags,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
input=x[0],
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
ceil_mode=ceil_mode,
)


# pad
@pytest.mark.skip("Testing pipeline not yet implemented")
@handle_frontend_test(
Expand Down

0 comments on commit 1eb64d5

Please sign in to comment.