diff --git a/ivy/functional/frontends/mindspore/ops/function/nn_func.py b/ivy/functional/frontends/mindspore/ops/function/nn_func.py index 1b730bf0b6e86..329c1808e1fa3 100644 --- a/ivy/functional/frontends/mindspore/ops/function/nn_func.py +++ b/ivy/functional/frontends/mindspore/ops/function/nn_func.py @@ -403,6 +403,52 @@ def log_softmax(input, axis=-1): return ivy.log_softmax(input) +@with_supported_dtypes( + { + "2.0.0 and below": ( + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float16", + "float32", + "float64", + ) + }, + "mindspore", +) +@to_ivy_arrays_and_back +def max_pool3d( + input, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, + return_indices=False, +): + # ToDo: Add return_indices once superset in implemented + + if not stride: + stride = kernel_size + + data_format = "NCDHW" + + return ivy.max_pool3d( + input, + kernel_size, + stride, + padding, + data_format=data_format, + dilation=dilation, + ceil_mode=ceil_mode, + ) + + @with_supported_dtypes( { "2.0 and below": ( diff --git a/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py b/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py index fe825eda7098f..17919064ef663 100644 --- a/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py +++ b/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py @@ -665,6 +665,52 @@ def test_mindspore_log_softmax( # ) +# max_pool3d +@pytest.mark.skip("Testing pipeline not yet implemented") +@handle_frontend_test( + fn_tree="mindspore.ops.function.nn_func.max_pool3d", + x_k_s_p=helpers.arrays_for_pooling( + min_dims=5, + max_dims=5, + min_side=1, + max_side=4, + only_explicit_padding=True, + return_dilation=True, + data_format="channel_first", + ), + test_with_out=st.just(False), + ceil_mode=st.sampled_from([True, False]), +) +def test_mindspore_max_pool3d( + x_k_s_p, + ceil_mode, + *, + test_flags, + frontend, + backend_fw, + fn_tree, + on_device, +): + input_dtypes, x, kernel_size, stride, padding, dilation = x_k_s_p + + padding = (padding[0][0], padding[1][0], padding[2][0]) + + helpers.test_frontend_function( + input_dtypes=input_dtypes, + backend_to_test=backend_fw, + test_flags=test_flags, + frontend=frontend, + fn_tree=fn_tree, + on_device=on_device, + input=x[0], + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + + # pad @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test(