Skip to content

Commit

Permalink
feat(frontend): implement type_of_target utility function in sklearn …
Browse files Browse the repository at this point in the history
…frontend utils along with example based test. This function has repeated use in implementations
  • Loading branch information
Ishticode committed Aug 30, 2023
1 parent 7a048c1 commit 13d99f9
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ivy/functional/frontends/sklearn/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import multiclass
from .multiclass import *
17 changes: 17 additions & 0 deletions ivy/functional/frontends/sklearn/utils/multiclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import ivy


# reapeated utility function
def type_of_target(y, input_name='y'):
# purely utility function
# TODO: implement multilabel-indicator, ...-multioutput, unknown
if y.ndim not in (1, 2):
return "unknown"
if ivy.is_float_dtype(y) and ivy.any(ivy.not_equal(y, y.astype('int64'))):
return "continuous"
else:
vals = ivy.unique_values(y)
if len(vals) > 2:
return "multiclass"
else:
return "binary"
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import ivy
from ivy.functional.frontends.sklearn.utils.multiclass import type_of_target


# not suitable for usual frontend testing
@pytest.mark.parametrize("y, label", [([1.2], "continuous"),
([1], "binary"),
([1, 2], "binary"),
([1, 2, 3], "multiclass"),
([1, 2, 3, 4], "multiclass"),
([1, 2, 3, 4, 5], "multiclass"),
([1, 2, 2], "binary"),
([1, 2., 2, 3], "multiclass"),
([1., 2., 2.], "binary")])
def test_sklearn_type_of_target(y, label):
assert type_of_target(ivy.array(y)) == label

0 comments on commit 13d99f9

Please sign in to comment.