Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Driver handling in svdvals function in torch_frontend #23718

Merged
merged 26 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
699e80b
handling driver of svdvals in torch
AhmedHossam23 Sep 14, 2023
337870c
handling the driver and solving the formating issue
AhmedHossam23 Sep 14, 2023
8137331
Handling driver and solving formating issue
AhmedHossam23 Sep 15, 2023
da90906
added the test and simplified the implementation
AhmedHossam23 Sep 23, 2023
791b7e8
solving the formating issue
AhmedHossam23 Sep 23, 2023
3841366
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Sep 26, 2023
a4a8a41
added the test to test_linalg
AhmedHossam23 Sep 26, 2023
f8eff89
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 2, 2023
997c5c6
added the driver to the backend
AhmedHossam23 Oct 4, 2023
5df4c6e
🤖 Lint code
ivy-branch Oct 4, 2023
17a6f66
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 4, 2023
fb6bdcd
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 8, 2023
0749056
reformatted
AhmedHossam23 Oct 8, 2023
e891edd
test_array_api
AhmedHossam23 Oct 8, 2023
ec4c9c3
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 8, 2023
ea0e331
Merge branch 'main' into hossam_branch
AhmedHossam23 Oct 9, 2023
d55430d
🤖 Lint code
ivy-branch Oct 9, 2023
a9e575a
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 16, 2023
adb7cc5
added driver argument and to do comment
AhmedHossam23 Oct 16, 2023
71027ba
Merge branch 'hossam_branch' of https://github.com/AhmedHossam23/ivy …
AhmedHossam23 Oct 16, 2023
ffc36e7
Merge branch 'main' into hossam_branch
AhmedHossam23 Oct 20, 2023
df5368b
🤖 Lint code
ivy-branch Oct 20, 2023
b31bb95
Merge branch 'unifyai:main' into hossam_branch
AhmedHossam23 Oct 20, 2023
468b869
added driver arg in right way
AhmedHossam23 Oct 20, 2023
b47923e
updates
AhmedHossam23 Oct 20, 2023
c0fa649
🤖 Lint code
ivy-branch Oct 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ivy/functional/frontends/torch/linalg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# local
import math
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! The native torch shouldn't be imported in the frontends, if ivy.svdvals doesn't support this argument yet, you should first implement the new argument in the backend, add the related argument in the tests, and then use the ivy function in this frontend, thanks!

import ivy
import ivy.functional.frontends.torch as torch_frontend
from ivy.functional.frontends.torch.func_wrapper import to_ivy_arrays_and_back
Expand Down Expand Up @@ -315,7 +316,10 @@ def svd(A, /, *, full_matrices=True, driver=None, out=None):
)
def svdvals(A, *, driver=None, out=None):
# TODO: add handling for driver
return ivy.svdvals(A, out=out)
if driver in ["gesvd", "gesvdj", "gesvda", None]:
return torch.linalg.svdvals(A, driver=driver, out=out)
else:
raise ValueError("Unsupported SVD driver")


@to_ivy_arrays_and_back
Expand Down
2 changes: 1 addition & 1 deletion ivy_tests/array_api_testing/test_array_api
3 changes: 3 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,10 +1156,12 @@ def test_torch_svd(
@handle_frontend_test(
fn_tree="torch.linalg.svdvals",
dtype_and_x=_get_dtype_and_matrix(batch=True),
driver=st.sampled_from([None, "gesvd", "gesvdj", "gesvda"]),
)
def test_torch_svdvals(
*,
dtype_and_x,
driver,
on_device,
fn_tree,
frontend,
Expand All @@ -1174,6 +1176,7 @@ def test_torch_svdvals(
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
driver=driver,
A=x[0],
)

Expand Down
Loading