diff --git a/ivy/functional/frontends/sklearn/utils/__init__.py b/ivy/functional/frontends/sklearn/utils/__init__.py index 340d5c6aca479..8dbc0d55e215c 100644 --- a/ivy/functional/frontends/sklearn/utils/__init__.py +++ b/ivy/functional/frontends/sklearn/utils/__init__.py @@ -1,2 +1,4 @@ from . import multiclass from .multiclass import * +from . import validation +from .validation import * diff --git a/ivy/functional/frontends/sklearn/utils/validation.py b/ivy/functional/frontends/sklearn/utils/validation.py new file mode 100644 index 0000000000000..5c8e1f93d12e8 --- /dev/null +++ b/ivy/functional/frontends/sklearn/utils/validation.py @@ -0,0 +1,13 @@ +import ivy +from ivy.functional.frontends.numpy.func_wrapper import to_ivy_arrays_and_back + + +@to_ivy_arrays_and_back +def as_float_array(X, *, copy=True, force_all_finite=True): + if X.dtype in [ivy.float32, ivy.float64]: + return X.copy_array() if copy else X + if ("bool" in X.dtype or "int" in X.dtype or "uint" in X.dtype) and ivy.itemsize(X) <= 4: + return_dtype = ivy.float32 + else: + return_dtype = ivy.float64 + return ivy.asarray(X, dtype=return_dtype) diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_utils/test_validation.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_utils/test_validation.py new file mode 100644 index 0000000000000..272c866bc732f --- /dev/null +++ b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_utils/test_validation.py @@ -0,0 +1,28 @@ +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test + + +@handle_frontend_test( + fn_tree="sklearn.utils.as_float_array", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_sklearn_as_float_array( + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtypes, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + frontend=frontend, + on_device=on_device, + X=x[0], + ) \ No newline at end of file