diff --git a/ivy/functional/frontends/torch/linalg.py b/ivy/functional/frontends/torch/linalg.py index 3df44b0ae3fa7..5f37f7fdf0fcf 100644 --- a/ivy/functional/frontends/torch/linalg.py +++ b/ivy/functional/frontends/torch/linalg.py @@ -16,18 +16,23 @@ def cholesky(input, *, upper=False, out=None): @to_ivy_arrays_and_back +@with_supported_dtypes( + {"2.2 and below": ("float32", "float64", "complex32", "complex64")}, "torch" +) def cholesky_ex(input, *, upper=False, check_errors=False, out=None): try: + results = namedtuple("cholesky_ex", ['L', 'info']) matrix = ivy.cholesky(input, upper=upper, out=out) info = ivy.zeros(input.shape[:-2], dtype=ivy.int32) - return matrix, info + return results(matrix, info) except RuntimeError as e: if check_errors: raise RuntimeError(e) from e else: + results = namedtuple("cholesky_ex", ['L', 'info']) matrix = input * math.nan info = ivy.ones(input.shape[:-2], dtype=ivy.int32) - return matrix, info + return results(matrix, info) @to_ivy_arrays_and_back diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py index 3fe907273c191..3af1cc7f64318 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_linalg.py @@ -336,7 +336,7 @@ def test_torch_cholesky( @handle_frontend_test( fn_tree="torch.linalg.cholesky_ex", - dtype_and_x=_get_dtype_and_matrix(square=True, batch=True), + dtype_and_x=_get_dtype_and_matrix(square=True), upper=st.booleans(), ) def test_torch_cholesky_ex( @@ -350,8 +350,9 @@ def test_torch_cholesky_ex( backend_fw, ): dtype, x = dtype_and_x - x = np.matmul(x.T, x) + np.identity(x.shape[0]) # make symmetric positive-definite - + x = np.asarray(x[0], dtype=dtype[0]) + x = np.matmul(np.conjugate(x.T), x) + np.identity(x.shape[0], dtype=dtype[0]) + # make symmetric positive-definite helpers.test_frontend_function( input_dtypes=dtype, backend_to_test=backend_fw,