diff --git a/exla/CHANGELOG.md b/exla/CHANGELOG.md index 466505a84e..a06ca4b41e 100644 --- a/exla/CHANGELOG.md +++ b/exla/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## v0.9.2 (2024-11-16) + +### Enhancements + + * Support cross-compilation for use with Nerves + * Optimize LU with a custom call + ## v0.9.1 (2024-10-08) ### Enhancements diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index d2f2fd7357..fbf392862f 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -787,8 +787,8 @@ defmodule EXLA.Defn do transform = Keyword.fetch!(opts, :transform_a) case Value.get_typespec(b).shape do - {_} = b_shape -> - b_shape = Tuple.append(b_shape, 1) + {dim} -> + b_shape = {dim, 1} b = b diff --git a/exla/mix.lock b/exla/mix.lock index cb55ce51fc..51edca3542 100644 --- a/exla/mix.lock +++ b/exla/mix.lock @@ -1,6 +1,6 @@ %{ "benchee": {:hex, :benchee, "1.1.0", "f3a43817209a92a1fade36ef36b86e1052627fd8934a8b937ac9ab3a76c43062", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}, {:statistex, "~> 1.0", [hex: :statistex, repo: "hexpm", optional: false]}], "hexpm", "7da57d545003165a012b587077f6ba90b89210fd88074ce3c60ce239eb5e6d93"}, - "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, + "complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"}, "deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"}, "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, "elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"}, diff --git a/nx/CHANGELOG.md b/nx/CHANGELOG.md index 69f8c81c16..23aaa25268 100644 --- a/nx/CHANGELOG.md +++ b/nx/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## v0.9.2 (2024-11-16) + +### Bug fixes + + * [Nx] Fix deprecation warnings on latest Elixir + * [Nx.LinAlg] Fix `least_squares` implementation + * [Nx.Random] Fix `Nx.Random.shuffle` repeating a single value in certain cases on GPU + ## v0.9.1 (2024-10-08) ### Deprecations diff --git a/nx/guides/advanced/aggregation.livemd b/nx/guides/advanced/aggregation.livemd index 2e8cbc7262..b9d5678c2e 100644 --- a/nx/guides/advanced/aggregation.livemd +++ b/nx/guides/advanced/aggregation.livemd @@ -64,7 +64,7 @@ m = ~MAT[ ] ``` -First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights. +First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_(matrices)#:~:text=In%20mathematics%2C%20the%20Hadamard%20product,elements%20i%2C%20j%20of%20the), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights. ```elixir w = ~MAT[ diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index ce5cb99ceb..8206f4106b 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -417,7 +417,7 @@ defmodule Nx do config: [nx: [default_backend: EXLA.Backend]] ) - Or by calling `Nx.global_default_backend/1` (less preferrable): + Or by calling `Nx.global_default_backend/1` (less preferable): Nx.global_default_backend(EXLA.Backend) @@ -614,7 +614,7 @@ defmodule Nx do > Certain backends and compilers support 8-bit floats. The precision - iomplementation of 8-bit floats may change per backend, so you must + implementation of 8-bit floats may change per backend, so you must be careful when transferring data across. The binary backend implements F8E5M2: @@ -943,7 +943,7 @@ defmodule Nx do for t <- [:u2, :u4, :u8, :u16, :u32, :u64, :s2, :s4, :s8, :s16, :s32, :s64] ++ - [:f8, :bf16, :f16, :f32, :f64] do + [:f8, :bf16, :f16, :f32, :f64, :c64, :c128] do @doc """ Short-hand function for creating tensor of type `#{t}`. @@ -1305,7 +1305,7 @@ defmodule Nx do out = case shape do {n} -> - intermediate_shape = Tuple.duplicate(1, tuple_size(out_shape) - 1) |> Tuple.append(n) + intermediate_shape = Tuple.duplicate(1, tuple_size(out_shape) - 1) |> tuple_append(n) backend.eye( %T{type: type, shape: intermediate_shape, names: names}, @@ -1609,7 +1609,7 @@ defmodule Nx do t else diag_length = div(Nx.size(t), Tuple.product(batch_shape)) - Nx.reshape(t, Tuple.append(batch_shape, diag_length)) + Nx.reshape(t, tuple_append(batch_shape, diag_length)) end end @@ -7684,7 +7684,7 @@ defmodule Nx do number of indices, while `updates` must have a compatible `{n, ...j}` shape, such that `i + j = rank(tensor)`. - In case of repeating indices, the result is non-determinstic, since the operation happens + In case of repeating indices, the result is non-deterministic, since the operation happens in parallel when running on devices such as the GPU. See also: `indexed_add/3`, `put_slice/3`. @@ -10365,9 +10365,9 @@ defmodule Nx do if opts[:keep_axis] do new_shape |> Tuple.delete_at(tuple_size(new_shape) - 1) - |> Tuple.append(:auto) + |> tuple_append(:auto) else - Tuple.append(new_shape, :auto) + tuple_append(new_shape, :auto) end reshaped_tensor = reshape(tensor, flattened_shape) @@ -12919,6 +12919,11 @@ defmodule Nx do of summing the element-wise products in the window across each input channel. + > #### Kernel Reflection {: .info} + > + > See the note at the end of this section for more details + > on the convention for kernel reflection and conjugation. + The ranks of both `input` and `kernel` must match. By default, both `input` and `kernel` are expected to have shapes of the following form: @@ -13000,6 +13005,45 @@ defmodule Nx do in the same way as with `:feature_group_size`, however, the input tensor will be split into groups along the batch dimension. + > #### Convolution vs Correlation {: .tip} + > + > `conv/3` does not perform reversion of the kernel. + > This means that if you come from a Signal Processing background, + > you might treat it as a cross-correlation operation instead of a convolution. + > + > This function is not exactly a cross-correlation function, as it does not + > perform conjugation of the kernel, as is done in [scipy.signal.correlate](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.correlate.html). + > This can be remedied via `Nx.conjugate/1`, as seen below: + > + > ```elixir + > kernel = + > if Nx.Type.complex?(Nx.type(kernel)) do + > Nx.conjugate(kernel) + > else + > kernel + > end + > + > Nx.conv(tensor, kernel) + > ``` + > + > If you need the proper Signal Processing convolution, such as the one in + > [scipy.signal.convolve](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.convolve.html), + > you can use `reverse/2`, like in the example: + > + > ```elixir + > reversal_axes = + > case Nx.rank(kernel) do + > 0 -> [] + > 1 -> [1] + > 2 -> [0, 1] + > _ -> Enum.drop(Nx.axes(kernel), 2) + > end + > + > kernel = Nx.reverse(kernel, axes: reversal_axes) + > + > Nx.conv(tensor, kernel) + > ``` + ## Examples iex> left = Nx.iota({1, 1, 3, 3}) @@ -13554,7 +13598,7 @@ defmodule Nx do end) |> Nx.stack() |> Nx.revectorize(vectorized_axes, - target_shape: Tuple.append(List.to_tuple(lengths), :auto) + target_shape: tuple_append(List.to_tuple(lengths), :auto) ) Nx.gather(tensor, idx) @@ -14288,7 +14332,7 @@ defmodule Nx do Nx.Shared.optional(:take_along_axis, [tensor, indices, [axis: axis]], out, fn tensor, indices, _opts -> axes_range = axes(indices) - new_axis_shape = Tuple.append(shape(indices), 1) + new_axis_shape = tuple_append(shape(indices), 1) full_indices = axes_range @@ -14471,7 +14515,7 @@ defmodule Nx do indices = devectorize(indices, keep_names: false) iota_shape = - indices.shape |> Tuple.delete_at(tuple_size(indices.shape) - 1) |> Tuple.append(1) + indices.shape |> Tuple.delete_at(tuple_size(indices.shape) - 1) |> tuple_append(1) offset_axes = (offset - 1)..0//-1 diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index 3954028e1f..1193a0b8c4 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -138,7 +138,7 @@ defmodule Nx.Backend do First we will attempt to call the optional callback itself (one of the many callbacks defined below), then we attempt to call this callback (which is also optional), then we - fallback to the default iomplementation. + fallback to the default implementation. """ @callback optional(atom, [term], fun) :: tensor diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 583b964bb6..10c3b58e33 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -755,11 +755,29 @@ defmodule Nx.BinaryBackend do defp element_equal(_, :nan, _), do: 0 defp element_equal(_, _, :nan), do: 0 - defp element_equal(_, a, b), do: boolean_as_number(a == b) + + defp element_equal(_, a, b) do + bool = + case {a, b} do + {%Complex{re: re_a, im: im_a}, b} when is_number(b) -> + re_a == b and im_a == 0 + + {a, %Complex{re: re_b, im: im_b}} when is_number(a) -> + a == re_b and im_b == 0 + + {a, b} -> + a == b + end + + boolean_as_number(bool) + end defp element_not_equal(_, :nan, _), do: 1 defp element_not_equal(_, _, :nan), do: 1 - defp element_not_equal(_, a, b), do: boolean_as_number(a != b) + + defp element_not_equal(out, a, b) do + 1 - element_equal(out, a, b) + end defp element_logical_and(_, a, b), do: boolean_as_number(as_boolean(a) and as_boolean(b)) defp element_logical_or(_, a, b), do: boolean_as_number(as_boolean(a) or as_boolean(b)) @@ -1240,25 +1258,6 @@ defmodule Nx.BinaryBackend do output_batch_groups |> Enum.with_index() |> Enum.map(fn {x, i} -> {x, rem(i, groups)} end) end - @impl true - def eigh( - {%{type: output_type} = eigenvals_holder, eigenvecs_holder}, - %{type: input_type, shape: input_shape} = tensor, - opts - ) do - bin = to_binary(tensor) - rank = tuple_size(input_shape) - n = elem(input_shape, rank - 1) - - {eigenvals, eigenvecs} = - bin_batch_reduce(bin, n * n, input_type, {<<>>, <<>>}, fn matrix, {vals_acc, vecs_acc} -> - {vals, vecs} = B.Matrix.eigh(matrix, input_type, {n, n}, output_type, opts) - {vals_acc <> vals, vecs_acc <> vecs} - end) - - {from_binary(eigenvals_holder, eigenvals), from_binary(eigenvecs_holder, eigenvecs)} - end - @impl true def lu( {%{type: p_type} = p_holder, %{type: l_type} = l_holder, %{type: u_type} = u_holder}, @@ -1492,7 +1491,7 @@ defmodule Nx.BinaryBackend do dilations = opts[:window_dilations] %T{shape: padded_shape, type: {_, size} = type} = - tensor = Nx.pad(tensor, acc, Enum.map(padding_config, &Tuple.append(&1, 0))) + tensor = Nx.pad(tensor, acc, Enum.map(padding_config, &tuple_append(&1, 0))) acc = scalar_to_number(acc) @@ -1608,7 +1607,7 @@ defmodule Nx.BinaryBackend do init_value = scalar_to_number(init_value) %T{shape: padded_shape, type: {_, size} = type} = - tensor = Nx.pad(t, init_value, Enum.map(padding, &Tuple.append(&1, 0))) + tensor = Nx.pad(t, init_value, Enum.map(padding, &tuple_append(&1, 0))) input_data = to_binary(tensor) input_weighted_shape = weighted_shape(padded_shape, size, window_dimensions) @@ -2077,7 +2076,7 @@ defmodule Nx.BinaryBackend do for <>, into: <<>> do x = read!(x, 0) - case x do + generated_case x do %Complex{re: re} when float_output? and real_output? -> number_to_binary(re, output_type) @@ -2253,17 +2252,16 @@ defmodule Nx.BinaryBackend do end end - output_data = - match_types [out.type] do - for row <- result, %Complex{re: re, im: im} <- row, into: <<>> do - re = if abs(re) <= eps, do: 0, else: re - im = if abs(im) <= eps, do: 0, else: im + %{type: {_, output_size}} = out - <> - end + output_data = + for row <- result, %Complex{re: re, im: im} <- row, into: <<>> do + re = if abs(re) <= eps, do: 0, else: re + im = if abs(im) <= eps, do: 0, else: im + <> end - intermediate_shape = out.shape |> Tuple.delete_at(axis) |> Tuple.append(n) + intermediate_shape = out.shape |> Tuple.delete_at(axis) |> tuple_append(n) permuted_output = from_binary(%{out | shape: intermediate_shape}, output_data) @@ -2391,20 +2389,6 @@ defmodule Nx.BinaryBackend do end end - defp bin_zip_reduce(t1, [], t2, [], type, acc, fun) do - %{type: {_, s1}} = t1 - %{type: {_, s2}} = t2 - b1 = to_binary(t1) - b2 = to_binary(t2) - - match_types [t1.type, t2.type] do - for <>, <>, into: <<>> do - {result, _} = fun.(d1, d2, acc) - scalar_to_binary!(result, type) - end - end - end - defp bin_zip_reduce(t1, [_ | _] = axes1, t2, [_ | _] = axes2, type, acc, fun) do {_, s1} = t1.type {_, s2} = t2.type diff --git a/nx/lib/nx/binary_backend/matrix.ex b/nx/lib/nx/binary_backend/matrix.ex index 85601295b9..afb55fb668 100644 --- a/nx/lib/nx/binary_backend/matrix.ex +++ b/nx/lib/nx/binary_backend/matrix.ex @@ -116,150 +116,6 @@ defmodule Nx.BinaryBackend.Matrix do defp do_ts([], [], _idx, acc), do: acc - defp qr_decomposition(matrix, n, _eps) when n in 0..1 do - {[[1.0]], matrix} - end - - defp qr_decomposition(matrix, n, eps) when n >= 2 do - # QR decomposition is performed by using Householder transform - # this function originally supported generic QR, but - # it is now only used by eigh. Because of this, - # we simplified the function signature to only - # support square matrices. - - {q_matrix, r_matrix} = - for i <- 0..(n - 2)//1, reduce: {nil, matrix} do - {q, r} -> - h = - r - |> slice_matrix([i, i], [n - i, 1]) - |> householder_reflector(n, eps) - - # If we haven't allocated Q yet, let Q = H1 - # TODO: Resolve inconsistent with the Householder reflector. - # cf. https://github.com/elixir-nx/nx/pull/933#discussion_r982772063 - q = - if is_nil(q) do - h - else - dot_matrix_real(q, h) - end - - r = dot_matrix_real(h, r) - {q, r} - end - - {approximate_zeros(q_matrix, eps), approximate_zeros(r_matrix, eps)} - end - - defp raise_not_hermitian do - raise ArgumentError, - "matrix must be hermitian, a matrix is hermitian iff X = adjoint(X)" - end - - def eigh(input_data, input_type, {n, n} = input_shape, output_type, opts) do - eps = opts[:eps] - max_iter = opts[:max_iter] - - # Validate that the input is a Hermitian matrix using the relation A^* = A. - a = binary_to_matrix(input_data, input_type, input_shape) - - is_hermitian = - a - |> transpose_matrix() - |> Enum.map(fn a_row -> Enum.map(a_row, &Complex.conjugate(&1)) end) - |> is_approximately_same?(a, eps) - - unless is_hermitian do - raise_not_hermitian() - end - - # Hessenberg decomposition - {h, q_h} = hessenberg_decomposition(a, n, eps) - - # QR iteration for eigenvalues and eigenvectors - {eigenvals_diag, eigenvecs} = - Enum.reduce_while(1..max_iter//1, {h, q_h}, fn _, {a_old, q_old} -> - # QR decomposition - {q_now, r_now} = qr_decomposition(a_old, n, eps) - - # Update matrix A, Q - a_new = dot_matrix_real(r_now, q_now) - q_new = dot_matrix_real(q_old, q_now) - - if is_approximately_same?(q_old, q_new, eps) do - {:halt, {a_new, q_new}} - else - {:cont, {a_new, q_new}} - end - end) - - # Obtain the eigenvalues, which are the diagonal elements - indices_diag = for idx <- 0..(n - 1), do: [idx, idx] - eigenvals = get_matrix_elements(eigenvals_diag, indices_diag) - - # In general, the eigenvalues of a Hermitian matrix are real numbers - eigenvals_real = eigenvals |> Enum.map(&Complex.real(&1)) - - # Reduce the elements smaller than eps to zero - {eigenvals_real |> approximate_zeros(eps) |> matrix_to_binary(output_type), - eigenvecs |> approximate_zeros(eps) |> matrix_to_binary(output_type)} - end - - defp hessenberg_decomposition(matrix, n, _eps) when n in 0..1 do - {matrix, [[1.0]]} - end - - defp hessenberg_decomposition(matrix, n, eps) do - # Hessenberg decomposition is performed by using Householder transform - {hess_matrix, q_matrix} = - for i <- 0..(n - 2)//1, reduce: {matrix, nil} do - {hess, q} -> - h = - hess - |> slice_matrix([i + 1, i], [n - i - 1, 1]) - |> householder_reflector(n, eps) - - # If we haven't allocated Q yet, let Q = H1 - # TODO: Resolve inconsistent with the Householder reflector. - # cf. https://github.com/elixir-nx/nx/pull/933#discussion_r982772063 - q = - if is_nil(q) do - h - else - dot_matrix_real(q, h) - end - - # Hessenberg matrix H updating - h_adj = adjoint_matrix(h) - - hess = - h - |> dot_matrix_real(hess) - |> dot_matrix_real(h_adj) - - {hess, q} - end - - {approximate_zeros(hess_matrix, eps), approximate_zeros(q_matrix, eps)} - end - - defp is_approximately_same?(a, b, eps) do - # Determine if matrices `a` and `b` are equal in the range of eps - a - |> Enum.zip(b) - |> Enum.all?(fn {a_row, b_row} -> - a_row - |> Enum.zip(b_row) - |> Enum.all?(fn - {a_elem, b_elem} -> - abs_diff = Complex.abs(a_elem - b_elem) - - abs_diff == :nan or abs_diff <= eps - end) - end) - end - def lu(input_data, input_type, {n, n} = input_shape, p_type, l_type, u_type, opts) do a = binary_to_matrix(input_data, input_type, input_shape) eps = opts[:eps] @@ -361,116 +217,6 @@ defmodule Nx.BinaryBackend.Matrix do end) end - ## Householder helpers - - defp householder_reflector(a, target_k, eps) - - defp householder_reflector([], target_k, _eps) do - flat_list = - for col <- 0..(target_k - 1), row <- 0..(target_k - 1), into: [] do - if col == row, do: 1, else: 0 - end - - Enum.chunk_every(flat_list, target_k) - end - - defp householder_reflector(a, target_k, eps) do - {v, scale, is_complex} = householder_reflector_pivot(a, eps) - - prefix_threshold = target_k - length(v) - v = List.duplicate(0, prefix_threshold) ++ v - - # dot(v, v) = norm_v_squared, which can be calculated from norm_a as: - # norm_v_squared = norm_a_squared - a_0^2 + v_0^2 - - # execute I - 2 / norm_v_squared * outer(v, v) - {_, _, reflector_reversed} = - for col_factor <- v, row_factor <- v, reduce: {0, 0, []} do - {row, col, acc} -> - row_factor = if is_complex, do: Complex.conjugate(row_factor), else: row_factor - - # The current element in outer(v, v) is given by col_factor * row_factor - # and the current I element is 1 when row == col - identity_element = if row == col, do: 1, else: 0 - - result = - if row >= prefix_threshold and col >= prefix_threshold do - identity_element - - scale * col_factor * row_factor - else - identity_element - end - - acc = [result | acc] - - if col + 1 == target_k do - {row + 1, 0, acc} - else - {row, col + 1, acc} - end - end - - # This is equivalent to reflector_reversed |> Enum.reverse() |> Enum.chunk_every(target_k) - {reflector, _, _} = - for x <- reflector_reversed, reduce: {[], [], 0} do - {result_acc, row_acc, col} -> - row_acc = [x | row_acc] - - if col + 1 == target_k do - {[row_acc | result_acc], [], 0} - else - {result_acc, row_acc, col + 1} - end - end - - reflector - end - - defp householder_reflector_pivot([a_0 | tail] = a, eps) when is_number(a_0) do - # This is a trick so we can both calculate the norm of a_reverse and extract the - # head a the same time we reverse the array - # receives a_reverse as a list of numbers and returns the reflector as a - # k x k matrix - - norm_a_squared = Enum.reduce(a, 0, fn x, acc -> x * Complex.conjugate(x) + acc end) - norm_a_sq_1on = norm_a_squared - a_0 * a_0 - - if norm_a_sq_1on < eps do - {[1 | tail], 0, false} - else - v_0 = - if a_0 <= 0 do - a_0 - Complex.sqrt(norm_a_squared) - else - -norm_a_sq_1on / (a_0 + Complex.sqrt(norm_a_squared)) - end - - v_0_sq = v_0 * v_0 - scale = 2 * v_0_sq / (norm_a_sq_1on + v_0_sq) - v = [1 | Enum.map(tail, &(&1 / v_0))] - {v, scale, false} - end - end - - defp householder_reflector_pivot([a_0 | tail], _eps) do - # complex case - norm_a_sq_1on = Enum.reduce(tail, 0, &(Complex.abs_squared(&1) + &2)) - norm_a_sq = norm_a_sq_1on + Complex.abs_squared(a_0) - norm_a = Complex.sqrt(norm_a_sq) - - phase_a_0 = Complex.phase(a_0) - alfa = Complex.exp(Complex.new(0, phase_a_0)) * norm_a - - # u = x - alfa * e1 - u_0 = a_0 + alfa - u = [u_0 | tail] - norm_u_sq = norm_a_sq_1on + Complex.abs_squared(u_0) - norm_u = Complex.sqrt(norm_u_sq) - - v = Enum.map(u, &(&1 / norm_u)) - {v, 2, true} - end - ## Matrix (2-D array) manipulation defp dot_matrix([], _), do: 0 @@ -491,24 +237,6 @@ defmodule Nx.BinaryBackend.Matrix do end) end - defp dot_matrix_real(m1, m2) do - Enum.map(m1, fn row -> - m2 - |> transpose_matrix() - |> Enum.map(fn col -> - Enum.zip_reduce(row, col, 0, fn x, y, acc -> acc + x * y end) - end) - end) - end - - defp adjoint_matrix([x | _] = m) when not is_list(x) do - Enum.map(m, &[Complex.conjugate(&1)]) - end - - defp adjoint_matrix(m) do - Enum.zip_with(m, fn cols -> Enum.map(cols, &Complex.conjugate/1) end) - end - defp transpose_matrix([x | _] = m) when not is_list(x) do Enum.map(m, &[&1]) end diff --git a/nx/lib/nx/defn.ex b/nx/lib/nx/defn.ex index e0d1cc1a30..e798fb31ab 100644 --- a/nx/lib/nx/defn.ex +++ b/nx/lib/nx/defn.ex @@ -66,7 +66,7 @@ defmodule Nx.Defn do ## JIT compilers The power of `Nx.Defn` is given by its compilers. The default - compiler is `Nx.Defn.Evaluator`, which evalutes the code. + compiler is `Nx.Defn.Evaluator`, which evaluates the code. You can use `jit/3` to compile a function on the fly using a different compiler, such as `EXLA`: diff --git a/nx/lib/nx/defn/compiler.ex b/nx/lib/nx/defn/compiler.ex index b529eae365..809076c748 100644 --- a/nx/lib/nx/defn/compiler.ex +++ b/nx/lib/nx/defn/compiler.ex @@ -585,6 +585,13 @@ defmodule Nx.Defn.Compiler do {{{:., dot_meta, [Nx, name]}, meta, args}, state} end + # We also allow specifically Complex.new so that literal complex numbers + # can be written in defn. + defp normalize({{:., dot_meta, [Complex, :new]}, meta, args}, state) do + {args, state} = normalize_list(args, state) + {{{:., dot_meta, [Complex, :new]}, meta, args}, state} + end + defp normalize({{:., dot_meta, [mod, name]}, meta, args}, state) when mod in @allowed_modules do {args, state} = normalize_list(args, state) {{{:., dot_meta, [mod, name]}, meta, args}, state} diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 997950a1f3..a51d72e677 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -1,5 +1,5 @@ defmodule Nx.Defn.Expr do - @doc """ + @moduledoc """ The expression used by `Nx.Defn.Compiler`. `Nx.Defn.Compiler` changes `Nx` default backend from `Nx.BinaryBackend` @@ -249,9 +249,18 @@ defmodule Nx.Defn.Expr do result = for expr <- [last | exprs] do - expr - |> Nx.as_type(type) - |> Nx.broadcast(shape, names: names) + typed_expr = + case expr do + %T{data: %Expr{op: :constant}} -> + expr + |> maybe_upcast_float_constant(type) + |> Nx.as_type(type) + + expr -> + Nx.as_type(expr, type) + end + + Nx.broadcast(typed_expr, shape, names: names) end {result, vectorized_axes} @@ -1271,7 +1280,7 @@ defmodule Nx.Defn.Expr do "value and inline it inside the defn expression. Got: #{inspect(t)}" end - defp to_expr(number) when is_number(number), + defp to_expr(number) when is_number(number) or is_struct(number, Complex), do: constant(%T{shape: {}, names: [], type: Nx.Type.infer(number)}, number) defp to_expr(other) do @@ -1401,6 +1410,10 @@ defmodule Nx.Defn.Expr do defp constant(%{shape: shape, type: type} = out, number) do number = cond do + Nx.Type.complex?(type) and + (is_number(number) or number in [:infinity, :neg_infinity, :nan]) -> + Complex.new(number, 0.0) + is_integer(number) and Nx.Type.float?(type) -> Complex.multiply(1.0, number) @@ -1468,16 +1481,42 @@ defmodule Nx.Defn.Expr do c1 = maybe_constant(arg1) c2 = maybe_constant(arg2) - if c1 && c2 do - apply(Nx.BinaryBackend, op, [ - %{out | shape: {}, names: []}, - constant_binary(arg1, c1), - constant_binary(arg2, c2) - ]) - |> Nx.to_number() - |> then(&constant(out, &1)) + cond do + c1 && c2 -> + apply(Nx.BinaryBackend, op, [ + %{out | shape: {}, names: []}, + constant_binary(arg1, c1), + constant_binary(arg2, c2) + ]) + |> Nx.to_number() + |> then(&constant(out, &1)) + + c1 -> + expr(out, context, op, [maybe_upcast_float_constant(arg1, out.type), arg2]) + + c2 -> + expr(out, context, op, [arg1, maybe_upcast_float_constant(arg2, out.type)]) + + true -> + expr(out, context, op, [arg1, arg2]) + end + end + + defp maybe_upcast_float_constant( + %T{type: type, data: %Expr{op: :constant, args: [number]}} = t, + out_type + ) do + # By default, Elixir floats are 64 bits, so we're not really upcasting + # if out_type is higher precision than what's annotated. + # This is just so that downstream code that relies on this type annotation + # properly interprets the f64 value as the higher precision type. + # This also means that if out_type is lower precision, `number` will be + # downcast to the lower precision type. + + if Nx.Type.float?(type) and Nx.Type.float?(out_type) do + constant(%{t | type: out_type}, number) else - expr(out, context, op, [arg1, arg2]) + t end end diff --git a/nx/lib/nx/defn/grad.ex b/nx/lib/nx/defn/grad.ex index 33a7a0deed..25b4a1178b 100644 --- a/nx/lib/nx/defn/grad.ex +++ b/nx/lib/nx/defn/grad.ex @@ -1157,7 +1157,7 @@ defmodule Nx.Defn.Grad do num_axes = tuple_size(window_dimensions) - indices = Nx.reshape(indices_to_flatten, Tuple.append(source.shape, num_axes)) + indices = Nx.reshape(indices_to_flatten, Nx.Shared.tuple_append(source.shape, num_axes)) dsource = Nx.gather(g, indices) dtensor = Nx.broadcast(0, tensor) @@ -1490,7 +1490,7 @@ defmodule Nx.Defn.Grad do end defp grad_scatter_window__gather_windows(tensor, window_dimensions, strides, padding) do - tensor = Nx.pad(tensor, 0, Enum.map(padding, &Tuple.append(&1, 0))) + tensor = Nx.pad(tensor, 0, Enum.map(padding, &Nx.Shared.tuple_append(&1, 0))) shape_l = Tuple.to_list(tensor.shape) window_dims_l = Tuple.to_list(window_dimensions) diff --git a/nx/lib/nx/defn/kernel.ex b/nx/lib/nx/defn/kernel.ex index 9bfb478237..ab913ab61f 100644 --- a/nx/lib/nx/defn/kernel.ex +++ b/nx/lib/nx/defn/kernel.ex @@ -1398,8 +1398,8 @@ defmodule Nx.Defn.Kernel do ## Named hooks It is possible to give names to the hooks. This allows them - to be defined or overridden by calling `Nx.Defn.jit/2` or - `Nx.Defn.stream/2`. Let's see an example: + to be defined or overridden by calling `Nx.Defn.jit/2`. + Let's see an example: defmodule Hooks do import Nx.Defn @@ -1437,9 +1437,8 @@ defmodule Nx.Defn.Kernel do {add, mult} end - If a hook with the same name is given to `Nx.Defn.jit/2` - or `Nx.Defn.stream/2`, then it will override the default - callback. + If a hook with the same name is given to `Nx.Defn.jit/2`, + then it will override the default callback. ## Hooks and tokens diff --git a/nx/lib/nx/lin_alg.ex b/nx/lib/nx/lin_alg.ex index 450a48cae4..edced80246 100644 --- a/nx/lib/nx/lin_alg.ex +++ b/nx/lib/nx/lin_alg.ex @@ -1179,8 +1179,8 @@ defmodule Nx.LinAlg do #Nx.Tensor< f32[2][2] [ - [3.9924824237823486, -1.0052783489227295], - [-3.0051186084747314, 1.0071179866790771] + [4.000002861022949, -1.0000008344650269], + [-3.000002384185791, 1.0000005960464478] ] > @@ -1275,14 +1275,14 @@ defmodule Nx.LinAlg do iex> Nx.round(eigenvals) #Nx.Tensor< f32[2] - [1.0, 2.0] + [2.0, 1.0] > iex> eigenvecs #Nx.Tensor< f32[2][2] [ - [1.0, 0.0], - [0.0, 1.0] + [0.0, 1.0], + [1.0, 0.0] ] > @@ -1296,9 +1296,9 @@ defmodule Nx.LinAlg do #Nx.Tensor< f32[3][3] [ - [0.4075949788093567, 0.9131628274917603, 0.0], - [0.40837883949279785, -0.18228201568126678, 0.8944271802902222], - [0.8167576789855957, -0.36456403136253357, -0.4472135901451111] + [0.40824827551841736, -0.18257419764995575, 0.8944271802902222], + [0.40824833512306213, 0.9128708839416504, 0.0], + [0.8164965510368347, -0.3651483952999115, -0.4472135901451111] ] > @@ -1308,7 +1308,7 @@ defmodule Nx.LinAlg do f32[2][2] [ [9.0, -1.0], - [1.0, 4.0] + [4.0, 1.0] ] > iex> eigenvecs @@ -1316,12 +1316,12 @@ defmodule Nx.LinAlg do f32[2][2][2] [ [ - [0.5612090229988098, -0.8276740908622742], - [0.8276740908622742, 0.5612090229988098] + [0.5606288313865662, 0.8280671834945679], + [0.8280671834945679, -0.5606288313865662] ], [ - [1.0, 0.0], - [0.0, 1.0] + [0.0, 1.0], + [1.0, 0.0] ] ] > @@ -1334,7 +1334,7 @@ defmodule Nx.LinAlg do f32[2] [ [9.0, -1.0], - [1.0, 4.0] + [4.0, 1.0] ] > iex> eigenvecs @@ -1343,12 +1343,12 @@ defmodule Nx.LinAlg do f32[2][2] [ [ - [0.5612090229988098, -0.8276740908622742], - [0.8276740908622742, 0.5612090229988098] + [0.5606288313865662, 0.8280671834945679], + [0.8280671834945679, -0.5606288313865662] ], [ - [1.0, 0.0], - [0.0, 1.0] + [0.0, 1.0], + [1.0, 0.0] ] ] > @@ -1376,7 +1376,7 @@ defmodule Nx.LinAlg do %{tensor | names: eigenvecs_name, type: output_type, shape: eigenvecs_shape}} :eigh - |> Nx.Shared.optional([tensor, opts], output, &Nx.LinAlg.Eigh.eigh/2) + |> Nx.Shared.optional([tensor, opts], output, &Nx.LinAlg.BlockEigh.eigh/2) |> Nx.vectorize(vectorized_axes) end @@ -2161,19 +2161,19 @@ defmodule Nx.LinAlg do iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2], [2, 3]]), Nx.tensor([1, 2])) #Nx.Tensor< f32[2] - [0.9977624416351318, 0.0011188983917236328] + [1.0000028610229492, -2.384185791015625e-6] > iex> Nx.LinAlg.least_squares(Nx.tensor([[0, 1], [1, 1], [2, 1], [3, 1]]), Nx.tensor([-1, 0.2, 0.9, 2.1])) #Nx.Tensor< f32[2] - [0.9966151118278503, -0.947966456413269] + [0.9999998211860657, -0.9500012993812561] > iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2, 3], [4, 5, 6]]), Nx.tensor([1, 2])) #Nx.Tensor< f32[3] - [-0.05534052848815918, 0.1111316829919815, 0.27760395407676697] + [-0.05555540323257446, 0.1111111044883728, 0.27777770161628723] > ## Error cases diff --git a/nx/lib/nx/lin_alg/block_eigh.ex b/nx/lib/nx/lin_alg/block_eigh.ex new file mode 100644 index 0000000000..1c6f7986d9 --- /dev/null +++ b/nx/lib/nx/lin_alg/block_eigh.ex @@ -0,0 +1,304 @@ +defmodule Nx.LinAlg.BlockEigh do + @moduledoc """ + Parallel Jacobi symmetric eigendecomposition. + + Reference implementation taking from XLA's eigh_expander + which is built on the approach in: + Brent, R. P., & Luk, F. T. (1985). The solution of singular-value + and symmetric eigenvalue problems on multiprocessor arrays. + SIAM Journal on Computing, 6(1), 69-84. https://doi.org/10.1137/0906007 + """ + require Nx + + import Nx.Defn + + defn eigh(matrix, opts \\ []) do + opts = keyword!(opts, eps: 1.0e-6, max_iter: 100) + + matrix + |> Nx.revectorize([collapsed_axes: :auto], + target_shape: {Nx.axis_size(matrix, -2), Nx.axis_size(matrix, -1)} + ) + |> decompose(opts) + |> revectorize_result(matrix) + end + + defnp decompose(matrix, opts) do + {n, _} = Nx.shape(matrix) + + if n > 1 do + m_decompose(matrix, opts) + else + {Nx.take_diagonal(Nx.real(matrix)), Nx.tensor([1], type: matrix.type)} + end + end + + defnp m_decompose(matrix, opts) do + eps = opts[:eps] + max_iter = opts[:max_iter] + + type = Nx.Type.to_floating(Nx.type(matrix)) + matrix = Nx.as_type(matrix, type) + {n, _} = Nx.shape(matrix) + i_n = n - 1 + mid = calculate_mid(i_n) + i_mid = mid - 1 + + tl = matrix[[0..i_mid, 0..i_mid]] + tr = matrix[[0..i_mid, mid..i_n]] + bl = matrix[[mid..i_n, 0..i_mid]] + br = matrix[[mid..i_n, mid..i_n]] + + # Pad if not even + {tr, bl, br} = + if Nx.remainder(n, 2) == 1 do + tr = Nx.pad(tr, 0, [{0, 0, 0}, {0, 1, 0}]) + bl = Nx.pad(bl, 0, [{0, 1, 0}, {0, 0, 0}]) + br = Nx.pad(br, 0, [{0, 1, 0}, {0, 1, 0}]) + {tr, bl, br} + else + {tr, bl, br} + end + + # Initialze tensors to hold eigenvectors + v_tl = v_br = Nx.eye(mid, type: type) + v_tr = v_bl = Nx.broadcast(Nx.tensor(0, type: type), {mid, mid}) + + {frob_norm, off_norm} = norms(tl, tr, bl, br) + + # Nested loop + # Outside loop performs the "sweep" operation until the norms converge + # or max iterations are hit. The Brent/Luk paper states that Log2(n) is + # a good estimate for convergence, but XLA chose a static number which wouldn't + # be reached until a matrix roughly greater than 20kx20k. + # + # The inner loop performs "sweep" rounds of n - 1, which is enough permutations to allow + # all sub matrices to share the needed values. + {{tl, br, v_tl, v_tr, v_bl, v_br}, _} = + while {{tl, br, v_tl, v_tr, v_bl, v_br}, {frob_norm, off_norm, tr, bl, i = 0}}, + off_norm > eps ** 2 * frob_norm and i < max_iter do + {tl, tr, bl, br, v_tl, v_tr, v_bl, v_br} = + perform_sweeps(tl, tr, bl, br, v_tl, v_tr, v_bl, v_br, mid, i_n) + + {frob_norm, off_norm} = norms(tl, tr, bl, br) + + {{tl, br, v_tl, v_tr, v_bl, v_br}, {frob_norm, off_norm, tr, bl, i + 1}} + end + + # Recombine + w = Nx.concatenate([Nx.take_diagonal(tl), Nx.take_diagonal(br)]) + + v = + Nx.concatenate([ + Nx.concatenate([v_tl, v_tr], axis: 1), + Nx.concatenate([v_bl, v_br], axis: 1) + ]) + |> Nx.LinAlg.adjoint() + + # trim padding + {w, v} = + if Nx.remainder(n, 2) == 1 do + {w[0..i_n], v[[0..i_n, 0..i_n]]} + else + {w, v} + end + + sort_ind = Nx.argsort(Nx.abs(w), direction: :desc) + + w = Nx.take(w, sort_ind) |> approximate_zeros(eps) + v = Nx.take(v, sort_ind, axis: 1) |> approximate_zeros(eps) + + {w, v} + end + + deftransformp calculate_mid(i_n) do + Range.size(0..i_n//2) + end + + defnp calc_rot(tl, tr, br) do + complex? = tl |> Nx.type() |> Nx.Type.complex?() + br = Nx.take_diagonal(br) |> Nx.real() + tr = Nx.take_diagonal(tr) + tl = Nx.take_diagonal(tl) |> Nx.real() + + {tr, w} = + if complex? do + abs_tr = Nx.abs(tr) + {abs_tr, Nx.select(abs_tr == 0, 1, Nx.conjugate(tr) / abs_tr)} + else + {tr, 1} + end + + z_tr = Nx.equal(tr, 0) + s_tr = Nx.select(z_tr, 1, tr) + tau = Nx.select(z_tr, 0, (br - tl) / (2 * s_tr)) + + t = Nx.sqrt(1 + tau ** 2) + + t = 1 / (tau + Nx.select(tau >= 0, t, -t)) + + pred = Nx.abs(tr) <= 1.0e-5 * Nx.min(Nx.abs(br), Nx.abs(tl)) + t = Nx.select(pred, Nx.tensor(0, type: tl.type), t) + + c = 1.0 / Nx.sqrt(1.0 + t ** 2) + s = if complex?, do: Nx.complex(t * c, 0) * w, else: t * c + + rt1 = tl - t * tr + rt2 = br + t * tr + {rt1, rt2, c, s} + end + + defnp sq_norm(tl, tr, bl, br) do + Nx.sum(Nx.abs(tl) ** 2 + Nx.abs(tr) ** 2 + Nx.abs(bl) ** 2 + Nx.abs(br) ** 2) + end + + defnp off_norm(tl, tr, bl, br) do + {n, _} = Nx.shape(tl) + diag = Nx.broadcast(0, {n}) + o_tl = Nx.put_diagonal(tl, diag) + o_br = Nx.put_diagonal(br, diag) + + sq_norm(o_tl, tr, bl, o_br) + end + + # Calculates the Frobenius norm and the norm of the off-diagonals from + # the submatrices. Used to calculate convergeance. + defnp norms(tl, tr, bl, br) do + frob = sq_norm(tl, tr, bl, br) + off = off_norm(tl, tr, bl, br) + + {frob, off} + end + + deftransformp revectorize_result({eigenvals, eigenvecs}, a) do + shape = Nx.shape(a) + + { + Nx.revectorize(eigenvals, a.vectorized_axes, + target_shape: Tuple.delete_at(shape, tuple_size(shape) - 1) + ), + Nx.revectorize(eigenvecs, a.vectorized_axes, target_shape: shape) + } + end + + defnp perform_sweeps(tl, tr, bl, br, v_tl, v_tr, v_bl, v_br, mid, i_n) do + while {tl, tr, bl, br, v_tl, v_tr, v_bl, v_br}, _n <- 0..i_n do + {rt1, rt2, c, s} = calc_rot(tl, tr, br) + # build row and column vectors for parrelelized rotations + c_v = Nx.new_axis(c, 1) + s_v = Nx.new_axis(s, 1) + c_h = Nx.new_axis(c, 0) + s_h = Nx.new_axis(s, 0) + + s_v_conj = + if Nx.type(s) |> Nx.Type.complex?() do + Nx.conjugate(s_v) + else + s_v + end + + s_h_conj = Nx.transpose(s_v_conj) + + # Each rotation group below is performed based on the same + # tl, bl, tr, br values, so we must do single-expr + # assignments (i.e. {tl, tr, bl, br} = ...) + + # Rotate rows + {tl, tr, bl, br} = { + tl * c_v - bl * s_v_conj, + tr * c_v - br * s_v_conj, + tl * s_v + bl * c_v, + tr * s_v + br * c_v + } + + # Rotate cols + {tl, tr, bl, br} = { + tl * c_h - tr * s_h, + tl * s_h_conj + tr * c_h, + bl * c_h - br * s_h, + bl * s_h_conj + br * c_h + } + + # Store results and permute values across sub matrices + zero_diag = Nx.broadcast(0, {mid}) + tl = Nx.put_diagonal(tl, rt1) + tr = Nx.put_diagonal(tr, zero_diag) + bl = Nx.put_diagonal(bl, zero_diag) + br = Nx.put_diagonal(br, rt2) + + {tl, tr} = permute_cols_in_row(tl, tr) + {bl, br} = permute_cols_in_row(bl, br) + {tl, bl} = permute_rows_in_col(tl, bl) + {tr, br} = permute_rows_in_col(tr, br) + + # Rotate to calc vectors + {v_tl, v_tr, v_bl, v_br} = { + v_tl * c_v - v_bl * s_v_conj, + v_tr * c_v - v_br * s_v_conj, + v_tl * s_v + v_bl * c_v, + v_tr * s_v + v_br * c_v + } + + # permute for vectors + {v_tl, v_bl} = permute_rows_in_col(v_tl, v_bl) + {v_tr, v_br} = permute_rows_in_col(v_tr, v_br) + + {tl, tr, bl, br, v_tl, v_tr, v_bl, v_br} + end + end + + defnp approximate_zeros(matrix, eps), do: Nx.select(Nx.abs(matrix) <= eps, 0, matrix) + + # https://github.com/openxla/xla/blob/main/xla/hlo/transforms/expanders/eigh_expander.cc#L200-L239 + defnp permute_rows_in_col(top, bottom) do + {k, _} = Nx.shape(top) + + {top_out, bottom_out} = + cond do + k == 2 -> + {Nx.concatenate([top[0..0], bottom[0..0]], axis: 0), + Nx.concatenate( + [ + bottom[1..-1//1], + top[(k - 1)..(k - 1)] + ], + axis: 0 + )} + + k == 1 -> + {top, bottom} + + true -> + {Nx.concatenate([top[0..0], bottom[0..0], top[1..(k - 2)]], axis: 0), + Nx.concatenate( + [ + bottom[1..-1//1], + top[(k - 1)..(k - 1)] + ], + axis: 0 + )} + end + + {top_out, bottom_out} + end + + defnp permute_cols_in_row(left, right) do + {k, _} = Nx.shape(left) + + {left_out, right_out} = + cond do + k == 2 -> + {Nx.concatenate([left[[.., 0..0]], right[[.., 0..0]]], axis: 1), + Nx.concatenate([right[[.., 1..(k - 1)]], left[[.., (k - 1)..(k - 1)]]], axis: 1)} + + k == 1 -> + {left, right} + + true -> + {Nx.concatenate([left[[.., 0..0]], right[[.., 0..0]], left[[.., 1..(k - 2)]]], axis: 1), + Nx.concatenate([right[[.., 1..(k - 1)]], left[[.., (k - 1)..(k - 1)]]], axis: 1)} + end + + {left_out, right_out} + end +end diff --git a/nx/lib/nx/lin_alg/svd.ex b/nx/lib/nx/lin_alg/svd.ex index 338720d0db..7e1cad84f0 100644 --- a/nx/lib/nx/lin_alg/svd.ex +++ b/nx/lib/nx/lin_alg/svd.ex @@ -59,9 +59,9 @@ defmodule Nx.LinAlg.SVD do collapsed_axes = shape |> Tuple.delete_at(rank - 2) |> Tuple.delete_at(rank - 2) - u_shape = collapsed_axes |> Tuple.append(m) |> Tuple.append(:auto) - s_shape = Tuple.append(collapsed_axes, :auto) - vt_shape = Tuple.append(s_shape, n) + u_shape = collapsed_axes |> Nx.Shared.tuple_append(m) |> Nx.Shared.tuple_append(:auto) + s_shape = Nx.Shared.tuple_append(collapsed_axes, :auto) + vt_shape = Nx.Shared.tuple_append(s_shape, n) {{m, n}, u_shape, s_shape, vt_shape} end diff --git a/nx/lib/nx/random.ex b/nx/lib/nx/random.ex index 5b429517cb..2523c5114d 100644 --- a/nx/lib/nx/random.ex +++ b/nx/lib/nx/random.ex @@ -746,7 +746,7 @@ defmodule Nx.Random do {dim} dims when is_tuple(dims) -> - Tuple.append(dims, dim) + Nx.Shared.tuple_append(dims, dim) _ -> raise ArgumentError, @@ -1126,7 +1126,7 @@ defmodule Nx.Random do case type do {:c, _} -> type = Nx.Type.to_real(type) - data = fun.(key, type, Tuple.append(shape, 2)) + data = fun.(key, type, Nx.Shared.tuple_append(shape, 2)) to_complex = Nx.stack([1, Nx.Constants.i()]) Nx.dot(data, to_complex) diff --git a/nx/lib/nx/serving.ex b/nx/lib/nx/serving.ex index acb0523448..ab0a524bd7 100644 --- a/nx/lib/nx/serving.ex +++ b/nx/lib/nx/serving.ex @@ -1117,7 +1117,7 @@ defmodule Nx.Serving do end defp distributed_batched_run_with_retries!(name, input, retries) do - case :pg.get_members(Nx.Serving.PG, __MODULE__) do + case :pg.get_members(Nx.Serving.PG, name) do [] -> exit({:noproc, {__MODULE__, :distributed_batched_run, [name, input, [retries: retries]]}}) @@ -1332,7 +1332,7 @@ defmodule Nx.Serving do ) serving_weight = max(1, weight * partitions_count) - :pg.join(Nx.Serving.PG, __MODULE__, List.duplicate(self(), serving_weight)) + :pg.join(Nx.Serving.PG, name, List.duplicate(self(), serving_weight)) for batch_key <- batch_keys do stack_init(batch_key) diff --git a/nx/lib/nx/shape.ex b/nx/lib/nx/shape.ex index 61d7eeb941..3992822f6d 100644 --- a/nx/lib/nx/shape.ex +++ b/nx/lib/nx/shape.ex @@ -1739,8 +1739,6 @@ defmodule Nx.Shape do end) end - defp assert_non_concat_dims_equal([], _axis), do: :ok - defp assert_non_concat_dims_equal([s1 | shapes], axis) do s1_size = tuple_size(s1) template = Tuple.delete_at(s1, axis) diff --git a/nx/lib/nx/shared.ex b/nx/lib/nx/shared.ex index 392e7416d3..072f7db792 100644 --- a/nx/lib/nx/shared.ex +++ b/nx/lib/nx/shared.ex @@ -6,6 +6,13 @@ defmodule Nx.Shared do ## Type macros + defmacro generated_case(expr, do: clauses) do + clauses = + Enum.map(clauses, fn {:->, meta, args} -> {:->, [generated: true] ++ meta, args} end) + + {:case, [generated: true], [expr, [do: clauses]]} + end + @doc """ Match the cartesian product of all given types. @@ -438,6 +445,13 @@ defmodule Nx.Shared do ## Helpers + @doc """ + Appends an element to a tuple. + """ + def tuple_append(tuple, elem) do + Tuple.insert_at(tuple, tuple_size(tuple), elem) + end + @doc """ Extracts the backend from the given options. """ diff --git a/nx/lib/nx/type.ex b/nx/lib/nx/type.ex index 2097693d6a..c529d09c31 100644 --- a/nx/lib/nx/type.ex +++ b/nx/lib/nx/type.ex @@ -373,7 +373,9 @@ defmodule Nx.Type do bits. Otherwise it casts to f64. In the case of complex numbers, the maximum bit size is 128 bits - because they are composed of two floats. + because they are composed of two floats. Float types are promoted + to c64 by default, with the exception of f64, which is promoted to + c128 so that a single component can represent an f64 number properly. ## Examples @@ -429,8 +431,15 @@ defmodule Nx.Type do iex> Nx.Type.merge({:f, 64}, {:bf, 16}) {:f, 64} + iex> Nx.Type.merge({:f, 16}, {:c, 64}) + {:c, 64} + iex> Nx.Type.merge({:f, 32}, {:c, 64}) + {:c, 64} + iex> Nx.Type.merge({:f, 64}, {:c, 64}) + {:c, 128} iex> Nx.Type.merge({:c, 64}, {:f, 32}) {:c, 64} + iex> Nx.Type.merge({:c, 64}, {:c, 64}) {:c, 64} iex> Nx.Type.merge({:c, 128}, {:c, 64}) @@ -443,6 +452,7 @@ defmodule Nx.Type do def merge(left, right) do case sort(left, right) do {{:u, size1}, {:s, size2}} -> {:s, max(min(size1 * 2, 64), size2)} + {{:f, size1}, {:c, size2}} -> {:c, max(size1 * 2, size2)} {_, type2} -> type2 end end diff --git a/nx/mix.exs b/nx/mix.exs index effc66eb4f..a43972cf17 100644 --- a/nx/mix.exs +++ b/nx/mix.exs @@ -35,7 +35,7 @@ defmodule Nx.MixProject do defp deps do [ - {:complex, "~> 0.5"}, + {:complex, "~> 0.6"}, {:telemetry, "~> 0.4.0 or ~> 1.0"}, {:ex_doc, "~> 0.29", only: :docs} ] diff --git a/nx/mix.lock b/nx/mix.lock index fab455ebe4..2d7fd485c4 100644 --- a/nx/mix.lock +++ b/nx/mix.lock @@ -1,5 +1,5 @@ %{ - "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, + "complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"}, "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, diff --git a/nx/test/nx/complex_test.exs b/nx/test/nx/complex_test.exs index a7ea35c47d..9b3bcf1196 100644 --- a/nx/test/nx/complex_test.exs +++ b/nx/test/nx/complex_test.exs @@ -231,6 +231,60 @@ defmodule Nx.ComplexTest do end end end + + test "equal" do + one_r = 1 + one_u8 = Nx.tensor(1, type: {:u, 8}) + zero_r = 0 + zero_u8 = Nx.tensor(0, type: {:u, 8}) + one_c = Complex.new(1, 0) + zero_c = Complex.new(0, 0) + + assert Nx.equal(one_r, one_r) == one_u8 + assert Nx.equal(one_r, zero_r) == zero_u8 + + assert Nx.equal(one_c, one_c) == one_u8 + assert Nx.equal(one_c, zero_c) == zero_u8 + + assert Nx.equal(one_r, one_c) == one_u8 + assert Nx.equal(zero_r, one_c) == zero_u8 + + assert Nx.equal(one_c, one_r) == one_u8 + assert Nx.equal(one_c, zero_r) == zero_u8 + + assert Nx.equal(:nan, one_r) == zero_u8 + assert Nx.equal(:nan, one_c) == zero_u8 + + assert Nx.equal(one_r, :nan) == zero_u8 + assert Nx.equal(one_c, :nan) == zero_u8 + end + + test "not_equal" do + one_r = 1 + one_u8 = Nx.tensor(1, type: {:u, 8}) + zero_r = 0 + zero_u8 = Nx.tensor(0, type: {:u, 8}) + one_c = Complex.new(1, 0) + zero_c = Complex.new(0, 0) + + assert Nx.not_equal(one_r, one_r) == zero_u8 + assert Nx.not_equal(one_r, zero_r) == one_u8 + + assert Nx.not_equal(one_c, one_c) == zero_u8 + assert Nx.not_equal(one_c, zero_c) == one_u8 + + assert Nx.not_equal(one_r, one_c) == zero_u8 + assert Nx.not_equal(zero_r, one_c) == one_u8 + + assert Nx.not_equal(one_c, one_r) == zero_u8 + assert Nx.not_equal(one_c, zero_r) == one_u8 + + assert Nx.not_equal(:nan, one_r) == one_u8 + assert Nx.not_equal(:nan, one_c) == one_u8 + + assert Nx.not_equal(one_r, :nan) == one_u8 + assert Nx.not_equal(one_c, :nan) == one_u8 + end end describe "LinAlg not yet implemented" do diff --git a/nx/test/nx/defn/expr_test.exs b/nx/test/nx/defn/expr_test.exs index e901c83392..e579b39211 100644 --- a/nx/test/nx/defn/expr_test.exs +++ b/nx/test/nx/defn/expr_test.exs @@ -198,6 +198,29 @@ defmodule Nx.Defn.ExprTest do c = metadata b, :stop_grad s32[1] """ end + + test "upcast float constants when operating against higher precision types" do + t_f32 = Nx.tensor([2, 2], type: :f32) |> Expr.tensor() + c_f64 = Expr.constant(Nx.tensor(0.7, type: :f64), 0.7, []) + + assert %T{type: {:f, 64}, data: %Expr{op: :multiply, args: [^c_f64, ^t_f32]}} = + Nx.multiply(t_f32, c_f64) + + t_f64 = Nx.tensor([2, 2], type: :f64) |> Expr.tensor() + c_f32 = Expr.constant(Nx.tensor(0.7, type: :f32), 0.7, []) + + assert %T{type: {:f, 64}, data: %Expr{op: :multiply, args: [^c_f64, ^t_f64]}} = + Nx.multiply(t_f64, c_f32) + + c_c64 = Expr.constant(Nx.tensor(0.7, type: :c64), 0.7, []) + c_c128 = Expr.constant(Nx.tensor(0.7, type: :c128), 0.7, []) + + assert %T{type: {:c, 64}, data: %Expr{op: :multiply, args: [^c_c64, ^t_f32]}} = + Nx.multiply(t_f32, c_c64) + + assert %T{type: {:c, 128}, data: %Expr{op: :multiply, args: [^c_c128, ^t_f64]}} = + Nx.multiply(t_f64, c_c64) + end end describe "inspect" do diff --git a/nx/test/nx/defn/grad_test.exs b/nx/test/nx/defn/grad_test.exs index 6ab7e8ef78..bf4ba60003 100644 --- a/nx/test/nx/defn/grad_test.exs +++ b/nx/test/nx/defn/grad_test.exs @@ -698,7 +698,10 @@ defmodule Nx.Defn.GradTest do lhs = grad_mean_conv_y_general_stride_rhs_dilated(x, y) rhs = - Nx.tensor([[[[7.4000006, 8.2], [7.4000006, 8.2]]], [[[7.4000006, 8.2], [7.4000006, 8.2]]]]) + Nx.tensor([ + [[[7.4000006, 8.2], [7.4000006, 8.2]]], + [[[7.4000006, 8.2], [7.4000006, 8.2]]] + ]) assert_all_close(lhs, rhs) end @@ -1981,8 +1984,8 @@ defmodule Nx.Defn.GradTest do assert_all_close( svd_grad(Nx.tensor([[3, 0], [1, 2]])), Nx.tensor([ - [0.07228553295135498, 0.7500489950180054], - [1.113668441772461, 1.8945982456207275] + [1.368404507637024, -0.5419228672981262], + [-0.2197188436985016, 0.6067624092102051] ]) ) end @@ -1991,8 +1994,8 @@ defmodule Nx.Defn.GradTest do assert_all_close( svd_composed_grad(Nx.tensor([[3, 0], [1, 2]])), Nx.tensor([ - [22.44730567932129, 4.334394931793213], - [10.295409202575684, 9.27196216583252] + [22.86724090576172, 3.655829906463623], + [10.035255432128906, 8.769235610961914] ]) ) end @@ -2001,9 +2004,9 @@ defmodule Nx.Defn.GradTest do assert_all_close( svd_composed_grad(Nx.tensor([[3, 0], [1, 2], [1, 1]])), Nx.tensor([ - [25.990453720092773, 6.061026096343994], - [12.646490097045898, 10.775838851928711], - [10.656349182128906, 6.384178638458252] + [25.911056518554688, 6.1099162101745605], + [12.69705581665039, 10.84456729888916], + [10.668402671813965, 6.426826477050781] ]) ) end diff --git a/nx/test/nx/defn_test.exs b/nx/test/nx/defn_test.exs index dbb560839b..90fe9263dd 100644 --- a/nx/test/nx/defn_test.exs +++ b/nx/test/nx/defn_test.exs @@ -25,6 +25,10 @@ defmodule Nx.DefnTest do @tensor [1, 2, 3] defn(list_constant, do: Nx.tensor(@tensor)) + defn complex_constant do + Complex.new(1, :infinity) + end + test "from list" do assert %T{data: %Expr{op: :tensor}} = list_constant() end @@ -35,6 +39,11 @@ defmodule Nx.DefnTest do test "from binary" do assert %T{data: %Expr{op: :tensor}} = binary_constant() end + + test "complex literals" do + assert %T{data: %Expr{op: :constant, args: [%Complex{} = c]}} = complex_constant() + assert c == Complex.new(1, :infinity) + end end describe "Nx.tensor" do @@ -1176,6 +1185,48 @@ defmodule Nx.DefnTest do ) end + defn cond_upcast_float_literals(n) do + cond do + n == 1 -> 1.4 + n == 2 -> 2 + true -> n + end + end + + test "upcasts float literals based on the accumulated clause type" do + for input_type <- [f: 32, f: 64] do + assert %T{ + type: ^input_type, + data: %Expr{op: :cond, args: [[clause1, clause2], _last]} + } = + cond_upcast_float_literals(Nx.tensor(10.0, type: input_type)) + + assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [1.4]}}} = clause1 + assert {_, %T{type: ^input_type, data: %Expr{op: :constant, args: [2.0]}}} = clause2 + end + + for input_type <- [c: 64, c: 128] do + assert %T{ + type: ^input_type, + data: %Expr{op: :cond, args: [[clause1, clause2], _last]} + } = + cond_upcast_float_literals(Nx.tensor(10.0, type: input_type)) + + assert {_, + %T{ + type: ^input_type, + data: %Expr{op: :constant, args: [%Complex{re: 1.4, im: +0.0}]} + }} = clause1 + + assert {_, + %T{ + type: ^input_type, + data: %Expr{op: :constant, args: [%Complex{re: 2.0, im: +0.0}]} + }} = + clause2 + end + end + defn cond_list(a) do if Nx.any(a), do: 1, else: -1 end @@ -1995,7 +2046,7 @@ defmodule Nx.DefnTest do defn while_mixed_return(a, b) do while {a, b}, Nx.less(a, 10) do - %{a: a, b: b} + %{"a" => a, "b" => b} end end @@ -2003,7 +2054,7 @@ defmodule Nx.DefnTest do expected_error = [ "the do-block in while must return tensors with the same shape, type, and names as the initial arguments.", - "\n\n\e\\[32m\n<<<<< Body \\(do-block\\) <<<<<\n%\\{a: #Nx.Tensor<\n s32\n >, b: #Nx.Tensor<\n s32\n >\\}", + "\n\n\e\\[32m\n<<<<< Body \\(do-block\\) <<<<<\n%\\{\"a\" => #Nx.Tensor<\n s32\n >, \"b\" => #Nx.Tensor<\n s32\n >\\}", "\n==========\n\e\\[31m\\{#Nx.Tensor<\n s32\n >, #Nx.Tensor<\n s32\n >\\}\n>>>>> Initial >>>>>\n\e\\[0m\n$" ] |> IO.iodata_to_binary() diff --git a/nx/test/nx/lin_alg_test.exs b/nx/test/nx/lin_alg_test.exs index 36a48159e1..d8c8fe2bb4 100644 --- a/nx/test/nx/lin_alg_test.exs +++ b/nx/test/nx/lin_alg_test.exs @@ -574,11 +574,11 @@ defmodule Nx.LinAlgTest do assert_all_close( eigenvecs, Nx.tensor([ - [0.112, -0.005, -0.831, -0.436, -0.328], - [0.395, 0.163, 0.530, -0.537, -0.497], - [0.427, 0.326, -0.133, 0.700, -0.452], - [0.603, -0.783, -0.007, 0.079, 0.130], - [0.534, 0.504, -0.104, -0.160, 0.651] + [0.112, 0.004, 0.828, -0.440, -0.328], + [0.395, -0.163, -0.533, -0.534, -0.497], + [0.427, -0.326, 0.137, 0.700, -0.452], + [0.603, 0.783, 0.008, 0.079, 0.130], + [0.534, -0.504, 0.103, -0.160, 0.651] ]), atol: 1.0e-3, rtol: 1.0e-3 @@ -600,28 +600,20 @@ defmodule Nx.LinAlgTest do # Eigenvalues assert eigenvals == - Nx.tensor([Complex.new(-5, 0), Complex.new(3, 0), Complex.new(1, 0)]) + Nx.tensor([ + Complex.new(-5, 0), + Complex.new(3, 0), + Complex.new(0.9999998807907104, 0) + ]) # Eigenvectors assert_all_close( eigenvecs, - Nx.tensor([ - [ - Complex.new(-0.408, 0.0), - Complex.new(-0.0, 0.707), - Complex.new(0.577, 0.0) - ], - [ - Complex.new(-0.0, -0.816), - Complex.new(0.0, 0.0), - Complex.new(0.0, -0.577) - ], - [ - Complex.new(0.408, 0.0), - Complex.new(-0.0, 0.707), - Complex.new(-0.577, 0.0) - ] - ]), + ~MAT[ + 0.0000-0.4082i 0.7071-0.0i 00.5773-0.0000i + 0.8164-0.0000i 0.0000+0.0i 00.0000-0.5773i + 0.0000+0.4082i 0.7071-0.0i -0.5773-0.0000i + ], atol: 1.0e-3, rtol: 1.0e-3 ) @@ -638,42 +630,56 @@ defmodule Nx.LinAlgTest do for type <- [f: 32, c: 64], reduce: key do key -> # Unitary matrix from a random matrix - {base, key} = Nx.Random.uniform(key, shape: {3, 3, 3}, type: type) + {base, key} = Nx.Random.uniform(key, shape: {2, 3, 3}, type: type) {q, _} = Nx.LinAlg.qr(base) # Different eigenvalues from random values evals_test = - [{100, 30}, {4, 6}, {0.7, 0.9}] - |> Enum.map(fn {low, up} -> - if :rand.uniform() - 0.5 > 0 do - {low, up} - else - {-up, -low} - end - end) - |> Enum.map(fn {low, up} -> - rand = :rand.uniform() * (up - low) + low - Nx.tensor([rand], type: :f64) + [100, 10, 1] + |> Enum.map(fn magnitude -> + sign = + if :rand.uniform() - 0.5 > 0 do + 1 + else + -1 + end + + rand = :rand.uniform() * magnitude * 0.1 + magnitude + rand * sign end) - |> Nx.concatenate() + |> Nx.tensor(type: type) + + evals_test_diag = + evals_test + |> Nx.make_diagonal() + |> Nx.reshape({1, 3, 3}) + |> Nx.tile([2, 1, 1]) # Hermitian matrix with different eigenvalues # using A = A^* = Q^*.Λ.Q. a = q |> Nx.LinAlg.adjoint() - |> Nx.multiply(evals_test) + |> Nx.dot([2], [0], evals_test_diag, [1], [0]) |> Nx.dot([2], [0], q, [1], [0]) # Eigenvalues and eigenvectors - assert {evals, evecs} = Nx.LinAlg.eigh(a, max_iter: 10_000) - assert_all_close(evals_test, evals, atol: 1.0e-1) + assert {evals, evecs} = Nx.LinAlg.eigh(a, eps: 1.0e-8) + + assert_all_close(evals_test, evals[0], atol: 1.0e-8) + assert_all_close(evals_test, evals[1], atol: 1.0e-8) + + evals = + evals + |> Nx.vectorize(:x) + |> Nx.make_diagonal() + |> Nx.devectorize(keep_names: false) # Eigenvalue equation - evecs_evals = Nx.multiply(evecs, evals) - a_evecs = Nx.dot(a, [2], [0], evecs, [1], [0]) + evecs_evals = Nx.dot(evecs, [2], [0], evals, [1], [0]) + a_evecs = Nx.dot(evecs_evals, [2], [0], Nx.LinAlg.adjoint(evecs), [1], [0]) - assert_all_close(evecs_evals, a_evecs, atol: 1.0e-1) + assert_all_close(a, a_evecs, atol: 1.0e-8) key end end @@ -734,10 +740,10 @@ defmodule Nx.LinAlgTest do assert_all_close( u, Nx.tensor([ - [0.141, 0.825, -0.001, 0.019], - [0.344, 0.426, 0.00200, 0.382], - [0.547, 0.028, 0.0, -0.822], - [0.75, -0.370, -0.001, 0.421] + [0.141, -0.825, -0.001, 0.019], + [0.344, -0.426, 0.00200, 0.382], + [0.547, -0.028, 0.0, -0.822], + [0.75, 0.370, -0.001, 0.421] ]), atol: 1.0e-3, rtol: 1.0e-3 @@ -747,8 +753,8 @@ defmodule Nx.LinAlgTest do assert_all_close( Nx.tensor([ - [0.505, 0.575, 0.644], - [-0.761, -0.057, 0.647], + [0.504, 0.575, 0.644], + [0.761, 0.057, -0.647], [-0.408, 0.816, -0.408] ]), v, @@ -801,9 +807,9 @@ defmodule Nx.LinAlgTest do assert_all_close( u, Nx.tensor([ - [0.336, -0.407, -0.849], - [0.037, -0.895, 0.444], - [0.941, 0.181, 0.286] + [0.335, 0.408, 0.849], + [0.036, 0.895, -0.445], + [0.941, -0.18, -0.286] ]), atol: 1.0e-3, rtol: 1.0e-3 @@ -815,9 +821,9 @@ defmodule Nx.LinAlgTest do assert_all_close( Nx.tensor([ - [0.035, 0.0869, 0.996], - [-0.091, -0.992, 0.09], - [-0.995, 0.094, 0.027] + [0.035, 0.0856, 0.996], + [0.092, 0.992, -0.089], + [0.995, -0.094, -0.027] ]), v, atol: 1.0e-3, diff --git a/nx/test/nx/non_finite_test.exs b/nx/test/nx/non_finite_test.exs index 563657f845..423456746b 100644 --- a/nx/test/nx/non_finite_test.exs +++ b/nx/test/nx/non_finite_test.exs @@ -11,7 +11,6 @@ defmodule Nx.NonFiniteTest do @arg Complex.new(:infinity, 3) @arg2 Complex.new(-2, 4) - @inf_inf Complex.new(:infinity, :infinity) @nan_nan Complex.new(:nan, :nan) @one Nx.tensor(1, type: {:u, 8}) @@ -19,11 +18,11 @@ defmodule Nx.NonFiniteTest do describe "unary operations" do test "exp" do - assert Nx.exp(@arg) == Nx.tensor(@inf_inf) + assert Nx.exp(@arg) == Nx.tensor(Complex.new(:neg_infinity, :infinity)) end test "expm1" do - assert Nx.expm1(@arg) == Nx.tensor(@inf_inf) + assert Nx.expm1(@arg) == Nx.tensor(Complex.new(:neg_infinity, :infinity)) end test "log" do @@ -51,11 +50,11 @@ defmodule Nx.NonFiniteTest do end test "cosh" do - assert Nx.cosh(@arg) == Nx.tensor(@inf_inf) + assert Nx.cosh(@arg) == Nx.tensor(Complex.new(:neg_infinity, :infinity)) end test "sinh" do - assert Nx.sinh(@arg) == Nx.tensor(@inf_inf) + assert Nx.sinh(@arg) == Nx.tensor(Complex.new(:neg_infinity, :infinity)) end test "tanh" do @@ -145,7 +144,7 @@ defmodule Nx.NonFiniteTest do end test "product" do - assert Nx.product(Nx.tensor([@arg, @arg2])) == Nx.multiply(@arg, @arg2) + assert Nx.product(Nx.tensor([@arg, @arg2])) == Nx.tensor(Complex.new(:nan, :nan)) assert Nx.product(Nx.tensor(:infinity)) == Nx.tensor(:infinity) end diff --git a/nx/test/nx/serving_test.exs b/nx/test/nx/serving_test.exs index f6aa54368d..49d72573f0 100644 --- a/nx/test/nx/serving_test.exs +++ b/nx/test/nx/serving_test.exs @@ -1288,7 +1288,8 @@ defmodule Nx.ServingTest do ] Node.spawn_link(:"secondary@127.0.0.1", DistributedServings, :multiply, [parent, opts]) - assert_receive {_, :join, Nx.Serving, _} + assert_receive {_, :join, name, _} + assert name == config.test batch = Nx.Batch.concatenate([Nx.tensor([1, 2])]) @@ -1327,14 +1328,16 @@ defmodule Nx.ServingTest do opts2 = Keyword.put(opts, :distribution_weight, 4) Node.spawn_link(:"secondary@127.0.0.1", DistributedServings, :multiply, [parent, opts]) - assert_receive {_, :join, Nx.Serving, pids} + assert_receive {_, :join, name, pids} assert length(pids) == 1 + assert name == config.test Node.spawn_link(:"tertiary@127.0.0.1", DistributedServings, :multiply, [parent, opts2]) - assert_receive {_, :join, Nx.Serving, pids} + assert_receive {_, :join, name, pids} assert length(pids) == 4 + assert name == config.test - members = :pg.get_members(Nx.Serving.PG, Nx.Serving) + members = :pg.get_members(Nx.Serving.PG, config.test) assert length(members) == 5 end @@ -1356,7 +1359,8 @@ defmodule Nx.ServingTest do args = [parent, opts] Node.spawn_link(:"secondary@127.0.0.1", DistributedServings, :add_five_round_about, args) - assert_receive {_, :join, Nx.Serving, _} + assert_receive {_, :join, name, _} + assert name == config.test batch = Nx.Batch.concatenate([Nx.tensor([1, 2])]) @@ -1412,7 +1416,8 @@ defmodule Nx.ServingTest do ] Node.spawn_link(:"tertiary@127.0.0.1", DistributedServings, :multiply, [parent, opts]) - assert_receive {_, :join, Nx.Serving, _} + assert_receive {_, :join, name, _} + assert name == config.test batch = Nx.Batch.concatenate([Nx.tensor([1, 2])]) diff --git a/torchx/CHANGELOG.md b/torchx/CHANGELOG.md index db3c8f8078..f25ae812bf 100644 --- a/torchx/CHANGELOG.md +++ b/torchx/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## v0.9.2 (2024-11-16) + + * Update to latest Nx + ## v0.9.1 (2024-10-08) * Update to latest Nx diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 18f453dc6d..77308ce94d 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -506,7 +506,7 @@ defmodule Torchx.Backend do result = if axes == [] do - aggregate_whole_tensor(t, keep_axes, &Torchx.product/1) + aggregate_whole_tensor(t, &Torchx.product/1) else aggregate_over_axes(t, axes, keep_axes, &Torchx.product/3) end @@ -523,7 +523,7 @@ defmodule Torchx.Backend do result = if axes == [] do - aggregate_whole_tensor(t, keep_axes, &Torchx.any/1) + aggregate_whole_tensor(t, &Torchx.any/1) else aggregate_over_axes(t, axes, keep_axes, &Torchx.any/3) end @@ -538,7 +538,7 @@ defmodule Torchx.Backend do result = if axes == [] do - aggregate_whole_tensor(t, keep_axes, &Torchx.all/1) + aggregate_whole_tensor(t, &Torchx.all/1) else aggregate_over_axes(t, axes, keep_axes, &Torchx.all/3) end @@ -563,18 +563,10 @@ defmodule Torchx.Backend do |> to_nx(out) end - defp aggregate_whole_tensor(t, keep_axes, fun) when is_function(fun, 1) do - result = - t - |> from_nx() - |> then(fun) - - if keep_axes do - shape = t.shape |> Tuple.delete_at(-1) |> Tuple.append(1) - Torchx.reshape(result, shape) - else - result - end + defp aggregate_whole_tensor(t, fun) when is_function(fun, 1) do + t + |> from_nx() + |> then(fun) end defp aggregate_over_axes(t, axes, keep_axes, fun) when is_function(fun, 3) do diff --git a/torchx/mix.exs b/torchx/mix.exs index 070718c3b0..8a0f3cd368 100644 --- a/torchx/mix.exs +++ b/torchx/mix.exs @@ -41,8 +41,8 @@ defmodule Torchx.MixProject do defp deps do [ - {:nx, "~> 0.9.0"}, - # {:nx, path: "../nx"}, + # {:nx, "~> 0.9.0"}, + {:nx, path: "../nx"}, {:ex_doc, "~> 0.29", only: :docs} ] end diff --git a/torchx/mix.lock b/torchx/mix.lock index b3a93517e0..adcdad2238 100644 --- a/torchx/mix.lock +++ b/torchx/mix.lock @@ -1,5 +1,5 @@ %{ - "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, + "complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"}, "earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"}, "ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"}, "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, diff --git a/torchx/test/torchx/nx_linalg_doctest_test.exs b/torchx/test/torchx/nx_linalg_doctest_test.exs index 9f3c6eca52..30e75dafc5 100644 --- a/torchx/test/torchx/nx_linalg_doctest_test.exs +++ b/torchx/test/torchx/nx_linalg_doctest_test.exs @@ -18,7 +18,7 @@ defmodule Torchx.NxLinAlgDoctestTest do invert: 1, determinant: 1, pinv: 2, - least_squares: 2 + least_squares: 3 ] # Results do not match but properties are respected