Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Driver handling in svdvals function in torch_frontend #23718

Merged
merged 26 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
699e80b
handling driver of svdvals in torch
AhmedHossam23 Sep 14, 2023
337870c
handling the driver and solving the formating issue
AhmedHossam23 Sep 14, 2023
8137331
Handling driver and solving formating issue
AhmedHossam23 Sep 15, 2023
da90906
added the test and simplified the implementation
AhmedHossam23 Sep 23, 2023
791b7e8
solving the formating issue
AhmedHossam23 Sep 23, 2023
3841366
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Sep 26, 2023
a4a8a41
added the test to test_linalg
AhmedHossam23 Sep 26, 2023
f8eff89
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 2, 2023
997c5c6
added the driver to the backend
AhmedHossam23 Oct 4, 2023
5df4c6e
🤖 Lint code
ivy-branch Oct 4, 2023
17a6f66
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 4, 2023
fb6bdcd
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 8, 2023
0749056
reformatted
AhmedHossam23 Oct 8, 2023
e891edd
test_array_api
AhmedHossam23 Oct 8, 2023
ec4c9c3
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 8, 2023
ea0e331
Merge branch 'main' into hossam_branch
AhmedHossam23 Oct 9, 2023
d55430d
🤖 Lint code
ivy-branch Oct 9, 2023
a9e575a
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 16, 2023
adb7cc5
added driver argument and to do comment
AhmedHossam23 Oct 16, 2023
71027ba
Merge branch 'hossam_branch' of https://github.com/AhmedHossam23/ivy …
AhmedHossam23 Oct 16, 2023
ffc36e7
Merge branch 'main' into hossam_branch
AhmedHossam23 Oct 20, 2023
df5368b
🤖 Lint code
ivy-branch Oct 20, 2023
b31bb95
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 20, 2023
468b869
added driver arg in right way
AhmedHossam23 Oct 20, 2023
b47923e
updates
AhmedHossam23 Oct 20, 2023
c0fa649
🤖 Lint code
ivy-branch Oct 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ivy/functional/backends/jax/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,10 @@ def svd(
{"0.4.19 and below": ("bfloat16", "float16", "complex")},
backend_version,
)
def svdvals(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
def svdvals(
x: JaxArray, /, *, driver: Optional[str] = None, out: Optional[JaxArray] = None
juliagsy marked this conversation as resolved.
Show resolved Hide resolved
) -> JaxArray:
# TODO: handling the driver argument
return jnp.linalg.svd(x, compute_uv=False)


Expand Down
2 changes: 2 additions & 0 deletions ivy/functional/backends/mxnet/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,10 @@ def svdvals(
x: Union[(None, mx.ndarray.NDArray)],
/,
*,
driver: Optional[str] = None,
out: Optional[Union[(None, mx.ndarray.NDArray)]] = None,
) -> Union[(None, mx.ndarray.NDArray)]:
# TODO: handling the driver argument
raise IvyNotImplementedException()


Expand Down
5 changes: 4 additions & 1 deletion ivy/functional/backends/numpy/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,10 @@ def svd(


@with_unsupported_dtypes({"1.26.1 and below": ("float16",)}, backend_version)
def svdvals(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray:
def svdvals(
x: np.ndarray, /, *, driver: Optional[str] = None, out: Optional[np.ndarray] = None
) -> np.ndarray:
# TODO: handling the driver argument
return np.linalg.svd(x, compute_uv=False)


Expand Down
7 changes: 6 additions & 1 deletion ivy/functional/backends/paddle/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,8 +521,13 @@ def svd(
backend_version,
)
def svdvals(
x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None
x: paddle.Tensor,
/,
*,
driver: Optional[str] = None,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
# TODO:handling the driver argument
return paddle_backend.svd(x)[1]


Expand Down
2 changes: 2 additions & 0 deletions ivy/functional/backends/tensorflow/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,8 +564,10 @@ def svdvals(
x: Union[tf.Tensor, tf.Variable],
/,
*,
driver: Optional[str] = None,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
# TODO: handling the driver argument
ret = tf.linalg.svd(x, compute_uv=False)
return ret

Expand Down
12 changes: 9 additions & 3 deletions ivy/functional/backends/torch/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,9 +414,15 @@ def svd(
return results(D)


@with_unsupported_dtypes({"2.1.0 and below": ("float16", "bfloat16")}, backend_version)
def svdvals(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.linalg.svdvals(x, out=out)
@with_unsupported_dtypes({"2.0.1 and below": ("float16", "bfloat16")}, backend_version)
def svdvals(
x: torch.Tensor,
/,
*,
driver: Optional[str] = None,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.linalg.svdvals(x, driver=driver, out=out)


svdvals.support_native_out = True
Expand Down
6 changes: 4 additions & 2 deletions ivy/functional/frontends/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,10 @@ def svd(A, /, *, full_matrices=True, driver=None, out=None):
{"2.1.0 and below": ("float32", "float64", "complex32", "complex64")}, "torch"
)
def svdvals(A, *, driver=None, out=None):
# TODO: add handling for driver
return ivy.svdvals(A, out=out)
if driver in ["gesvd", "gesvdj", "gesvda", None]:
return ivy.svdvals(A, driver=driver, out=out)
else:
raise ValueError("Unsupported SVD driver")


@to_ivy_arrays_and_back
Expand Down
11 changes: 9 additions & 2 deletions ivy/functional/ivy/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2264,7 +2264,11 @@ def svd(
@handle_array_function
@handle_device
def svdvals(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
driver: Optional[str] = None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Return the singular values of a matrix (or a stack of matrices) ``x``.
Expand All @@ -2274,6 +2278,9 @@ def svdvals(
x
input array having shape ``(..., M, N)`` and whose innermost two dimensions form
``MxN`` matrices.
driver
optional output array,name of the cuSOLVER method to be used. This keyword argument only works on CUDA inputs.
Available options are: None, gesvd, gesvdj, and gesvda.Default: None.
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.
Expand Down Expand Up @@ -2387,7 +2394,7 @@ def svdvals(
b: ivy.array([23.16134834, 10.35037804, 4.31025076, 1.35769391])
}
"""
return current_backend(x).svdvals(x, out=out)
return current_backend(x).svdvals(x, driver=driver, out=out)


@handle_exceptions
Expand Down
3 changes: 3 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,10 +1156,12 @@ def test_torch_svd(
@handle_frontend_test(
fn_tree="torch.linalg.svdvals",
dtype_and_x=_get_dtype_and_matrix(batch=True),
driver=st.sampled_from([None, "gesvd", "gesvdj", "gesvda"]),
)
def test_torch_svdvals(
*,
dtype_and_x,
driver,
on_device,
fn_tree,
frontend,
Expand All @@ -1174,6 +1176,7 @@ def test_torch_svdvals(
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
driver=driver,
A=x[0],
)

Expand Down
Loading