diff --git a/ivy/functional/backends/jax/linear_algebra.py b/ivy/functional/backends/jax/linear_algebra.py index 383a9d032abc6..69089814de7cc 100644 --- a/ivy/functional/backends/jax/linear_algebra.py +++ b/ivy/functional/backends/jax/linear_algebra.py @@ -113,7 +113,7 @@ def eigh( result_tuple = NamedTuple( "eigh", [("eigenvalues", JaxArray), ("eigenvectors", JaxArray)] ) - eigenvalues, eigenvectors = jnp.linalg.eigh(x, UPLO=UPLO) + eigenvalues, eigenvectors = jnp.linalg.eigh(x, UPLO=UPLO, symmetrize_input=False) return result_tuple(eigenvalues, eigenvectors)