Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch svd #28770

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f7f499e
fixed the potentially wrong namedtuple definitions in the svd backend…
Jun 20, 2024
3db7f17
try to fix the blas_and_lapack_ops.py.svd with correct output namedtu…
Jun 20, 2024
46d180a
try to fix the blas_and_lapack_ops.py.svd with correct output namedtu…
Jun 20, 2024
dce10a6
replace the unimplemented tensor.mH used to the implemented adjoint, …
Jun 30, 2024
0c13ce6
update test of torch.blas_and_lapack_ops.svd to calculate the validit…
Jul 3, 2024
4d0851d
small fix
Jul 3, 2024
3b3670f
small fix
Jul 3, 2024
71fed6b
updated the test for torch.linalg.svd
Jul 3, 2024
3fdf4dd
find that jax.lax.linalg.svd has a argument "subset_by_index" missing
Jul 3, 2024
c5b5904
fixed the skipping torch svd tests according to suggestion, no longer…
Jul 3, 2024
efe9a5a
tests are partially passing, though for torch backend, "RuntimeError:…
Jul 4, 2024
26aeba0
fix test of numpy.linalg.decomposition.svd as it returns a svd object…
Jul 7, 2024
b9cf1cd
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Jul 14, 2024
3bb5b66
now only torch backend of jax.numpy.linalg.svd is failing due to "Run…
Jul 15, 2024
8b29eb7
all tests for tesnorflow.linalg.svd are passing
Jul 15, 2024
bc30d7d
try to fix the two svd function in torch frontend, now the only probl…
Jul 16, 2024
bed8f77
applied the suggested fix to torch svd tests, they are all passing now
Jul 16, 2024
316986e
make namedtuple definition more simple as suggested
Jul 16, 2024
e0268c6
tried to fix jax.lac.linalg.svd. p.s. there is no implementation of s…
Jul 16, 2024
dcab2c1
fixed jax.numpy.linalg.svd, all tests are passing, but jax.lax.linalg…
Jul 16, 2024
da4a78b
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Jul 18, 2024
ac7a60a
fixing numpy.linalg.decompositions.svd
Jul 18, 2024
9b4c161
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Jul 31, 2024
6c1c39c
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Aug 12, 2024
8e927a4
fixed ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py::te…
Aug 16, 2024
37272a8
fixed ivy_tests/test_ivy/test_frontends/test_numpy/test_linalg/test_d…
Aug 16, 2024
dc90073
Fixing ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops…
Aug 18, 2024
65c902f
fixed ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_lina…
Aug 20, 2024
c92b6be
fixing ivy_tests/test_ivy/test_frontends/test_tensorflow/test_raw_ops…
Aug 20, 2024
c1a6632
changed ivy_tests/test_ivy/test_frontends/test_torch/test_blas_and_la…
Aug 22, 2024
9b75cd1
changed all the torch's test_torch_svd so that complex number inputs …
Aug 22, 2024
5c9de15
try to update torch and tensorflow's svd functions as they somehow re…
Aug 22, 2024
2f76f66
seems like should not use svdvals as it always return a not complex v…
Aug 22, 2024
412e60c
fixed jax's svd to teat for complex input. though only jax.lax.linalg…
Aug 23, 2024
ee58d83
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Aug 28, 2024
9433e91
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Sep 3, 2024
c7d9ddc
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Sep 6, 2024
c89223c
Merge branch 'ivy-llc:main' into fix-torch-frontend-blas_and_lapack_o…
Daniel4078 Sep 28, 2024
99b7266
small update on test_torch.test_tensor.test_torch_svd
Daniel4078 Sep 28, 2024
00f7754
Update test_blas_and_lapack_ops.py
Daniel4078 Sep 28, 2024
49db616
Update test_linalg.py
Daniel4078 Sep 28, 2024
637652e
Update test_linalg.py
Daniel4078 Sep 30, 2024
390b00b
Update test_tensor.py
Daniel4078 Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
try to fix the blas_and_lapack_ops.py.svd with correct output namedtu…
…ple definition and behavior when compute_uv is false
Jin Wang committed Jun 20, 2024
commit 3db7f17dfb0ca3c33bae9b982ee150af56f93bbb
14 changes: 10 additions & 4 deletions ivy/functional/frontends/torch/blas_and_lapack_ops.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
import ivy
from ivy.func_wrapper import with_unsupported_dtypes
import ivy.functional.frontends.torch as torch_frontend
from collections import namedtuple
from ivy.functional.frontends.torch.func_wrapper import to_ivy_arrays_and_back


@@ -191,11 +192,16 @@ def slogdet(A, *, out=None):

@to_ivy_arrays_and_back
def svd(input, some=True, compute_uv=True, *, out=None):
# TODO: add compute_uv
if some:
ret = ivy.svd(input, full_matrices=False)
# TODO: add handling for driver
ret = ivy.svd(input, full_matrices=not some, compute_uv=compute_uv)
results = namedtuple("svd", ['U', 'S', 'V'])
if compute_uv:
ret = results(ret.U, ret.S, ret.Vh.mH)
else:
ret = ivy.svd(input, full_matrices=True)
shape = input.shape
m = shape[-2]
n = shape[-1]
ret = results(ivy.zeros((m,m)), ret.S, ivy.zeros((n,n))) # TODO: keep the zeros on same device as input
if ivy.exists(out):
return ivy.inplace_update(out, ret)
return ret
7 changes: 5 additions & 2 deletions ivy/functional/frontends/torch/linalg.py
Original file line number Diff line number Diff line change
@@ -326,8 +326,11 @@ def solve_ex(A, B, *, left=True, check_errors=False, out=None):
{"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch"
)
def svd(A, /, *, full_matrices=True, driver=None, out=None):
# TODO: add handling for driver and out
return ivy.svd(A, compute_uv=True, full_matrices=full_matrices)
# TODO: add handling for driver
ret = ivy.svd(A, compute_uv=True, full_matrices=full_matrices)
if ivy.exists(out):
return ivy.inplace_update(out, ret)
return ret


@to_ivy_arrays_and_back