diff --git a/ivy/functional/backends/jax/linear_algebra.py b/ivy/functional/backends/jax/linear_algebra.py index 815b53a14584c..dea93a65efc0b 100644 --- a/ivy/functional/backends/jax/linear_algebra.py +++ b/ivy/functional/backends/jax/linear_algebra.py @@ -371,7 +371,10 @@ def svd( {"0.4.19 and below": ("bfloat16", "float16", "complex")}, backend_version, ) -def svdvals(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: +def svdvals( + x: JaxArray, /, *, driver: Optional[str] = None, out: Optional[JaxArray] = None +) -> JaxArray: + # TODO: handling the driver argument return jnp.linalg.svd(x, compute_uv=False) diff --git a/ivy/functional/backends/mxnet/linear_algebra.py b/ivy/functional/backends/mxnet/linear_algebra.py index 0ae1d07490e42..e7717406c1a23 100644 --- a/ivy/functional/backends/mxnet/linear_algebra.py +++ b/ivy/functional/backends/mxnet/linear_algebra.py @@ -228,8 +228,10 @@ def svdvals( x: Union[(None, mx.ndarray.NDArray)], /, *, + driver: Optional[str] = None, out: Optional[Union[(None, mx.ndarray.NDArray)]] = None, ) -> Union[(None, mx.ndarray.NDArray)]: + # TODO: handling the driver argument raise IvyNotImplementedException() diff --git a/ivy/functional/backends/numpy/linear_algebra.py b/ivy/functional/backends/numpy/linear_algebra.py index 90b7958df40fb..26257efbba19f 100644 --- a/ivy/functional/backends/numpy/linear_algebra.py +++ b/ivy/functional/backends/numpy/linear_algebra.py @@ -298,7 +298,10 @@ def svd( @with_unsupported_dtypes({"1.26.1 and below": ("float16",)}, backend_version) -def svdvals(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: +def svdvals( + x: np.ndarray, /, *, driver: Optional[str] = None, out: Optional[np.ndarray] = None +) -> np.ndarray: + # TODO: handling the driver argument return np.linalg.svd(x, compute_uv=False) diff --git a/ivy/functional/backends/paddle/linear_algebra.py b/ivy/functional/backends/paddle/linear_algebra.py index 78a553d32c1d2..3e7d5285849f2 100644 --- a/ivy/functional/backends/paddle/linear_algebra.py +++ b/ivy/functional/backends/paddle/linear_algebra.py @@ -521,8 +521,13 @@ def svd( backend_version, ) def svdvals( - x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None + x: paddle.Tensor, + /, + *, + driver: Optional[str] = None, + out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: + # TODO:handling the driver argument return paddle_backend.svd(x)[1] diff --git a/ivy/functional/backends/tensorflow/linear_algebra.py b/ivy/functional/backends/tensorflow/linear_algebra.py index 2a1828c876bc2..10bf64c766fba 100644 --- a/ivy/functional/backends/tensorflow/linear_algebra.py +++ b/ivy/functional/backends/tensorflow/linear_algebra.py @@ -564,8 +564,10 @@ def svdvals( x: Union[tf.Tensor, tf.Variable], /, *, + driver: Optional[str] = None, out: Optional[Union[tf.Tensor, tf.Variable]] = None, ) -> Union[tf.Tensor, tf.Variable]: + # TODO: handling the driver argument ret = tf.linalg.svd(x, compute_uv=False) return ret diff --git a/ivy/functional/backends/torch/linear_algebra.py b/ivy/functional/backends/torch/linear_algebra.py index c8614f3cfee49..d01aad788a9a8 100644 --- a/ivy/functional/backends/torch/linear_algebra.py +++ b/ivy/functional/backends/torch/linear_algebra.py @@ -414,9 +414,15 @@ def svd( return results(D) -@with_unsupported_dtypes({"2.1.0 and below": ("float16", "bfloat16")}, backend_version) -def svdvals(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: - return torch.linalg.svdvals(x, out=out) +@with_unsupported_dtypes({"2.0.1 and below": ("float16", "bfloat16")}, backend_version) +def svdvals( + x: torch.Tensor, + /, + *, + driver: Optional[str] = None, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.linalg.svdvals(x, driver=driver, out=out) svdvals.support_native_out = True diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py index 437ffb764c4e6..57cbfaae05107 100644 --- a/ivy/functional/frontends/torch/linalg.py +++ b/ivy/functional/frontends/torch/linalg.py @@ -314,8 +314,10 @@ def svd(A, /, *, full_matrices=True, driver=None, out=None): {"2.1.0 and below": ("float32", "float64", "complex32", "complex64")}, "torch" ) def svdvals(A, *, driver=None, out=None): - # TODO: add handling for driver - return ivy.svdvals(A, out=out) + if driver in ["gesvd", "gesvdj", "gesvda", None]: + return ivy.svdvals(A, driver=driver, out=out) + else: + raise ValueError("Unsupported SVD driver") @to_ivy_arrays_and_back diff --git a/ivy/functional/ivy/linear_algebra.py b/ivy/functional/ivy/linear_algebra.py index 2387e7e8ccbab..7e580378e5df4 100644 --- a/ivy/functional/ivy/linear_algebra.py +++ b/ivy/functional/ivy/linear_algebra.py @@ -2264,7 +2264,11 @@ def svd( @handle_array_function @handle_device def svdvals( - x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None + x: Union[ivy.Array, ivy.NativeArray], + /, + *, + driver: Optional[str] = None, + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Return the singular values of a matrix (or a stack of matrices) ``x``. @@ -2274,6 +2278,9 @@ def svdvals( x input array having shape ``(..., M, N)`` and whose innermost two dimensions form ``MxN`` matrices. + driver + optional output array,name of the cuSOLVER method to be used. This keyword argument only works on CUDA inputs. + Available options are: None, gesvd, gesvdj, and gesvda.Default: None. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. @@ -2387,7 +2394,7 @@ def svdvals( b: ivy.array([23.16134834, 10.35037804, 4.31025076, 1.35769391]) } """ - return current_backend(x).svdvals(x, out=out) + return current_backend(x).svdvals(x, driver=driver, out=out) @handle_exceptions diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index 45c624d17c354..590ee6c770cfc 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -1156,10 +1156,12 @@ def test_torch_svd( @handle_frontend_test( fn_tree="torch.linalg.svdvals", dtype_and_x=_get_dtype_and_matrix(batch=True), + driver=st.sampled_from([None, "gesvd", "gesvdj", "gesvda"]), ) def test_torch_svdvals( *, dtype_and_x, + driver, on_device, fn_tree, frontend, @@ -1174,6 +1176,7 @@ def test_torch_svdvals( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, + driver=driver, A=x[0], )