Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Driver handling in svdvals function in torch_frontend #23718

Merged
merged 26 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
699e80b
handling driver of svdvals in torch
AhmedHossam23 Sep 14, 2023
337870c
handling the driver and solving the formating issue
AhmedHossam23 Sep 14, 2023
8137331
Handling driver and solving formating issue
AhmedHossam23 Sep 15, 2023
da90906
added the test and simplified the implementation
AhmedHossam23 Sep 23, 2023
791b7e8
solving the formating issue
AhmedHossam23 Sep 23, 2023
3841366
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Sep 26, 2023
a4a8a41
added the test to test_linalg
AhmedHossam23 Sep 26, 2023
f8eff89
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 2, 2023
997c5c6
added the driver to the backend
AhmedHossam23 Oct 4, 2023
5df4c6e
🤖 Lint code
ivy-branch Oct 4, 2023
17a6f66
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 4, 2023
fb6bdcd
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 8, 2023
0749056
reformatted
AhmedHossam23 Oct 8, 2023
e891edd
test_array_api
AhmedHossam23 Oct 8, 2023
ec4c9c3
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 8, 2023
ea0e331
Merge branch 'main' into hossam_branch
AhmedHossam23 Oct 9, 2023
d55430d
🤖 Lint code
ivy-branch Oct 9, 2023
a9e575a
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 16, 2023
adb7cc5
added driver argument and to do comment
AhmedHossam23 Oct 16, 2023
71027ba
Merge branch 'hossam_branch' of https://github.com/AhmedHossam23/ivy …
AhmedHossam23 Oct 16, 2023
ffc36e7
Merge branch 'main' into hossam_branch
AhmedHossam23 Oct 20, 2023
df5368b
🤖 Lint code
ivy-branch Oct 20, 2023
b31bb95
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 20, 2023
468b869
added driver arg in right way
AhmedHossam23 Oct 20, 2023
b47923e
updates
AhmedHossam23 Oct 20, 2023
c0fa649
🤖 Lint code
ivy-branch Oct 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions ivy/functional/backends/torch/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,9 +414,11 @@ def svd(
return results(D)


@with_unsupported_dtypes({"2.1.0 and below": ("float16", "bfloat16")}, backend_version)
def svdvals(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.linalg.svdvals(x, out=out)
@with_unsupported_dtypes({"2.0.1 and below": ("float16", "bfloat16")}, backend_version)
def svdvals(
x: torch.Tensor, /, *, driver: Optional[str], out: Optional[torch.Tensor] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

driver: ... = None required

) -> torch.Tensor:
return torch.linalg.svdvals(x, driver=driver, out=out)


svdvals.support_native_out = True
Expand Down
6 changes: 4 additions & 2 deletions ivy/functional/frontends/torch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,10 @@ def svd(A, /, *, full_matrices=True, driver=None, out=None):
{"2.1.0 and below": ("float32", "float64", "complex32", "complex64")}, "torch"
)
def svdvals(A, *, driver=None, out=None):
# TODO: add handling for driver
return ivy.svdvals(A, out=out)
if driver in ["gesvd", "gesvdj", "gesvda", None]:
return ivy.svdvals(A, driver=driver, out=out)
else:
raise ValueError("Unsupported SVD driver")


@to_ivy_arrays_and_back
Expand Down
9 changes: 8 additions & 1 deletion ivy/functional/ivy/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2263,7 +2263,11 @@ def svd(
@handle_array_function
@handle_device
def svdvals(
x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None
x: Union[ivy.Array, ivy.NativeArray],
/,
*,
driver: Optional[str],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey! sure, but could you please at least add this argument to every backend + a todo comment so that its usage won't break when a user use this on a backend other than torch? thanks!

PS: the driver argument should have a = None

out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Return the singular values of a matrix (or a stack of matrices) ``x``.
Expand All @@ -2273,6 +2277,9 @@ def svdvals(
x
input array having shape ``(..., M, N)`` and whose innermost two dimensions form
``MxN`` matrices.
driver
optional output array,name of the cuSOLVER method to be used. This keyword argument only works on CUDA inputs.
Available options are: None, gesvd, gesvdj, and gesvda.Default: None.
out
optional output array, for writing the result to. It must have a shape that the
inputs broadcast to.
Expand Down
3 changes: 3 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,10 +1156,12 @@ def test_torch_svd(
@handle_frontend_test(
fn_tree="torch.linalg.svdvals",
dtype_and_x=_get_dtype_and_matrix(batch=True),
driver=st.sampled_from([None, "gesvd", "gesvdj", "gesvda"]),
)
def test_torch_svdvals(
*,
dtype_and_x,
driver,
on_device,
fn_tree,
frontend,
Expand All @@ -1174,6 +1176,7 @@ def test_torch_svdvals(
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
driver=driver,
A=x[0],
)

Expand Down
Loading