diff --git a/ivy/functional/frontends/sklearn/covariance/__init__.py b/ivy/functional/frontends/sklearn/covariance/__init__.py new file mode 100644 index 0000000000000..b8e859ea75ffa --- /dev/null +++ b/ivy/functional/frontends/sklearn/covariance/__init__.py @@ -0,0 +1,2 @@ +from . import _empirical_covariance +from ._empirical_covariance import * \ No newline at end of file diff --git a/ivy/functional/frontends/sklearn/covariance/_empirical_covariance.py b/ivy/functional/frontends/sklearn/covariance/_empirical_covariance.py new file mode 100644 index 0000000000000..fe67edde5183a --- /dev/null +++ b/ivy/functional/frontends/sklearn/covariance/_empirical_covariance.py @@ -0,0 +1,20 @@ +import ivy +from ivy.functional.frontends.numpy import to_ivy_arrays_and_back + +@to_ivy_arrays_and_back +def empirical_covariance(X, *, assume_centered = False): + if X.ndim == 1: + X = ivy.reshape(X, (1, -1)) + + if assume_centered: + covariance = ivy.dot(X.T, X) / X.shape[0] + else: + covariance = ivy.cov(X.T, bias = 1) + + if covariance.ndim == 0: + covariance = ivy.array([[covariance]]) + + if ivy.is_complex_dtype(X): + return covariance.astype(ivy.complex128) + + return covariance.astype(ivy.float64) diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_covariance/test_empirical_covariance.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_covariance/test_empirical_covariance.py new file mode 100644 index 0000000000000..b7d56862c7c51 --- /dev/null +++ b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_covariance/test_empirical_covariance.py @@ -0,0 +1,35 @@ +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test +from hypothesis import strategies as st + +@handle_frontend_test( + fn_tree="sklearn.covariance.empirical_covariance", + dtype_and_x=helpers.dtype_and_values( + available_dtypes= helpers.get_dtypes("valid"), + min_num_dims=1, + max_num_dims=2, + min_dim_size=1, + max_dim_size=3, + ), + assume_centered = st.booleans() +) +def test_sklearn_empirical_covariance( + dtype_and_x, + on_device, + assume_centered, + fn_tree, + frontend, + test_flags, + backend_fw, +): + dtypes, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + frontend=frontend, + on_device=on_device, + X=x[0], + assume_centered=assume_centered + ) \ No newline at end of file