diff --git a/ivy/functional/frontends/sklearn/model_selection/_split.py b/ivy/functional/frontends/sklearn/model_selection/_split.py index f41ebc39265e0..25b51c3451fe6 100644 --- a/ivy/functional/frontends/sklearn/model_selection/_split.py +++ b/ivy/functional/frontends/sklearn/model_selection/_split.py @@ -1,6 +1,7 @@ from abc import ABCMeta, abstractmethod import ivy from ivy.functional.frontends.numpy.func_wrapper import to_ivy_arrays_and_back +from ivy.functional.frontends.sklearn.utils.validation import column_or_1d class BaseCrossValidator(metaclass=ABCMeta): @@ -75,9 +76,7 @@ def __init__( def _iter_test_indices(self, X=None, y=None, groups=None): ivy.seed(self.random_state) y = ivy.array(y) - shape = y.shape - if len(shape) == 2 or shape[1] == 1: - y = ivy.reshape(y, (-1,)) + y = column_or_1d(y) _, y_idx, y_inv, _ = ivy.unique_all(y, return_index=True, return_inverse=True) class_perm = ivy.unique_inverse(y_idx) y_encoded = class_perm[y_inv] diff --git a/ivy/functional/frontends/sklearn/utils/validation.py b/ivy/functional/frontends/sklearn/utils/validation.py index c70d52d288c89..6250aae0ca53c 100644 --- a/ivy/functional/frontends/sklearn/utils/validation.py +++ b/ivy/functional/frontends/sklearn/utils/validation.py @@ -25,4 +25,4 @@ def column_or_1d(y, *, warn=False): elif len(shape) > 2: raise ValueError( "y should be a 1d array or a column vector") - return y \ No newline at end of file + return y