Skip to content

Commit

Permalink
refactor(frontend): use column_or_1d sklearn front function in Stratf…
Browse files Browse the repository at this point in the history
…iedKFold split method
  • Loading branch information
Ishticode committed Sep 3, 2023
1 parent b8ca4a1 commit 5a46b49
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
5 changes: 2 additions & 3 deletions ivy/functional/frontends/sklearn/model_selection/_split.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABCMeta, abstractmethod
import ivy
from ivy.functional.frontends.numpy.func_wrapper import to_ivy_arrays_and_back
from ivy.functional.frontends.sklearn.utils.validation import column_or_1d


class BaseCrossValidator(metaclass=ABCMeta):
Expand Down Expand Up @@ -75,9 +76,7 @@ def __init__(
def _iter_test_indices(self, X=None, y=None, groups=None):
ivy.seed(self.random_state)
y = ivy.array(y)
shape = y.shape
if len(shape) == 2 or shape[1] == 1:
y = ivy.reshape(y, (-1,))
y = column_or_1d(y)
_, y_idx, y_inv, _ = ivy.unique_all(y, return_index=True, return_inverse=True)
class_perm = ivy.unique_inverse(y_idx)
y_encoded = class_perm[y_inv]
Expand Down
2 changes: 1 addition & 1 deletion ivy/functional/frontends/sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ def column_or_1d(y, *, warn=False):
elif len(shape) > 2:
raise ValueError(
"y should be a 1d array or a column vector")
return y
return y

0 comments on commit 5a46b49

Please sign in to comment.