From 8dc7b29181ed608358845fef51da71c7f992a0ed Mon Sep 17 00:00:00 2001
From: Christian Green <christian.j.green@gmail.com>
Date: Mon, 13 Jan 2025 10:25:52 -0800
Subject: [PATCH] Blocked Jacobi method for eigen decomposition (#1510)

Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
---
 nx/lib/nx/binary_backend.ex                   |  19 --
 nx/lib/nx/binary_backend/matrix.ex            | 272 ----------------
 nx/lib/nx/lin_alg.ex                          |  44 +--
 nx/lib/nx/lin_alg/block_eigh.ex               | 304 ++++++++++++++++++
 nx/test/nx/defn/grad_test.exs                 |  14 +-
 nx/test/nx/lin_alg_test.exs                   | 114 +++----
 torchx/mix.exs                                |   4 +-
 torchx/test/torchx/nx_linalg_doctest_test.exs |   2 +-
 8 files changed, 396 insertions(+), 377 deletions(-)
 create mode 100644 nx/lib/nx/lin_alg/block_eigh.ex

diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex
index e3ae4121d6f..e2795d6071c 100644
--- a/nx/lib/nx/binary_backend.ex
+++ b/nx/lib/nx/binary_backend.ex
@@ -1240,25 +1240,6 @@ defmodule Nx.BinaryBackend do
     output_batch_groups |> Enum.with_index() |> Enum.map(fn {x, i} -> {x, rem(i, groups)} end)
   end
 
-  @impl true
-  def eigh(
-        {%{type: output_type} = eigenvals_holder, eigenvecs_holder},
-        %{type: input_type, shape: input_shape} = tensor,
-        opts
-      ) do
-    bin = to_binary(tensor)
-    rank = tuple_size(input_shape)
-    n = elem(input_shape, rank - 1)
-
-    {eigenvals, eigenvecs} =
-      bin_batch_reduce(bin, n * n, input_type, {<<>>, <<>>}, fn matrix, {vals_acc, vecs_acc} ->
-        {vals, vecs} = B.Matrix.eigh(matrix, input_type, {n, n}, output_type, opts)
-        {vals_acc <> vals, vecs_acc <> vecs}
-      end)
-
-    {from_binary(eigenvals_holder, eigenvals), from_binary(eigenvecs_holder, eigenvecs)}
-  end
-
   @impl true
   def lu(
         {%{type: p_type} = p_holder, %{type: l_type} = l_holder, %{type: u_type} = u_holder},
diff --git a/nx/lib/nx/binary_backend/matrix.ex b/nx/lib/nx/binary_backend/matrix.ex
index 85601295b97..afb55fb6689 100644
--- a/nx/lib/nx/binary_backend/matrix.ex
+++ b/nx/lib/nx/binary_backend/matrix.ex
@@ -116,150 +116,6 @@ defmodule Nx.BinaryBackend.Matrix do
 
   defp do_ts([], [], _idx, acc), do: acc
 
-  defp qr_decomposition(matrix, n, _eps) when n in 0..1 do
-    {[[1.0]], matrix}
-  end
-
-  defp qr_decomposition(matrix, n, eps) when n >= 2 do
-    # QR decomposition is performed by using Householder transform
-    # this function originally supported generic QR, but
-    # it is now only used by eigh. Because of this,
-    # we simplified the function signature to only
-    # support square matrices.
-
-    {q_matrix, r_matrix} =
-      for i <- 0..(n - 2)//1, reduce: {nil, matrix} do
-        {q, r} ->
-          h =
-            r
-            |> slice_matrix([i, i], [n - i, 1])
-            |> householder_reflector(n, eps)
-
-          # If we haven't allocated Q yet, let Q = H1
-          # TODO: Resolve inconsistent with the Householder reflector.
-          # cf. https://github.com/elixir-nx/nx/pull/933#discussion_r982772063
-          q =
-            if is_nil(q) do
-              h
-            else
-              dot_matrix_real(q, h)
-            end
-
-          r = dot_matrix_real(h, r)
-          {q, r}
-      end
-
-    {approximate_zeros(q_matrix, eps), approximate_zeros(r_matrix, eps)}
-  end
-
-  defp raise_not_hermitian do
-    raise ArgumentError,
-          "matrix must be hermitian, a matrix is hermitian iff X = adjoint(X)"
-  end
-
-  def eigh(input_data, input_type, {n, n} = input_shape, output_type, opts) do
-    eps = opts[:eps]
-    max_iter = opts[:max_iter]
-
-    # Validate that the input is a Hermitian matrix using the relation A^* = A.
-    a = binary_to_matrix(input_data, input_type, input_shape)
-
-    is_hermitian =
-      a
-      |> transpose_matrix()
-      |> Enum.map(fn a_row -> Enum.map(a_row, &Complex.conjugate(&1)) end)
-      |> is_approximately_same?(a, eps)
-
-    unless is_hermitian do
-      raise_not_hermitian()
-    end
-
-    # Hessenberg decomposition
-    {h, q_h} = hessenberg_decomposition(a, n, eps)
-
-    # QR iteration for eigenvalues and eigenvectors
-    {eigenvals_diag, eigenvecs} =
-      Enum.reduce_while(1..max_iter//1, {h, q_h}, fn _, {a_old, q_old} ->
-        # QR decomposition
-        {q_now, r_now} = qr_decomposition(a_old, n, eps)
-
-        # Update matrix A, Q
-        a_new = dot_matrix_real(r_now, q_now)
-        q_new = dot_matrix_real(q_old, q_now)
-
-        if is_approximately_same?(q_old, q_new, eps) do
-          {:halt, {a_new, q_new}}
-        else
-          {:cont, {a_new, q_new}}
-        end
-      end)
-
-    # Obtain the eigenvalues, which are the diagonal elements
-    indices_diag = for idx <- 0..(n - 1), do: [idx, idx]
-    eigenvals = get_matrix_elements(eigenvals_diag, indices_diag)
-
-    # In general, the eigenvalues of a Hermitian matrix are real numbers
-    eigenvals_real = eigenvals |> Enum.map(&Complex.real(&1))
-
-    # Reduce the elements smaller than eps to zero
-    {eigenvals_real |> approximate_zeros(eps) |> matrix_to_binary(output_type),
-     eigenvecs |> approximate_zeros(eps) |> matrix_to_binary(output_type)}
-  end
-
-  defp hessenberg_decomposition(matrix, n, _eps) when n in 0..1 do
-    {matrix, [[1.0]]}
-  end
-
-  defp hessenberg_decomposition(matrix, n, eps) do
-    # Hessenberg decomposition is performed by using Householder transform
-    {hess_matrix, q_matrix} =
-      for i <- 0..(n - 2)//1, reduce: {matrix, nil} do
-        {hess, q} ->
-          h =
-            hess
-            |> slice_matrix([i + 1, i], [n - i - 1, 1])
-            |> householder_reflector(n, eps)
-
-          # If we haven't allocated Q yet, let Q = H1
-          # TODO: Resolve inconsistent with the Householder reflector.
-          # cf. https://github.com/elixir-nx/nx/pull/933#discussion_r982772063
-          q =
-            if is_nil(q) do
-              h
-            else
-              dot_matrix_real(q, h)
-            end
-
-          # Hessenberg matrix H updating
-          h_adj = adjoint_matrix(h)
-
-          hess =
-            h
-            |> dot_matrix_real(hess)
-            |> dot_matrix_real(h_adj)
-
-          {hess, q}
-      end
-
-    {approximate_zeros(hess_matrix, eps), approximate_zeros(q_matrix, eps)}
-  end
-
-  defp is_approximately_same?(a, b, eps) do
-    # Determine if matrices `a` and `b` are equal in the range of eps
-    a
-    |> Enum.zip(b)
-    |> Enum.all?(fn {a_row, b_row} ->
-      a_row
-      |> Enum.zip(b_row)
-      |> Enum.all?(fn
-        {a_elem, b_elem} ->
-          abs_diff = Complex.abs(a_elem - b_elem)
-
-          abs_diff == :nan or abs_diff <= eps
-      end)
-    end)
-  end
-
   def lu(input_data, input_type, {n, n} = input_shape, p_type, l_type, u_type, opts) do
     a = binary_to_matrix(input_data, input_type, input_shape)
     eps = opts[:eps]
@@ -361,116 +217,6 @@ defmodule Nx.BinaryBackend.Matrix do
     end)
   end
 
-  ## Householder helpers
-
-  defp householder_reflector(a, target_k, eps)
-
-  defp householder_reflector([], target_k, _eps) do
-    flat_list =
-      for col <- 0..(target_k - 1), row <- 0..(target_k - 1), into: [] do
-        if col == row, do: 1, else: 0
-      end
-
-    Enum.chunk_every(flat_list, target_k)
-  end
-
-  defp householder_reflector(a, target_k, eps) do
-    {v, scale, is_complex} = householder_reflector_pivot(a, eps)
-
-    prefix_threshold = target_k - length(v)
-    v = List.duplicate(0, prefix_threshold) ++ v
-
-    # dot(v, v) = norm_v_squared, which can be calculated from norm_a as:
-    # norm_v_squared = norm_a_squared - a_0^2 + v_0^2
-
-    # execute I - 2 / norm_v_squared * outer(v, v)
-    {_, _, reflector_reversed} =
-      for col_factor <- v, row_factor <- v, reduce: {0, 0, []} do
-        {row, col, acc} ->
-          row_factor = if is_complex, do: Complex.conjugate(row_factor), else: row_factor
-
-          # The current element in outer(v, v) is given by col_factor * row_factor
-          # and the current I element is 1 when row == col
-          identity_element = if row == col, do: 1, else: 0
-
-          result =
-            if row >= prefix_threshold and col >= prefix_threshold do
-              identity_element -
-                scale * col_factor * row_factor
-            else
-              identity_element
-            end
-
-          acc = [result | acc]
-
-          if col + 1 == target_k do
-            {row + 1, 0, acc}
-          else
-            {row, col + 1, acc}
-          end
-      end
-
-    # This is equivalent to reflector_reversed |> Enum.reverse() |> Enum.chunk_every(target_k)
-    {reflector, _, _} =
-      for x <- reflector_reversed, reduce: {[], [], 0} do
-        {result_acc, row_acc, col} ->
-          row_acc = [x | row_acc]
-
-          if col + 1 == target_k do
-            {[row_acc | result_acc], [], 0}
-          else
-            {result_acc, row_acc, col + 1}
-          end
-      end
-
-    reflector
-  end
-
-  defp householder_reflector_pivot([a_0 | tail] = a, eps) when is_number(a_0) do
-    # This is a trick so we can both calculate the norm of a_reverse and extract the
-    # head a the same time we reverse the array
-    # receives a_reverse as a list of numbers and returns the reflector as a
-    # k x k matrix
-
-    norm_a_squared = Enum.reduce(a, 0, fn x, acc -> x * Complex.conjugate(x) + acc end)
-    norm_a_sq_1on = norm_a_squared - a_0 * a_0
-
-    if norm_a_sq_1on < eps do
-      {[1 | tail], 0, false}
-    else
-      v_0 =
-        if a_0 <= 0 do
-          a_0 - Complex.sqrt(norm_a_squared)
-        else
-          -norm_a_sq_1on / (a_0 + Complex.sqrt(norm_a_squared))
-        end
-
-      v_0_sq = v_0 * v_0
-      scale = 2 * v_0_sq / (norm_a_sq_1on + v_0_sq)
-      v = [1 | Enum.map(tail, &(&1 / v_0))]
-      {v, scale, false}
-    end
-  end
-
-  defp householder_reflector_pivot([a_0 | tail], _eps) do
-    # complex case
-    norm_a_sq_1on = Enum.reduce(tail, 0, &(Complex.abs_squared(&1) + &2))
-    norm_a_sq = norm_a_sq_1on + Complex.abs_squared(a_0)
-    norm_a = Complex.sqrt(norm_a_sq)
-
-    phase_a_0 = Complex.phase(a_0)
-    alfa = Complex.exp(Complex.new(0, phase_a_0)) * norm_a
-
-    # u = x - alfa * e1
-    u_0 = a_0 + alfa
-    u = [u_0 | tail]
-    norm_u_sq = norm_a_sq_1on + Complex.abs_squared(u_0)
-    norm_u = Complex.sqrt(norm_u_sq)
-
-    v = Enum.map(u, &(&1 / norm_u))
-    {v, 2, true}
-  end
-
   ## Matrix (2-D array) manipulation
 
   defp dot_matrix([], _), do: 0
@@ -491,24 +237,6 @@ defmodule Nx.BinaryBackend.Matrix do
     end)
   end
 
-  defp dot_matrix_real(m1, m2) do
-    Enum.map(m1, fn row ->
-      m2
-      |> transpose_matrix()
-      |> Enum.map(fn col ->
-        Enum.zip_reduce(row, col, 0, fn x, y, acc -> acc + x * y end)
-      end)
-    end)
-  end
-
-  defp adjoint_matrix([x | _] = m) when not is_list(x) do
-    Enum.map(m, &[Complex.conjugate(&1)])
-  end
-
-  defp adjoint_matrix(m) do
-    Enum.zip_with(m, fn cols -> Enum.map(cols, &Complex.conjugate/1) end)
-  end
-
   defp transpose_matrix([x | _] = m) when not is_list(x) do
     Enum.map(m, &[&1])
   end
diff --git a/nx/lib/nx/lin_alg.ex b/nx/lib/nx/lin_alg.ex
index 450a48cae42..edced802468 100644
--- a/nx/lib/nx/lin_alg.ex
+++ b/nx/lib/nx/lin_alg.ex
@@ -1179,8 +1179,8 @@ defmodule Nx.LinAlg do
       #Nx.Tensor<
         f32[2][2]
         [
-          [3.9924824237823486, -1.0052783489227295],
-          [-3.0051186084747314, 1.0071179866790771]
+          [4.000002861022949, -1.0000008344650269],
+          [-3.000002384185791, 1.0000005960464478]
         ]
       >
 
@@ -1275,14 +1275,14 @@ defmodule Nx.LinAlg do
       iex> Nx.round(eigenvals)
       #Nx.Tensor<
         f32[2]
-        [1.0, 2.0]
+        [2.0, 1.0]
       >
       iex> eigenvecs
       #Nx.Tensor<
         f32[2][2]
         [
-          [1.0, 0.0],
-          [0.0, 1.0]
+          [0.0, 1.0],
+          [1.0, 0.0]
         ]
       >
 
@@ -1296,9 +1296,9 @@ defmodule Nx.LinAlg do
       #Nx.Tensor<
         f32[3][3]
         [
-          [0.4075949788093567, 0.9131628274917603, 0.0],
-          [0.40837883949279785, -0.18228201568126678, 0.8944271802902222],
-          [0.8167576789855957, -0.36456403136253357, -0.4472135901451111]
+          [0.40824827551841736, -0.18257419764995575, 0.8944271802902222],
+          [0.40824833512306213, 0.9128708839416504, 0.0],
+          [0.8164965510368347, -0.3651483952999115, -0.4472135901451111]
         ]
       >
 
@@ -1308,7 +1308,7 @@ defmodule Nx.LinAlg do
         f32[2][2]
         [
           [9.0, -1.0],
-          [1.0, 4.0]
+          [4.0, 1.0]
         ]
       >
       iex> eigenvecs
@@ -1316,12 +1316,12 @@ defmodule Nx.LinAlg do
         f32[2][2][2]
         [
           [
-            [0.5612090229988098, -0.8276740908622742],
-            [0.8276740908622742, 0.5612090229988098]
+            [0.5606288313865662, 0.8280671834945679],
+            [0.8280671834945679, -0.5606288313865662]
           ],
           [
-            [1.0, 0.0],
-            [0.0, 1.0]
+            [0.0, 1.0],
+            [1.0, 0.0]
           ]
         ]
       >
@@ -1334,7 +1334,7 @@ defmodule Nx.LinAlg do
         f32[2]
         [
           [9.0, -1.0],
-          [1.0, 4.0]
+          [4.0, 1.0]
         ]
       >
       iex> eigenvecs
@@ -1343,12 +1343,12 @@ defmodule Nx.LinAlg do
         f32[2][2]
         [
           [
-            [0.5612090229988098, -0.8276740908622742],
-            [0.8276740908622742, 0.5612090229988098]
+            [0.5606288313865662, 0.8280671834945679],
+            [0.8280671834945679, -0.5606288313865662]
           ],
           [
-            [1.0, 0.0],
-            [0.0, 1.0]
+            [0.0, 1.0],
+            [1.0, 0.0]
           ]
         ]
       >
@@ -1376,7 +1376,7 @@ defmodule Nx.LinAlg do
        %{tensor | names: eigenvecs_name, type: output_type, shape: eigenvecs_shape}}
 
     :eigh
-    |> Nx.Shared.optional([tensor, opts], output, &Nx.LinAlg.Eigh.eigh/2)
+    |> Nx.Shared.optional([tensor, opts], output, &Nx.LinAlg.BlockEigh.eigh/2)
     |> Nx.vectorize(vectorized_axes)
   end
 
@@ -2161,19 +2161,19 @@ defmodule Nx.LinAlg do
       iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2], [2, 3]]), Nx.tensor([1, 2]))
       #Nx.Tensor<
         f32[2]
-        [0.9977624416351318, 0.0011188983917236328]
+        [1.0000028610229492, -2.384185791015625e-6]
       >
 
       iex> Nx.LinAlg.least_squares(Nx.tensor([[0, 1], [1, 1], [2, 1], [3, 1]]), Nx.tensor([-1, 0.2, 0.9, 2.1]))
       #Nx.Tensor<
         f32[2]
-        [0.9966151118278503, -0.947966456413269]
+        [0.9999998211860657, -0.9500012993812561]
       >
 
       iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2, 3], [4, 5, 6]]), Nx.tensor([1, 2]))
       #Nx.Tensor<
         f32[3]
-        [-0.05534052848815918, 0.1111316829919815, 0.27760395407676697]
+        [-0.05555540323257446, 0.1111111044883728, 0.27777770161628723]
       >
 
   ## Error cases
diff --git a/nx/lib/nx/lin_alg/block_eigh.ex b/nx/lib/nx/lin_alg/block_eigh.ex
new file mode 100644
index 00000000000..1c6f7986d90
--- /dev/null
+++ b/nx/lib/nx/lin_alg/block_eigh.ex
@@ -0,0 +1,304 @@
+defmodule Nx.LinAlg.BlockEigh do
+  @moduledoc """
+  Parallel Jacobi symmetric eigendecomposition.
+
+  Reference implementation taking from XLA's eigh_expander
+  which is built on the approach in:
+  Brent, R. P., & Luk, F. T. (1985). The solution of singular-value
+  and symmetric eigenvalue problems on multiprocessor arrays.
+  SIAM Journal on Computing, 6(1), 69-84. https://doi.org/10.1137/0906007
+  """
+  require Nx
+
+  import Nx.Defn
+
+  defn eigh(matrix, opts \\ []) do
+    opts = keyword!(opts, eps: 1.0e-6, max_iter: 100)
+
+    matrix
+    |> Nx.revectorize([collapsed_axes: :auto],
+      target_shape: {Nx.axis_size(matrix, -2), Nx.axis_size(matrix, -1)}
+    )
+    |> decompose(opts)
+    |> revectorize_result(matrix)
+  end
+
+  defnp decompose(matrix, opts) do
+    {n, _} = Nx.shape(matrix)
+
+    if n > 1 do
+      m_decompose(matrix, opts)
+    else
+      {Nx.take_diagonal(Nx.real(matrix)), Nx.tensor([1], type: matrix.type)}
+    end
+  end
+
+  defnp m_decompose(matrix, opts) do
+    eps = opts[:eps]
+    max_iter = opts[:max_iter]
+
+    type = Nx.Type.to_floating(Nx.type(matrix))
+    matrix = Nx.as_type(matrix, type)
+    {n, _} = Nx.shape(matrix)
+    i_n = n - 1
+    mid = calculate_mid(i_n)
+    i_mid = mid - 1
+
+    tl = matrix[[0..i_mid, 0..i_mid]]
+    tr = matrix[[0..i_mid, mid..i_n]]
+    bl = matrix[[mid..i_n, 0..i_mid]]
+    br = matrix[[mid..i_n, mid..i_n]]
+
+    # Pad if not even
+    {tr, bl, br} =
+      if Nx.remainder(n, 2) == 1 do
+        tr = Nx.pad(tr, 0, [{0, 0, 0}, {0, 1, 0}])
+        bl = Nx.pad(bl, 0, [{0, 1, 0}, {0, 0, 0}])
+        br = Nx.pad(br, 0, [{0, 1, 0}, {0, 1, 0}])
+        {tr, bl, br}
+      else
+        {tr, bl, br}
+      end
+
+    # Initialze tensors to hold eigenvectors
+    v_tl = v_br = Nx.eye(mid, type: type)
+    v_tr = v_bl = Nx.broadcast(Nx.tensor(0, type: type), {mid, mid})
+
+    {frob_norm, off_norm} = norms(tl, tr, bl, br)
+
+    # Nested loop
+    # Outside loop performs the "sweep" operation until the norms converge
+    # or max iterations are hit. The Brent/Luk paper states that Log2(n) is
+    # a good estimate for convergence, but XLA chose a static number which wouldn't
+    # be reached until a matrix roughly greater than 20kx20k.
+    #
+    # The inner loop performs "sweep" rounds of n - 1, which is enough permutations to allow
+    # all sub matrices to share the needed values.
+    {{tl, br, v_tl, v_tr, v_bl, v_br}, _} =
+      while {{tl, br, v_tl, v_tr, v_bl, v_br}, {frob_norm, off_norm, tr, bl, i = 0}},
+            off_norm > eps ** 2 * frob_norm and i < max_iter do
+        {tl, tr, bl, br, v_tl, v_tr, v_bl, v_br} =
+          perform_sweeps(tl, tr, bl, br, v_tl, v_tr, v_bl, v_br, mid, i_n)
+
+        {frob_norm, off_norm} = norms(tl, tr, bl, br)
+
+        {{tl, br, v_tl, v_tr, v_bl, v_br}, {frob_norm, off_norm, tr, bl, i + 1}}
+      end
+
+    # Recombine
+    w = Nx.concatenate([Nx.take_diagonal(tl), Nx.take_diagonal(br)])
+
+    v =
+      Nx.concatenate([
+        Nx.concatenate([v_tl, v_tr], axis: 1),
+        Nx.concatenate([v_bl, v_br], axis: 1)
+      ])
+      |> Nx.LinAlg.adjoint()
+
+    # trim padding
+    {w, v} =
+      if Nx.remainder(n, 2) == 1 do
+        {w[0..i_n], v[[0..i_n, 0..i_n]]}
+      else
+        {w, v}
+      end
+
+    sort_ind = Nx.argsort(Nx.abs(w), direction: :desc)
+
+    w = Nx.take(w, sort_ind) |> approximate_zeros(eps)
+    v = Nx.take(v, sort_ind, axis: 1) |> approximate_zeros(eps)
+
+    {w, v}
+  end
+
+  deftransformp calculate_mid(i_n) do
+    Range.size(0..i_n//2)
+  end
+
+  defnp calc_rot(tl, tr, br) do
+    complex? = tl |> Nx.type() |> Nx.Type.complex?()
+    br = Nx.take_diagonal(br) |> Nx.real()
+    tr = Nx.take_diagonal(tr)
+    tl = Nx.take_diagonal(tl) |> Nx.real()
+
+    {tr, w} =
+      if complex? do
+        abs_tr = Nx.abs(tr)
+        {abs_tr, Nx.select(abs_tr == 0, 1, Nx.conjugate(tr) / abs_tr)}
+      else
+        {tr, 1}
+      end
+
+    z_tr = Nx.equal(tr, 0)
+    s_tr = Nx.select(z_tr, 1, tr)
+    tau = Nx.select(z_tr, 0, (br - tl) / (2 * s_tr))
+
+    t = Nx.sqrt(1 + tau ** 2)
+
+    t = 1 / (tau + Nx.select(tau >= 0, t, -t))
+
+    pred = Nx.abs(tr) <= 1.0e-5 * Nx.min(Nx.abs(br), Nx.abs(tl))
+    t = Nx.select(pred, Nx.tensor(0, type: tl.type), t)
+
+    c = 1.0 / Nx.sqrt(1.0 + t ** 2)
+    s = if complex?, do: Nx.complex(t * c, 0) * w, else: t * c
+
+    rt1 = tl - t * tr
+    rt2 = br + t * tr
+    {rt1, rt2, c, s}
+  end
+
+  defnp sq_norm(tl, tr, bl, br) do
+    Nx.sum(Nx.abs(tl) ** 2 + Nx.abs(tr) ** 2 + Nx.abs(bl) ** 2 + Nx.abs(br) ** 2)
+  end
+
+  defnp off_norm(tl, tr, bl, br) do
+    {n, _} = Nx.shape(tl)
+    diag = Nx.broadcast(0, {n})
+    o_tl = Nx.put_diagonal(tl, diag)
+    o_br = Nx.put_diagonal(br, diag)
+
+    sq_norm(o_tl, tr, bl, o_br)
+  end
+
+  # Calculates the Frobenius norm and the norm of the off-diagonals from
+  # the submatrices. Used to calculate convergeance.
+  defnp norms(tl, tr, bl, br) do
+    frob = sq_norm(tl, tr, bl, br)
+    off = off_norm(tl, tr, bl, br)
+
+    {frob, off}
+  end
+
+  deftransformp revectorize_result({eigenvals, eigenvecs}, a) do
+    shape = Nx.shape(a)
+
+    {
+      Nx.revectorize(eigenvals, a.vectorized_axes,
+        target_shape: Tuple.delete_at(shape, tuple_size(shape) - 1)
+      ),
+      Nx.revectorize(eigenvecs, a.vectorized_axes, target_shape: shape)
+    }
+  end
+
+  defnp perform_sweeps(tl, tr, bl, br, v_tl, v_tr, v_bl, v_br, mid, i_n) do
+    while {tl, tr, bl, br, v_tl, v_tr, v_bl, v_br}, _n <- 0..i_n do
+      {rt1, rt2, c, s} = calc_rot(tl, tr, br)
+      # build row and column vectors for parrelelized rotations
+      c_v = Nx.new_axis(c, 1)
+      s_v = Nx.new_axis(s, 1)
+      c_h = Nx.new_axis(c, 0)
+      s_h = Nx.new_axis(s, 0)
+
+      s_v_conj =
+        if Nx.type(s) |> Nx.Type.complex?() do
+          Nx.conjugate(s_v)
+        else
+          s_v
+        end
+
+      s_h_conj = Nx.transpose(s_v_conj)
+
+      # Each rotation group below is performed based on the same
+      # tl, bl, tr, br values, so we must do single-expr
+      # assignments (i.e. {tl, tr, bl, br} = ...)
+
+      # Rotate rows
+      {tl, tr, bl, br} = {
+        tl * c_v - bl * s_v_conj,
+        tr * c_v - br * s_v_conj,
+        tl * s_v + bl * c_v,
+        tr * s_v + br * c_v
+      }
+
+      # Rotate cols
+      {tl, tr, bl, br} = {
+        tl * c_h - tr * s_h,
+        tl * s_h_conj + tr * c_h,
+        bl * c_h - br * s_h,
+        bl * s_h_conj + br * c_h
+      }
+
+      # Store results and permute values across sub matrices
+      zero_diag = Nx.broadcast(0, {mid})
+      tl = Nx.put_diagonal(tl, rt1)
+      tr = Nx.put_diagonal(tr, zero_diag)
+      bl = Nx.put_diagonal(bl, zero_diag)
+      br = Nx.put_diagonal(br, rt2)
+
+      {tl, tr} = permute_cols_in_row(tl, tr)
+      {bl, br} = permute_cols_in_row(bl, br)
+      {tl, bl} = permute_rows_in_col(tl, bl)
+      {tr, br} = permute_rows_in_col(tr, br)
+
+      # Rotate to calc vectors
+      {v_tl, v_tr, v_bl, v_br} = {
+        v_tl * c_v - v_bl * s_v_conj,
+        v_tr * c_v - v_br * s_v_conj,
+        v_tl * s_v + v_bl * c_v,
+        v_tr * s_v + v_br * c_v
+      }
+
+      # permute for vectors
+      {v_tl, v_bl} = permute_rows_in_col(v_tl, v_bl)
+      {v_tr, v_br} = permute_rows_in_col(v_tr, v_br)
+
+      {tl, tr, bl, br, v_tl, v_tr, v_bl, v_br}
+    end
+  end
+
+  defnp approximate_zeros(matrix, eps), do: Nx.select(Nx.abs(matrix) <= eps, 0, matrix)
+
+  # https://github.com/openxla/xla/blob/main/xla/hlo/transforms/expanders/eigh_expander.cc#L200-L239
+  defnp permute_rows_in_col(top, bottom) do
+    {k, _} = Nx.shape(top)
+
+    {top_out, bottom_out} =
+      cond do
+        k == 2 ->
+          {Nx.concatenate([top[0..0], bottom[0..0]], axis: 0),
+           Nx.concatenate(
+             [
+               bottom[1..-1//1],
+               top[(k - 1)..(k - 1)]
+             ],
+             axis: 0
+           )}
+
+        k == 1 ->
+          {top, bottom}
+
+        true ->
+          {Nx.concatenate([top[0..0], bottom[0..0], top[1..(k - 2)]], axis: 0),
+           Nx.concatenate(
+             [
+               bottom[1..-1//1],
+               top[(k - 1)..(k - 1)]
+             ],
+             axis: 0
+           )}
+      end
+
+    {top_out, bottom_out}
+  end
+
+  defnp permute_cols_in_row(left, right) do
+    {k, _} = Nx.shape(left)
+
+    {left_out, right_out} =
+      cond do
+        k == 2 ->
+          {Nx.concatenate([left[[.., 0..0]], right[[.., 0..0]]], axis: 1),
+           Nx.concatenate([right[[.., 1..(k - 1)]], left[[.., (k - 1)..(k - 1)]]], axis: 1)}
+
+        k == 1 ->
+          {left, right}
+
+        true ->
+          {Nx.concatenate([left[[.., 0..0]], right[[.., 0..0]], left[[.., 1..(k - 2)]]], axis: 1),
+           Nx.concatenate([right[[.., 1..(k - 1)]], left[[.., (k - 1)..(k - 1)]]], axis: 1)}
+      end
+
+    {left_out, right_out}
+  end
+end
diff --git a/nx/test/nx/defn/grad_test.exs b/nx/test/nx/defn/grad_test.exs
index 6ab7e8ef785..81bc025fc0c 100644
--- a/nx/test/nx/defn/grad_test.exs
+++ b/nx/test/nx/defn/grad_test.exs
@@ -1981,8 +1981,8 @@ defmodule Nx.Defn.GradTest do
       assert_all_close(
         svd_grad(Nx.tensor([[3, 0], [1, 2]])),
         Nx.tensor([
-          [0.07228553295135498, 0.7500489950180054],
-          [1.113668441772461, 1.8945982456207275]
+          [1.368404507637024, -0.5419228672981262],
+          [-0.2197188436985016, 0.6067624092102051]
         ])
       )
     end
@@ -1991,8 +1991,8 @@ defmodule Nx.Defn.GradTest do
       assert_all_close(
         svd_composed_grad(Nx.tensor([[3, 0], [1, 2]])),
         Nx.tensor([
-          [22.44730567932129, 4.334394931793213],
-          [10.295409202575684, 9.27196216583252]
+          [22.86724090576172, 3.655829906463623],
+          [10.035255432128906, 8.769235610961914]
         ])
       )
     end
@@ -2001,9 +2001,9 @@ defmodule Nx.Defn.GradTest do
       assert_all_close(
         svd_composed_grad(Nx.tensor([[3, 0], [1, 2], [1, 1]])),
         Nx.tensor([
-          [25.990453720092773, 6.061026096343994],
-          [12.646490097045898, 10.775838851928711],
-          [10.656349182128906, 6.384178638458252]
+          [25.911056518554688, 6.1099162101745605],
+          [12.69705581665039, 10.84456729888916],
+          [10.668402671813965, 6.426826477050781]
         ])
       )
     end
diff --git a/nx/test/nx/lin_alg_test.exs b/nx/test/nx/lin_alg_test.exs
index 36a48159e18..d8c8fe2bb40 100644
--- a/nx/test/nx/lin_alg_test.exs
+++ b/nx/test/nx/lin_alg_test.exs
@@ -574,11 +574,11 @@ defmodule Nx.LinAlgTest do
       assert_all_close(
         eigenvecs,
         Nx.tensor([
-          [0.112, -0.005, -0.831, -0.436, -0.328],
-          [0.395, 0.163, 0.530, -0.537, -0.497],
-          [0.427, 0.326, -0.133, 0.700, -0.452],
-          [0.603, -0.783, -0.007, 0.079, 0.130],
-          [0.534, 0.504, -0.104, -0.160, 0.651]
+          [0.112, 0.004, 0.828, -0.440, -0.328],
+          [0.395, -0.163, -0.533, -0.534, -0.497],
+          [0.427, -0.326, 0.137, 0.700, -0.452],
+          [0.603, 0.783, 0.008, 0.079, 0.130],
+          [0.534, -0.504, 0.103, -0.160, 0.651]
         ]),
         atol: 1.0e-3,
         rtol: 1.0e-3
@@ -600,28 +600,20 @@ defmodule Nx.LinAlgTest do
 
         # Eigenvalues
         assert eigenvals ==
-                 Nx.tensor([Complex.new(-5, 0), Complex.new(3, 0), Complex.new(1, 0)])
+                 Nx.tensor([
+                   Complex.new(-5, 0),
+                   Complex.new(3, 0),
+                   Complex.new(0.9999998807907104, 0)
+                 ])
 
         # Eigenvectors
         assert_all_close(
           eigenvecs,
-          Nx.tensor([
-            [
-              Complex.new(-0.408, 0.0),
-              Complex.new(-0.0, 0.707),
-              Complex.new(0.577, 0.0)
-            ],
-            [
-              Complex.new(-0.0, -0.816),
-              Complex.new(0.0, 0.0),
-              Complex.new(0.0, -0.577)
-            ],
-            [
-              Complex.new(0.408, 0.0),
-              Complex.new(-0.0, 0.707),
-              Complex.new(-0.577, 0.0)
-            ]
-          ]),
+          ~MAT[
+            0.0000-0.4082i 0.7071-0.0i 00.5773-0.0000i
+            0.8164-0.0000i 0.0000+0.0i 00.0000-0.5773i
+            0.0000+0.4082i 0.7071-0.0i -0.5773-0.0000i
+          ],
           atol: 1.0e-3,
           rtol: 1.0e-3
         )
@@ -638,42 +630,56 @@ defmodule Nx.LinAlgTest do
       for type <- [f: 32, c: 64], reduce: key do
         key ->
           # Unitary matrix from a random matrix
-          {base, key} = Nx.Random.uniform(key, shape: {3, 3, 3}, type: type)
+          {base, key} = Nx.Random.uniform(key, shape: {2, 3, 3}, type: type)
           {q, _} = Nx.LinAlg.qr(base)
 
           # Different eigenvalues from random values
           evals_test =
-            [{100, 30}, {4, 6}, {0.7, 0.9}]
-            |> Enum.map(fn {low, up} ->
-              if :rand.uniform() - 0.5 > 0 do
-                {low, up}
-              else
-                {-up, -low}
-              end
-            end)
-            |> Enum.map(fn {low, up} ->
-              rand = :rand.uniform() * (up - low) + low
-              Nx.tensor([rand], type: :f64)
+            [100, 10, 1]
+            |> Enum.map(fn magnitude ->
+              sign =
+                if :rand.uniform() - 0.5 > 0 do
+                  1
+                else
+                  -1
+                end
+
+              rand = :rand.uniform() * magnitude * 0.1 + magnitude
+              rand * sign
             end)
-            |> Nx.concatenate()
+            |> Nx.tensor(type: type)
+
+          evals_test_diag =
+            evals_test
+            |> Nx.make_diagonal()
+            |> Nx.reshape({1, 3, 3})
+            |> Nx.tile([2, 1, 1])
 
           # Hermitian matrix with different eigenvalues
           # using A = A^* = Q^*.Λ.Q.
           a =
             q
             |> Nx.LinAlg.adjoint()
-            |> Nx.multiply(evals_test)
+            |> Nx.dot([2], [0], evals_test_diag, [1], [0])
             |> Nx.dot([2], [0], q, [1], [0])
 
           # Eigenvalues and eigenvectors
-          assert {evals, evecs} = Nx.LinAlg.eigh(a, max_iter: 10_000)
-          assert_all_close(evals_test, evals, atol: 1.0e-1)
+          assert {evals, evecs} = Nx.LinAlg.eigh(a, eps: 1.0e-8)
+
+          assert_all_close(evals_test, evals[0], atol: 1.0e-8)
+          assert_all_close(evals_test, evals[1], atol: 1.0e-8)
+
+          evals =
+            evals
+            |> Nx.vectorize(:x)
+            |> Nx.make_diagonal()
+            |> Nx.devectorize(keep_names: false)
 
           # Eigenvalue equation
-          evecs_evals = Nx.multiply(evecs, evals)
-          a_evecs = Nx.dot(a, [2], [0], evecs, [1], [0])
+          evecs_evals = Nx.dot(evecs, [2], [0], evals, [1], [0])
+          a_evecs = Nx.dot(evecs_evals, [2], [0], Nx.LinAlg.adjoint(evecs), [1], [0])
 
-          assert_all_close(evecs_evals, a_evecs, atol: 1.0e-1)
+          assert_all_close(a, a_evecs, atol: 1.0e-8)
           key
       end
     end
@@ -734,10 +740,10 @@ defmodule Nx.LinAlgTest do
       assert_all_close(
         u,
         Nx.tensor([
-          [0.141, 0.825, -0.001, 0.019],
-          [0.344, 0.426, 0.00200, 0.382],
-          [0.547, 0.028, 0.0, -0.822],
-          [0.75, -0.370, -0.001, 0.421]
+          [0.141, -0.825, -0.001, 0.019],
+          [0.344, -0.426, 0.00200, 0.382],
+          [0.547, -0.028, 0.0, -0.822],
+          [0.75, 0.370, -0.001, 0.421]
         ]),
         atol: 1.0e-3,
         rtol: 1.0e-3
@@ -747,8 +753,8 @@ defmodule Nx.LinAlgTest do
 
       assert_all_close(
         Nx.tensor([
-          [0.505, 0.575, 0.644],
-          [-0.761, -0.057, 0.647],
+          [0.504, 0.575, 0.644],
+          [0.761, 0.057, -0.647],
           [-0.408, 0.816, -0.408]
         ]),
         v,
@@ -801,9 +807,9 @@ defmodule Nx.LinAlgTest do
       assert_all_close(
         u,
         Nx.tensor([
-          [0.336, -0.407, -0.849],
-          [0.037, -0.895, 0.444],
-          [0.941, 0.181, 0.286]
+          [0.335, 0.408, 0.849],
+          [0.036, 0.895, -0.445],
+          [0.941, -0.18, -0.286]
         ]),
         atol: 1.0e-3,
         rtol: 1.0e-3
@@ -815,9 +821,9 @@ defmodule Nx.LinAlgTest do
 
       assert_all_close(
         Nx.tensor([
-          [0.035, 0.0869, 0.996],
-          [-0.091, -0.992, 0.09],
-          [-0.995, 0.094, 0.027]
+          [0.035, 0.0856, 0.996],
+          [0.092, 0.992, -0.089],
+          [0.995, -0.094, -0.027]
         ]),
         v,
         atol: 1.0e-3,
diff --git a/torchx/mix.exs b/torchx/mix.exs
index 070718c3b03..8a0f3cd3686 100644
--- a/torchx/mix.exs
+++ b/torchx/mix.exs
@@ -41,8 +41,8 @@ defmodule Torchx.MixProject do
 
   defp deps do
     [
-      {:nx, "~> 0.9.0"},
-      # {:nx, path: "../nx"},
+      # {:nx, "~> 0.9.0"},
+      {:nx, path: "../nx"},
       {:ex_doc, "~> 0.29", only: :docs}
     ]
   end
diff --git a/torchx/test/torchx/nx_linalg_doctest_test.exs b/torchx/test/torchx/nx_linalg_doctest_test.exs
index 9f3c6eca521..30e75dafc55 100644
--- a/torchx/test/torchx/nx_linalg_doctest_test.exs
+++ b/torchx/test/torchx/nx_linalg_doctest_test.exs
@@ -18,7 +18,7 @@ defmodule Torchx.NxLinAlgDoctestTest do
     invert: 1,
     determinant: 1,
     pinv: 2,
-    least_squares: 2
+    least_squares: 3
   ]
 
   # Results do not match but properties are respected