From 2fa80953e9e3bcf19eafede00e7490dd0f1cd349 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Thu, 31 Aug 2023 23:25:27 +0100 Subject: [PATCH] test(frontend): add test for get_n_split method for KFold in sklearn frontend --- .../test_model_selection/test_split.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_model_selection/test_split.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_model_selection/test_split.py index d009e7cd5c2e6..0496ee4a18238 100644 --- a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_model_selection/test_split.py +++ b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_model_selection/test_split.py @@ -47,6 +47,40 @@ def test_sklearn_kfold_split( ) +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="sklearn.model_selection.KFold", + method_name="get_n_splits", + dtype_x=helpers.dtype_and_values(), +) +def test_sklearn_kfold_get_n_split( + dtype_x, + frontend, + frontend_method_data, + init_flags, + method_flags, + on_device, + backend_fw, +): + input_dtype, x = dtype_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + init_all_as_kwargs_np={ + "n_splits": 2, # this arg only for compatibility + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={ + "X": x[0], + }, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + backend_to_test=backend_fw, + ) + + @handle_frontend_test( fn_tree="sklearn.model_selection.train_test_split", arrays_and_dtypes=helpers.dtype_and_values(