diff --git a/ivy/functional/frontends/sklearn/datasets/_samples_generator.py b/ivy/functional/frontends/sklearn/datasets/_samples_generator.py index 865f097588ef1..14dc07f70dd39 100644 --- a/ivy/functional/frontends/sklearn/datasets/_samples_generator.py +++ b/ivy/functional/frontends/sklearn/datasets/_samples_generator.py @@ -39,3 +39,31 @@ def make_circles( axis=0, ) return X, y + + +def make_moons(n_samples=100, *, shuffle=True, noise=None, random_state=None): + if isinstance(n_samples, numbers.Integral): + n_samples_out = n_samples // 2 + n_samples_in = n_samples - n_samples_out + elif isinstance(n_samples, tuple): + n_samples_out, n_samples_in = n_samples + + + outer_circ_x = ivy.cos(ivy.linspace(0, ivy.pi, n_samples_out)) + outer_circ_y = ivy.sin(ivy.linspace(0, ivy.pi, n_samples_out)) + inner_circ_x = 1 - ivy.cos(ivy.linspace(0, ivy.pi, n_samples_in)) + inner_circ_y = 1 - ivy.sin(ivy.linspace(0, ivy.pi, n_samples_in)) - 0.5 + + + X = ivy.concat([ + ivy.stack([outer_circ_x, outer_circ_y], axis=1), + ivy.stack([inner_circ_x, inner_circ_y], axis=1), + ], axis=0, + ) + y = ivy.concat([ + ivy.zeros(n_samples_out, dtype=ivy.int32), + ivy.ones(n_samples_in, dtype=ivy.int32), + ], axis=0, + ) + + return X, y \ No newline at end of file diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_datasets/test_samples_generators.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_datasets/test_samples_generators.py index d19b4fe13090b..ee8624b5d9be5 100644 --- a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_datasets/test_samples_generators.py +++ b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_datasets/test_samples_generators.py @@ -24,3 +24,27 @@ def test_sklearn_make_circles( on_device=on_device, test_values=False, ) + + +@handle_frontend_test( + fn_tree="sklearn.datasets.make_moons", + n_samples=helpers.ints(min_value=1, max_value=5), +) +def test_sklearn_make_moons( + n_samples, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + helpers.test_frontend_function( + n_samples=n_samples, + input_dtypes=["int32"], + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + frontend=frontend, + on_device=on_device, + test_values=False, + )