From 8dc7b29181ed608358845fef51da71c7f992a0ed Mon Sep 17 00:00:00 2001 From: Christian Green <christian.j.green@gmail.com> Date: Mon, 13 Jan 2025 10:25:52 -0800 Subject: [PATCH] Blocked Jacobi method for eigen decomposition (#1510) Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com> --- nx/lib/nx/binary_backend.ex | 19 -- nx/lib/nx/binary_backend/matrix.ex | 272 ---------------- nx/lib/nx/lin_alg.ex | 44 +-- nx/lib/nx/lin_alg/block_eigh.ex | 304 ++++++++++++++++++ nx/test/nx/defn/grad_test.exs | 14 +- nx/test/nx/lin_alg_test.exs | 114 +++---- torchx/mix.exs | 4 +- torchx/test/torchx/nx_linalg_doctest_test.exs | 2 +- 8 files changed, 396 insertions(+), 377 deletions(-) create mode 100644 nx/lib/nx/lin_alg/block_eigh.ex diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index e3ae4121d6f..e2795d6071c 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -1240,25 +1240,6 @@ defmodule Nx.BinaryBackend do output_batch_groups |> Enum.with_index() |> Enum.map(fn {x, i} -> {x, rem(i, groups)} end) end - @impl true - def eigh( - {%{type: output_type} = eigenvals_holder, eigenvecs_holder}, - %{type: input_type, shape: input_shape} = tensor, - opts - ) do - bin = to_binary(tensor) - rank = tuple_size(input_shape) - n = elem(input_shape, rank - 1) - - {eigenvals, eigenvecs} = - bin_batch_reduce(bin, n * n, input_type, {<<>>, <<>>}, fn matrix, {vals_acc, vecs_acc} -> - {vals, vecs} = B.Matrix.eigh(matrix, input_type, {n, n}, output_type, opts) - {vals_acc <> vals, vecs_acc <> vecs} - end) - - {from_binary(eigenvals_holder, eigenvals), from_binary(eigenvecs_holder, eigenvecs)} - end - @impl true def lu( {%{type: p_type} = p_holder, %{type: l_type} = l_holder, %{type: u_type} = u_holder}, diff --git a/nx/lib/nx/binary_backend/matrix.ex b/nx/lib/nx/binary_backend/matrix.ex index 85601295b97..afb55fb6689 100644 --- a/nx/lib/nx/binary_backend/matrix.ex +++ b/nx/lib/nx/binary_backend/matrix.ex @@ -116,150 +116,6 @@ defmodule Nx.BinaryBackend.Matrix do defp do_ts([], [], _idx, acc), do: acc - defp qr_decomposition(matrix, n, _eps) when n in 0..1 do - {[[1.0]], matrix} - end - - defp qr_decomposition(matrix, n, eps) when n >= 2 do - # QR decomposition is performed by using Householder transform - # this function originally supported generic QR, but - # it is now only used by eigh. Because of this, - # we simplified the function signature to only - # support square matrices. - - {q_matrix, r_matrix} = - for i <- 0..(n - 2)//1, reduce: {nil, matrix} do - {q, r} -> - h = - r - |> slice_matrix([i, i], [n - i, 1]) - |> householder_reflector(n, eps) - - # If we haven't allocated Q yet, let Q = H1 - # TODO: Resolve inconsistent with the Householder reflector. - # cf. https://github.com/elixir-nx/nx/pull/933#discussion_r982772063 - q = - if is_nil(q) do - h - else - dot_matrix_real(q, h) - end - - r = dot_matrix_real(h, r) - {q, r} - end - - {approximate_zeros(q_matrix, eps), approximate_zeros(r_matrix, eps)} - end - - defp raise_not_hermitian do - raise ArgumentError, - "matrix must be hermitian, a matrix is hermitian iff X = adjoint(X)" - end - - def eigh(input_data, input_type, {n, n} = input_shape, output_type, opts) do - eps = opts[:eps] - max_iter = opts[:max_iter] - - # Validate that the input is a Hermitian matrix using the relation A^* = A. - a = binary_to_matrix(input_data, input_type, input_shape) - - is_hermitian = - a - |> transpose_matrix() - |> Enum.map(fn a_row -> Enum.map(a_row, &Complex.conjugate(&1)) end) - |> is_approximately_same?(a, eps) - - unless is_hermitian do - raise_not_hermitian() - end - - # Hessenberg decomposition - {h, q_h} = hessenberg_decomposition(a, n, eps) - - # QR iteration for eigenvalues and eigenvectors - {eigenvals_diag, eigenvecs} = - Enum.reduce_while(1..max_iter//1, {h, q_h}, fn _, {a_old, q_old} -> - # QR decomposition - {q_now, r_now} = qr_decomposition(a_old, n, eps) - - # Update matrix A, Q - a_new = dot_matrix_real(r_now, q_now) - q_new = dot_matrix_real(q_old, q_now) - - if is_approximately_same?(q_old, q_new, eps) do - {:halt, {a_new, q_new}} - else - {:cont, {a_new, q_new}} - end - end) - - # Obtain the eigenvalues, which are the diagonal elements - indices_diag = for idx <- 0..(n - 1), do: [idx, idx] - eigenvals = get_matrix_elements(eigenvals_diag, indices_diag) - - # In general, the eigenvalues of a Hermitian matrix are real numbers - eigenvals_real = eigenvals |> Enum.map(&Complex.real(&1)) - - # Reduce the elements smaller than eps to zero - {eigenvals_real |> approximate_zeros(eps) |> matrix_to_binary(output_type), - eigenvecs |> approximate_zeros(eps) |> matrix_to_binary(output_type)} - end - - defp hessenberg_decomposition(matrix, n, _eps) when n in 0..1 do - {matrix, [[1.0]]} - end - - defp hessenberg_decomposition(matrix, n, eps) do - # Hessenberg decomposition is performed by using Householder transform - {hess_matrix, q_matrix} = - for i <- 0..(n - 2)//1, reduce: {matrix, nil} do - {hess, q} -> - h = - hess - |> slice_matrix([i + 1, i], [n - i - 1, 1]) - |> householder_reflector(n, eps) - - # If we haven't allocated Q yet, let Q = H1 - # TODO: Resolve inconsistent with the Householder reflector. - # cf. https://github.com/elixir-nx/nx/pull/933#discussion_r982772063 - q = - if is_nil(q) do - h - else - dot_matrix_real(q, h) - end - - # Hessenberg matrix H updating - h_adj = adjoint_matrix(h) - - hess = - h - |> dot_matrix_real(hess) - |> dot_matrix_real(h_adj) - - {hess, q} - end - - {approximate_zeros(hess_matrix, eps), approximate_zeros(q_matrix, eps)} - end - - defp is_approximately_same?(a, b, eps) do - # Determine if matrices `a` and `b` are equal in the range of eps - a - |> Enum.zip(b) - |> Enum.all?(fn {a_row, b_row} -> - a_row - |> Enum.zip(b_row) - |> Enum.all?(fn - {a_elem, b_elem} -> - abs_diff = Complex.abs(a_elem - b_elem) - - abs_diff == :nan or abs_diff <= eps - end) - end) - end - def lu(input_data, input_type, {n, n} = input_shape, p_type, l_type, u_type, opts) do a = binary_to_matrix(input_data, input_type, input_shape) eps = opts[:eps] @@ -361,116 +217,6 @@ defmodule Nx.BinaryBackend.Matrix do end) end - ## Householder helpers - - defp householder_reflector(a, target_k, eps) - - defp householder_reflector([], target_k, _eps) do - flat_list = - for col <- 0..(target_k - 1), row <- 0..(target_k - 1), into: [] do - if col == row, do: 1, else: 0 - end - - Enum.chunk_every(flat_list, target_k) - end - - defp householder_reflector(a, target_k, eps) do - {v, scale, is_complex} = householder_reflector_pivot(a, eps) - - prefix_threshold = target_k - length(v) - v = List.duplicate(0, prefix_threshold) ++ v - - # dot(v, v) = norm_v_squared, which can be calculated from norm_a as: - # norm_v_squared = norm_a_squared - a_0^2 + v_0^2 - - # execute I - 2 / norm_v_squared * outer(v, v) - {_, _, reflector_reversed} = - for col_factor <- v, row_factor <- v, reduce: {0, 0, []} do - {row, col, acc} -> - row_factor = if is_complex, do: Complex.conjugate(row_factor), else: row_factor - - # The current element in outer(v, v) is given by col_factor * row_factor - # and the current I element is 1 when row == col - identity_element = if row == col, do: 1, else: 0 - - result = - if row >= prefix_threshold and col >= prefix_threshold do - identity_element - - scale * col_factor * row_factor - else - identity_element - end - - acc = [result | acc] - - if col + 1 == target_k do - {row + 1, 0, acc} - else - {row, col + 1, acc} - end - end - - # This is equivalent to reflector_reversed |> Enum.reverse() |> Enum.chunk_every(target_k) - {reflector, _, _} = - for x <- reflector_reversed, reduce: {[], [], 0} do - {result_acc, row_acc, col} -> - row_acc = [x | row_acc] - - if col + 1 == target_k do - {[row_acc | result_acc], [], 0} - else - {result_acc, row_acc, col + 1} - end - end - - reflector - end - - defp householder_reflector_pivot([a_0 | tail] = a, eps) when is_number(a_0) do - # This is a trick so we can both calculate the norm of a_reverse and extract the - # head a the same time we reverse the array - # receives a_reverse as a list of numbers and returns the reflector as a - # k x k matrix - - norm_a_squared = Enum.reduce(a, 0, fn x, acc -> x * Complex.conjugate(x) + acc end) - norm_a_sq_1on = norm_a_squared - a_0 * a_0 - - if norm_a_sq_1on < eps do - {[1 | tail], 0, false} - else - v_0 = - if a_0 <= 0 do - a_0 - Complex.sqrt(norm_a_squared) - else - -norm_a_sq_1on / (a_0 + Complex.sqrt(norm_a_squared)) - end - - v_0_sq = v_0 * v_0 - scale = 2 * v_0_sq / (norm_a_sq_1on + v_0_sq) - v = [1 | Enum.map(tail, &(&1 / v_0))] - {v, scale, false} - end - end - - defp householder_reflector_pivot([a_0 | tail], _eps) do - # complex case - norm_a_sq_1on = Enum.reduce(tail, 0, &(Complex.abs_squared(&1) + &2)) - norm_a_sq = norm_a_sq_1on + Complex.abs_squared(a_0) - norm_a = Complex.sqrt(norm_a_sq) - - phase_a_0 = Complex.phase(a_0) - alfa = Complex.exp(Complex.new(0, phase_a_0)) * norm_a - - # u = x - alfa * e1 - u_0 = a_0 + alfa - u = [u_0 | tail] - norm_u_sq = norm_a_sq_1on + Complex.abs_squared(u_0) - norm_u = Complex.sqrt(norm_u_sq) - - v = Enum.map(u, &(&1 / norm_u)) - {v, 2, true} - end - ## Matrix (2-D array) manipulation defp dot_matrix([], _), do: 0 @@ -491,24 +237,6 @@ defmodule Nx.BinaryBackend.Matrix do end) end - defp dot_matrix_real(m1, m2) do - Enum.map(m1, fn row -> - m2 - |> transpose_matrix() - |> Enum.map(fn col -> - Enum.zip_reduce(row, col, 0, fn x, y, acc -> acc + x * y end) - end) - end) - end - - defp adjoint_matrix([x | _] = m) when not is_list(x) do - Enum.map(m, &[Complex.conjugate(&1)]) - end - - defp adjoint_matrix(m) do - Enum.zip_with(m, fn cols -> Enum.map(cols, &Complex.conjugate/1) end) - end - defp transpose_matrix([x | _] = m) when not is_list(x) do Enum.map(m, &[&1]) end diff --git a/nx/lib/nx/lin_alg.ex b/nx/lib/nx/lin_alg.ex index 450a48cae42..edced802468 100644 --- a/nx/lib/nx/lin_alg.ex +++ b/nx/lib/nx/lin_alg.ex @@ -1179,8 +1179,8 @@ defmodule Nx.LinAlg do #Nx.Tensor< f32[2][2] [ - [3.9924824237823486, -1.0052783489227295], - [-3.0051186084747314, 1.0071179866790771] + [4.000002861022949, -1.0000008344650269], + [-3.000002384185791, 1.0000005960464478] ] > @@ -1275,14 +1275,14 @@ defmodule Nx.LinAlg do iex> Nx.round(eigenvals) #Nx.Tensor< f32[2] - [1.0, 2.0] + [2.0, 1.0] > iex> eigenvecs #Nx.Tensor< f32[2][2] [ - [1.0, 0.0], - [0.0, 1.0] + [0.0, 1.0], + [1.0, 0.0] ] > @@ -1296,9 +1296,9 @@ defmodule Nx.LinAlg do #Nx.Tensor< f32[3][3] [ - [0.4075949788093567, 0.9131628274917603, 0.0], - [0.40837883949279785, -0.18228201568126678, 0.8944271802902222], - [0.8167576789855957, -0.36456403136253357, -0.4472135901451111] + [0.40824827551841736, -0.18257419764995575, 0.8944271802902222], + [0.40824833512306213, 0.9128708839416504, 0.0], + [0.8164965510368347, -0.3651483952999115, -0.4472135901451111] ] > @@ -1308,7 +1308,7 @@ defmodule Nx.LinAlg do f32[2][2] [ [9.0, -1.0], - [1.0, 4.0] + [4.0, 1.0] ] > iex> eigenvecs @@ -1316,12 +1316,12 @@ defmodule Nx.LinAlg do f32[2][2][2] [ [ - [0.5612090229988098, -0.8276740908622742], - [0.8276740908622742, 0.5612090229988098] + [0.5606288313865662, 0.8280671834945679], + [0.8280671834945679, -0.5606288313865662] ], [ - [1.0, 0.0], - [0.0, 1.0] + [0.0, 1.0], + [1.0, 0.0] ] ] > @@ -1334,7 +1334,7 @@ defmodule Nx.LinAlg do f32[2] [ [9.0, -1.0], - [1.0, 4.0] + [4.0, 1.0] ] > iex> eigenvecs @@ -1343,12 +1343,12 @@ defmodule Nx.LinAlg do f32[2][2] [ [ - [0.5612090229988098, -0.8276740908622742], - [0.8276740908622742, 0.5612090229988098] + [0.5606288313865662, 0.8280671834945679], + [0.8280671834945679, -0.5606288313865662] ], [ - [1.0, 0.0], - [0.0, 1.0] + [0.0, 1.0], + [1.0, 0.0] ] ] > @@ -1376,7 +1376,7 @@ defmodule Nx.LinAlg do %{tensor | names: eigenvecs_name, type: output_type, shape: eigenvecs_shape}} :eigh - |> Nx.Shared.optional([tensor, opts], output, &Nx.LinAlg.Eigh.eigh/2) + |> Nx.Shared.optional([tensor, opts], output, &Nx.LinAlg.BlockEigh.eigh/2) |> Nx.vectorize(vectorized_axes) end @@ -2161,19 +2161,19 @@ defmodule Nx.LinAlg do iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2], [2, 3]]), Nx.tensor([1, 2])) #Nx.Tensor< f32[2] - [0.9977624416351318, 0.0011188983917236328] + [1.0000028610229492, -2.384185791015625e-6] > iex> Nx.LinAlg.least_squares(Nx.tensor([[0, 1], [1, 1], [2, 1], [3, 1]]), Nx.tensor([-1, 0.2, 0.9, 2.1])) #Nx.Tensor< f32[2] - [0.9966151118278503, -0.947966456413269] + [0.9999998211860657, -0.9500012993812561] > iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2, 3], [4, 5, 6]]), Nx.tensor([1, 2])) #Nx.Tensor< f32[3] - [-0.05534052848815918, 0.1111316829919815, 0.27760395407676697] + [-0.05555540323257446, 0.1111111044883728, 0.27777770161628723] > ## Error cases diff --git a/nx/lib/nx/lin_alg/block_eigh.ex b/nx/lib/nx/lin_alg/block_eigh.ex new file mode 100644 index 00000000000..1c6f7986d90 --- /dev/null +++ b/nx/lib/nx/lin_alg/block_eigh.ex @@ -0,0 +1,304 @@ +defmodule Nx.LinAlg.BlockEigh do + @moduledoc """ + Parallel Jacobi symmetric eigendecomposition. + + Reference implementation taking from XLA's eigh_expander + which is built on the approach in: + Brent, R. P., & Luk, F. T. (1985). The solution of singular-value + and symmetric eigenvalue problems on multiprocessor arrays. + SIAM Journal on Computing, 6(1), 69-84. https://doi.org/10.1137/0906007 + """ + require Nx + + import Nx.Defn + + defn eigh(matrix, opts \\ []) do + opts = keyword!(opts, eps: 1.0e-6, max_iter: 100) + + matrix + |> Nx.revectorize([collapsed_axes: :auto], + target_shape: {Nx.axis_size(matrix, -2), Nx.axis_size(matrix, -1)} + ) + |> decompose(opts) + |> revectorize_result(matrix) + end + + defnp decompose(matrix, opts) do + {n, _} = Nx.shape(matrix) + + if n > 1 do + m_decompose(matrix, opts) + else + {Nx.take_diagonal(Nx.real(matrix)), Nx.tensor([1], type: matrix.type)} + end + end + + defnp m_decompose(matrix, opts) do + eps = opts[:eps] + max_iter = opts[:max_iter] + + type = Nx.Type.to_floating(Nx.type(matrix)) + matrix = Nx.as_type(matrix, type) + {n, _} = Nx.shape(matrix) + i_n = n - 1 + mid = calculate_mid(i_n) + i_mid = mid - 1 + + tl = matrix[[0..i_mid, 0..i_mid]] + tr = matrix[[0..i_mid, mid..i_n]] + bl = matrix[[mid..i_n, 0..i_mid]] + br = matrix[[mid..i_n, mid..i_n]] + + # Pad if not even + {tr, bl, br} = + if Nx.remainder(n, 2) == 1 do + tr = Nx.pad(tr, 0, [{0, 0, 0}, {0, 1, 0}]) + bl = Nx.pad(bl, 0, [{0, 1, 0}, {0, 0, 0}]) + br = Nx.pad(br, 0, [{0, 1, 0}, {0, 1, 0}]) + {tr, bl, br} + else + {tr, bl, br} + end + + # Initialze tensors to hold eigenvectors + v_tl = v_br = Nx.eye(mid, type: type) + v_tr = v_bl = Nx.broadcast(Nx.tensor(0, type: type), {mid, mid}) + + {frob_norm, off_norm} = norms(tl, tr, bl, br) + + # Nested loop + # Outside loop performs the "sweep" operation until the norms converge + # or max iterations are hit. The Brent/Luk paper states that Log2(n) is + # a good estimate for convergence, but XLA chose a static number which wouldn't + # be reached until a matrix roughly greater than 20kx20k. + # + # The inner loop performs "sweep" rounds of n - 1, which is enough permutations to allow + # all sub matrices to share the needed values. + {{tl, br, v_tl, v_tr, v_bl, v_br}, _} = + while {{tl, br, v_tl, v_tr, v_bl, v_br}, {frob_norm, off_norm, tr, bl, i = 0}}, + off_norm > eps ** 2 * frob_norm and i < max_iter do + {tl, tr, bl, br, v_tl, v_tr, v_bl, v_br} = + perform_sweeps(tl, tr, bl, br, v_tl, v_tr, v_bl, v_br, mid, i_n) + + {frob_norm, off_norm} = norms(tl, tr, bl, br) + + {{tl, br, v_tl, v_tr, v_bl, v_br}, {frob_norm, off_norm, tr, bl, i + 1}} + end + + # Recombine + w = Nx.concatenate([Nx.take_diagonal(tl), Nx.take_diagonal(br)]) + + v = + Nx.concatenate([ + Nx.concatenate([v_tl, v_tr], axis: 1), + Nx.concatenate([v_bl, v_br], axis: 1) + ]) + |> Nx.LinAlg.adjoint() + + # trim padding + {w, v} = + if Nx.remainder(n, 2) == 1 do + {w[0..i_n], v[[0..i_n, 0..i_n]]} + else + {w, v} + end + + sort_ind = Nx.argsort(Nx.abs(w), direction: :desc) + + w = Nx.take(w, sort_ind) |> approximate_zeros(eps) + v = Nx.take(v, sort_ind, axis: 1) |> approximate_zeros(eps) + + {w, v} + end + + deftransformp calculate_mid(i_n) do + Range.size(0..i_n//2) + end + + defnp calc_rot(tl, tr, br) do + complex? = tl |> Nx.type() |> Nx.Type.complex?() + br = Nx.take_diagonal(br) |> Nx.real() + tr = Nx.take_diagonal(tr) + tl = Nx.take_diagonal(tl) |> Nx.real() + + {tr, w} = + if complex? do + abs_tr = Nx.abs(tr) + {abs_tr, Nx.select(abs_tr == 0, 1, Nx.conjugate(tr) / abs_tr)} + else + {tr, 1} + end + + z_tr = Nx.equal(tr, 0) + s_tr = Nx.select(z_tr, 1, tr) + tau = Nx.select(z_tr, 0, (br - tl) / (2 * s_tr)) + + t = Nx.sqrt(1 + tau ** 2) + + t = 1 / (tau + Nx.select(tau >= 0, t, -t)) + + pred = Nx.abs(tr) <= 1.0e-5 * Nx.min(Nx.abs(br), Nx.abs(tl)) + t = Nx.select(pred, Nx.tensor(0, type: tl.type), t) + + c = 1.0 / Nx.sqrt(1.0 + t ** 2) + s = if complex?, do: Nx.complex(t * c, 0) * w, else: t * c + + rt1 = tl - t * tr + rt2 = br + t * tr + {rt1, rt2, c, s} + end + + defnp sq_norm(tl, tr, bl, br) do + Nx.sum(Nx.abs(tl) ** 2 + Nx.abs(tr) ** 2 + Nx.abs(bl) ** 2 + Nx.abs(br) ** 2) + end + + defnp off_norm(tl, tr, bl, br) do + {n, _} = Nx.shape(tl) + diag = Nx.broadcast(0, {n}) + o_tl = Nx.put_diagonal(tl, diag) + o_br = Nx.put_diagonal(br, diag) + + sq_norm(o_tl, tr, bl, o_br) + end + + # Calculates the Frobenius norm and the norm of the off-diagonals from + # the submatrices. Used to calculate convergeance. + defnp norms(tl, tr, bl, br) do + frob = sq_norm(tl, tr, bl, br) + off = off_norm(tl, tr, bl, br) + + {frob, off} + end + + deftransformp revectorize_result({eigenvals, eigenvecs}, a) do + shape = Nx.shape(a) + + { + Nx.revectorize(eigenvals, a.vectorized_axes, + target_shape: Tuple.delete_at(shape, tuple_size(shape) - 1) + ), + Nx.revectorize(eigenvecs, a.vectorized_axes, target_shape: shape) + } + end + + defnp perform_sweeps(tl, tr, bl, br, v_tl, v_tr, v_bl, v_br, mid, i_n) do + while {tl, tr, bl, br, v_tl, v_tr, v_bl, v_br}, _n <- 0..i_n do + {rt1, rt2, c, s} = calc_rot(tl, tr, br) + # build row and column vectors for parrelelized rotations + c_v = Nx.new_axis(c, 1) + s_v = Nx.new_axis(s, 1) + c_h = Nx.new_axis(c, 0) + s_h = Nx.new_axis(s, 0) + + s_v_conj = + if Nx.type(s) |> Nx.Type.complex?() do + Nx.conjugate(s_v) + else + s_v + end + + s_h_conj = Nx.transpose(s_v_conj) + + # Each rotation group below is performed based on the same + # tl, bl, tr, br values, so we must do single-expr + # assignments (i.e. {tl, tr, bl, br} = ...) + + # Rotate rows + {tl, tr, bl, br} = { + tl * c_v - bl * s_v_conj, + tr * c_v - br * s_v_conj, + tl * s_v + bl * c_v, + tr * s_v + br * c_v + } + + # Rotate cols + {tl, tr, bl, br} = { + tl * c_h - tr * s_h, + tl * s_h_conj + tr * c_h, + bl * c_h - br * s_h, + bl * s_h_conj + br * c_h + } + + # Store results and permute values across sub matrices + zero_diag = Nx.broadcast(0, {mid}) + tl = Nx.put_diagonal(tl, rt1) + tr = Nx.put_diagonal(tr, zero_diag) + bl = Nx.put_diagonal(bl, zero_diag) + br = Nx.put_diagonal(br, rt2) + + {tl, tr} = permute_cols_in_row(tl, tr) + {bl, br} = permute_cols_in_row(bl, br) + {tl, bl} = permute_rows_in_col(tl, bl) + {tr, br} = permute_rows_in_col(tr, br) + + # Rotate to calc vectors + {v_tl, v_tr, v_bl, v_br} = { + v_tl * c_v - v_bl * s_v_conj, + v_tr * c_v - v_br * s_v_conj, + v_tl * s_v + v_bl * c_v, + v_tr * s_v + v_br * c_v + } + + # permute for vectors + {v_tl, v_bl} = permute_rows_in_col(v_tl, v_bl) + {v_tr, v_br} = permute_rows_in_col(v_tr, v_br) + + {tl, tr, bl, br, v_tl, v_tr, v_bl, v_br} + end + end + + defnp approximate_zeros(matrix, eps), do: Nx.select(Nx.abs(matrix) <= eps, 0, matrix) + + # https://github.com/openxla/xla/blob/main/xla/hlo/transforms/expanders/eigh_expander.cc#L200-L239 + defnp permute_rows_in_col(top, bottom) do + {k, _} = Nx.shape(top) + + {top_out, bottom_out} = + cond do + k == 2 -> + {Nx.concatenate([top[0..0], bottom[0..0]], axis: 0), + Nx.concatenate( + [ + bottom[1..-1//1], + top[(k - 1)..(k - 1)] + ], + axis: 0 + )} + + k == 1 -> + {top, bottom} + + true -> + {Nx.concatenate([top[0..0], bottom[0..0], top[1..(k - 2)]], axis: 0), + Nx.concatenate( + [ + bottom[1..-1//1], + top[(k - 1)..(k - 1)] + ], + axis: 0 + )} + end + + {top_out, bottom_out} + end + + defnp permute_cols_in_row(left, right) do + {k, _} = Nx.shape(left) + + {left_out, right_out} = + cond do + k == 2 -> + {Nx.concatenate([left[[.., 0..0]], right[[.., 0..0]]], axis: 1), + Nx.concatenate([right[[.., 1..(k - 1)]], left[[.., (k - 1)..(k - 1)]]], axis: 1)} + + k == 1 -> + {left, right} + + true -> + {Nx.concatenate([left[[.., 0..0]], right[[.., 0..0]], left[[.., 1..(k - 2)]]], axis: 1), + Nx.concatenate([right[[.., 1..(k - 1)]], left[[.., (k - 1)..(k - 1)]]], axis: 1)} + end + + {left_out, right_out} + end +end diff --git a/nx/test/nx/defn/grad_test.exs b/nx/test/nx/defn/grad_test.exs index 6ab7e8ef785..81bc025fc0c 100644 --- a/nx/test/nx/defn/grad_test.exs +++ b/nx/test/nx/defn/grad_test.exs @@ -1981,8 +1981,8 @@ defmodule Nx.Defn.GradTest do assert_all_close( svd_grad(Nx.tensor([[3, 0], [1, 2]])), Nx.tensor([ - [0.07228553295135498, 0.7500489950180054], - [1.113668441772461, 1.8945982456207275] + [1.368404507637024, -0.5419228672981262], + [-0.2197188436985016, 0.6067624092102051] ]) ) end @@ -1991,8 +1991,8 @@ defmodule Nx.Defn.GradTest do assert_all_close( svd_composed_grad(Nx.tensor([[3, 0], [1, 2]])), Nx.tensor([ - [22.44730567932129, 4.334394931793213], - [10.295409202575684, 9.27196216583252] + [22.86724090576172, 3.655829906463623], + [10.035255432128906, 8.769235610961914] ]) ) end @@ -2001,9 +2001,9 @@ defmodule Nx.Defn.GradTest do assert_all_close( svd_composed_grad(Nx.tensor([[3, 0], [1, 2], [1, 1]])), Nx.tensor([ - [25.990453720092773, 6.061026096343994], - [12.646490097045898, 10.775838851928711], - [10.656349182128906, 6.384178638458252] + [25.911056518554688, 6.1099162101745605], + [12.69705581665039, 10.84456729888916], + [10.668402671813965, 6.426826477050781] ]) ) end diff --git a/nx/test/nx/lin_alg_test.exs b/nx/test/nx/lin_alg_test.exs index 36a48159e18..d8c8fe2bb40 100644 --- a/nx/test/nx/lin_alg_test.exs +++ b/nx/test/nx/lin_alg_test.exs @@ -574,11 +574,11 @@ defmodule Nx.LinAlgTest do assert_all_close( eigenvecs, Nx.tensor([ - [0.112, -0.005, -0.831, -0.436, -0.328], - [0.395, 0.163, 0.530, -0.537, -0.497], - [0.427, 0.326, -0.133, 0.700, -0.452], - [0.603, -0.783, -0.007, 0.079, 0.130], - [0.534, 0.504, -0.104, -0.160, 0.651] + [0.112, 0.004, 0.828, -0.440, -0.328], + [0.395, -0.163, -0.533, -0.534, -0.497], + [0.427, -0.326, 0.137, 0.700, -0.452], + [0.603, 0.783, 0.008, 0.079, 0.130], + [0.534, -0.504, 0.103, -0.160, 0.651] ]), atol: 1.0e-3, rtol: 1.0e-3 @@ -600,28 +600,20 @@ defmodule Nx.LinAlgTest do # Eigenvalues assert eigenvals == - Nx.tensor([Complex.new(-5, 0), Complex.new(3, 0), Complex.new(1, 0)]) + Nx.tensor([ + Complex.new(-5, 0), + Complex.new(3, 0), + Complex.new(0.9999998807907104, 0) + ]) # Eigenvectors assert_all_close( eigenvecs, - Nx.tensor([ - [ - Complex.new(-0.408, 0.0), - Complex.new(-0.0, 0.707), - Complex.new(0.577, 0.0) - ], - [ - Complex.new(-0.0, -0.816), - Complex.new(0.0, 0.0), - Complex.new(0.0, -0.577) - ], - [ - Complex.new(0.408, 0.0), - Complex.new(-0.0, 0.707), - Complex.new(-0.577, 0.0) - ] - ]), + ~MAT[ + 0.0000-0.4082i 0.7071-0.0i 00.5773-0.0000i + 0.8164-0.0000i 0.0000+0.0i 00.0000-0.5773i + 0.0000+0.4082i 0.7071-0.0i -0.5773-0.0000i + ], atol: 1.0e-3, rtol: 1.0e-3 ) @@ -638,42 +630,56 @@ defmodule Nx.LinAlgTest do for type <- [f: 32, c: 64], reduce: key do key -> # Unitary matrix from a random matrix - {base, key} = Nx.Random.uniform(key, shape: {3, 3, 3}, type: type) + {base, key} = Nx.Random.uniform(key, shape: {2, 3, 3}, type: type) {q, _} = Nx.LinAlg.qr(base) # Different eigenvalues from random values evals_test = - [{100, 30}, {4, 6}, {0.7, 0.9}] - |> Enum.map(fn {low, up} -> - if :rand.uniform() - 0.5 > 0 do - {low, up} - else - {-up, -low} - end - end) - |> Enum.map(fn {low, up} -> - rand = :rand.uniform() * (up - low) + low - Nx.tensor([rand], type: :f64) + [100, 10, 1] + |> Enum.map(fn magnitude -> + sign = + if :rand.uniform() - 0.5 > 0 do + 1 + else + -1 + end + + rand = :rand.uniform() * magnitude * 0.1 + magnitude + rand * sign end) - |> Nx.concatenate() + |> Nx.tensor(type: type) + + evals_test_diag = + evals_test + |> Nx.make_diagonal() + |> Nx.reshape({1, 3, 3}) + |> Nx.tile([2, 1, 1]) # Hermitian matrix with different eigenvalues # using A = A^* = Q^*.Λ.Q. a = q |> Nx.LinAlg.adjoint() - |> Nx.multiply(evals_test) + |> Nx.dot([2], [0], evals_test_diag, [1], [0]) |> Nx.dot([2], [0], q, [1], [0]) # Eigenvalues and eigenvectors - assert {evals, evecs} = Nx.LinAlg.eigh(a, max_iter: 10_000) - assert_all_close(evals_test, evals, atol: 1.0e-1) + assert {evals, evecs} = Nx.LinAlg.eigh(a, eps: 1.0e-8) + + assert_all_close(evals_test, evals[0], atol: 1.0e-8) + assert_all_close(evals_test, evals[1], atol: 1.0e-8) + + evals = + evals + |> Nx.vectorize(:x) + |> Nx.make_diagonal() + |> Nx.devectorize(keep_names: false) # Eigenvalue equation - evecs_evals = Nx.multiply(evecs, evals) - a_evecs = Nx.dot(a, [2], [0], evecs, [1], [0]) + evecs_evals = Nx.dot(evecs, [2], [0], evals, [1], [0]) + a_evecs = Nx.dot(evecs_evals, [2], [0], Nx.LinAlg.adjoint(evecs), [1], [0]) - assert_all_close(evecs_evals, a_evecs, atol: 1.0e-1) + assert_all_close(a, a_evecs, atol: 1.0e-8) key end end @@ -734,10 +740,10 @@ defmodule Nx.LinAlgTest do assert_all_close( u, Nx.tensor([ - [0.141, 0.825, -0.001, 0.019], - [0.344, 0.426, 0.00200, 0.382], - [0.547, 0.028, 0.0, -0.822], - [0.75, -0.370, -0.001, 0.421] + [0.141, -0.825, -0.001, 0.019], + [0.344, -0.426, 0.00200, 0.382], + [0.547, -0.028, 0.0, -0.822], + [0.75, 0.370, -0.001, 0.421] ]), atol: 1.0e-3, rtol: 1.0e-3 @@ -747,8 +753,8 @@ defmodule Nx.LinAlgTest do assert_all_close( Nx.tensor([ - [0.505, 0.575, 0.644], - [-0.761, -0.057, 0.647], + [0.504, 0.575, 0.644], + [0.761, 0.057, -0.647], [-0.408, 0.816, -0.408] ]), v, @@ -801,9 +807,9 @@ defmodule Nx.LinAlgTest do assert_all_close( u, Nx.tensor([ - [0.336, -0.407, -0.849], - [0.037, -0.895, 0.444], - [0.941, 0.181, 0.286] + [0.335, 0.408, 0.849], + [0.036, 0.895, -0.445], + [0.941, -0.18, -0.286] ]), atol: 1.0e-3, rtol: 1.0e-3 @@ -815,9 +821,9 @@ defmodule Nx.LinAlgTest do assert_all_close( Nx.tensor([ - [0.035, 0.0869, 0.996], - [-0.091, -0.992, 0.09], - [-0.995, 0.094, 0.027] + [0.035, 0.0856, 0.996], + [0.092, 0.992, -0.089], + [0.995, -0.094, -0.027] ]), v, atol: 1.0e-3, diff --git a/torchx/mix.exs b/torchx/mix.exs index 070718c3b03..8a0f3cd3686 100644 --- a/torchx/mix.exs +++ b/torchx/mix.exs @@ -41,8 +41,8 @@ defmodule Torchx.MixProject do defp deps do [ - {:nx, "~> 0.9.0"}, - # {:nx, path: "../nx"}, + # {:nx, "~> 0.9.0"}, + {:nx, path: "../nx"}, {:ex_doc, "~> 0.29", only: :docs} ] end diff --git a/torchx/test/torchx/nx_linalg_doctest_test.exs b/torchx/test/torchx/nx_linalg_doctest_test.exs index 9f3c6eca521..30e75dafc55 100644 --- a/torchx/test/torchx/nx_linalg_doctest_test.exs +++ b/torchx/test/torchx/nx_linalg_doctest_test.exs @@ -18,7 +18,7 @@ defmodule Torchx.NxLinAlgDoctestTest do invert: 1, determinant: 1, pinv: 2, - least_squares: 2 + least_squares: 3 ] # Results do not match but properties are respected