diff --git a/ivy/functional/frontends/sklearn/utils/__init__.py b/ivy/functional/frontends/sklearn/utils/__init__.py index 340d5c6aca479..ee900dd63a467 100644 --- a/ivy/functional/frontends/sklearn/utils/__init__.py +++ b/ivy/functional/frontends/sklearn/utils/__init__.py @@ -1,2 +1,4 @@ from . import multiclass from .multiclass import * +from . import validation +from .validation import * \ No newline at end of file diff --git a/ivy/functional/frontends/sklearn/utils/validation.py b/ivy/functional/frontends/sklearn/utils/validation.py new file mode 100644 index 0000000000000..0285d6b667d18 --- /dev/null +++ b/ivy/functional/frontends/sklearn/utils/validation.py @@ -0,0 +1,8 @@ +import ivy + +def as_float_array(X, *, copy=True, force_all_finite=True): + if ("bool" in str(X.dtype) or "int" in str(X.dtype) or "uint" in str(X.dtype)) and ivy.itemsize(X) <= 4: + return_dtype = ivy.float32 + else: + return_dtype = ivy.float64 + return ivy.asarray(X, dtype=return_dtype) \ No newline at end of file diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_utils/test_validation.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_utils/test_validation.py new file mode 100644 index 0000000000000..272c866bc732f --- /dev/null +++ b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_utils/test_validation.py @@ -0,0 +1,28 @@ +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test + + +@handle_frontend_test( + fn_tree="sklearn.utils.as_float_array", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + ), +) +def test_sklearn_as_float_array( + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtypes, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + frontend=frontend, + on_device=on_device, + X=x[0], + ) \ No newline at end of file