Skip to content

Commit

Permalink
Merge branch 'main' into guide_fft
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Jan 27, 2025
2 parents c9c3e62 + 9cfcd05 commit 26223c6
Show file tree
Hide file tree
Showing 36 changed files with 764 additions and 485 deletions.
7 changes: 7 additions & 0 deletions exla/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## v0.9.2 (2024-11-16)

### Enhancements

* Support cross-compilation for use with Nerves
* Optimize LU with a custom call

## v0.9.1 (2024-10-08)

### Enhancements
Expand Down
4 changes: 2 additions & 2 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -787,8 +787,8 @@ defmodule EXLA.Defn do
transform = Keyword.fetch!(opts, :transform_a)

case Value.get_typespec(b).shape do
{_} = b_shape ->
b_shape = Tuple.append(b_shape, 1)
{dim} ->
b_shape = {dim, 1}

b =
b
Expand Down
2 changes: 1 addition & 1 deletion exla/mix.lock
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
%{
"benchee": {:hex, :benchee, "1.1.0", "f3a43817209a92a1fade36ef36b86e1052627fd8934a8b937ac9ab3a76c43062", [:mix], [{:deep_merge, "~> 1.0", [hex: :deep_merge, repo: "hexpm", optional: false]}, {:statistex, "~> 1.0", [hex: :statistex, repo: "hexpm", optional: false]}], "hexpm", "7da57d545003165a012b587077f6ba90b89210fd88074ce3c60ce239eb5e6d93"},
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"complex": {:hex, :complex, "0.6.0", "b0130086a7a8c33574d293b2e0e250f4685580418eac52a5658a4bd148f3ccf1", [:mix], [], "hexpm", "0a5fa95580dcaf30fcd60fe1aaf24327c0fe401e98c24d892e172e79498269f9"},
"deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"},
"earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"},
"elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"},
Expand Down
8 changes: 8 additions & 0 deletions nx/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Changelog

## v0.9.2 (2024-11-16)

### Bug fixes

* [Nx] Fix deprecation warnings on latest Elixir
* [Nx.LinAlg] Fix `least_squares` implementation
* [Nx.Random] Fix `Nx.Random.shuffle` repeating a single value in certain cases on GPU

## v0.9.1 (2024-10-08)

### Deprecations
Expand Down
2 changes: 1 addition & 1 deletion nx/guides/advanced/aggregation.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ m = ~MAT[
]
```

First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](<https://en.wikipedia.org/wiki/Hadamard_product_(matrices)#:~:text=In%20mathematics%2C%20the%20Hadamard%20product,elements%20i%2C%20j%20of%20the>), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights.
First, we'll compute the full-tensor aggregation. The calculations are developed below. We calculate an "array product" (aka [Hadamard product](https://en.wikipedia.org/wiki/Hadamard_product_(matrices)#:~:text=In%20mathematics%2C%20the%20Hadamard%20product,elements%20i%2C%20j%20of%20the), an element-wise product) of our tensor with the tensor of weights, then sum all the elements and divide by the sum of the weights.

```elixir
w = ~MAT[
Expand Down
66 changes: 55 additions & 11 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ defmodule Nx do
config: [nx: [default_backend: EXLA.Backend]]
)
Or by calling `Nx.global_default_backend/1` (less preferrable):
Or by calling `Nx.global_default_backend/1` (less preferable):
Nx.global_default_backend(EXLA.Backend)
Expand Down Expand Up @@ -614,7 +614,7 @@ defmodule Nx do
>
Certain backends and compilers support 8-bit floats. The precision
iomplementation of 8-bit floats may change per backend, so you must
implementation of 8-bit floats may change per backend, so you must
be careful when transferring data across. The binary backend implements
F8E5M2:
Expand Down Expand Up @@ -943,7 +943,7 @@ defmodule Nx do

for t <-
[:u2, :u4, :u8, :u16, :u32, :u64, :s2, :s4, :s8, :s16, :s32, :s64] ++
[:f8, :bf16, :f16, :f32, :f64] do
[:f8, :bf16, :f16, :f32, :f64, :c64, :c128] do
@doc """
Short-hand function for creating tensor of type `#{t}`.
Expand Down Expand Up @@ -1305,7 +1305,7 @@ defmodule Nx do
out =
case shape do
{n} ->
intermediate_shape = Tuple.duplicate(1, tuple_size(out_shape) - 1) |> Tuple.append(n)
intermediate_shape = Tuple.duplicate(1, tuple_size(out_shape) - 1) |> tuple_append(n)

backend.eye(
%T{type: type, shape: intermediate_shape, names: names},
Expand Down Expand Up @@ -1609,7 +1609,7 @@ defmodule Nx do
t
else
diag_length = div(Nx.size(t), Tuple.product(batch_shape))
Nx.reshape(t, Tuple.append(batch_shape, diag_length))
Nx.reshape(t, tuple_append(batch_shape, diag_length))
end
end

Expand Down Expand Up @@ -7684,7 +7684,7 @@ defmodule Nx do
number of indices, while `updates` must have a compatible `{n, ...j}` shape, such that
`i + j = rank(tensor)`.
In case of repeating indices, the result is non-determinstic, since the operation happens
In case of repeating indices, the result is non-deterministic, since the operation happens
in parallel when running on devices such as the GPU.
See also: `indexed_add/3`, `put_slice/3`.
Expand Down Expand Up @@ -10365,9 +10365,9 @@ defmodule Nx do
if opts[:keep_axis] do
new_shape
|> Tuple.delete_at(tuple_size(new_shape) - 1)
|> Tuple.append(:auto)
|> tuple_append(:auto)
else
Tuple.append(new_shape, :auto)
tuple_append(new_shape, :auto)
end

reshaped_tensor = reshape(tensor, flattened_shape)
Expand Down Expand Up @@ -12919,6 +12919,11 @@ defmodule Nx do
of summing the element-wise products in the window across
each input channel.
> #### Kernel Reflection {: .info}
>
> See the note at the end of this section for more details
> on the convention for kernel reflection and conjugation.
The ranks of both `input` and `kernel` must match. By
default, both `input` and `kernel` are expected to have shapes
of the following form:
Expand Down Expand Up @@ -13000,6 +13005,45 @@ defmodule Nx do
in the same way as with `:feature_group_size`, however, the input
tensor will be split into groups along the batch dimension.
> #### Convolution vs Correlation {: .tip}
>
> `conv/3` does not perform reversion of the kernel.
> This means that if you come from a Signal Processing background,
> you might treat it as a cross-correlation operation instead of a convolution.
>
> This function is not exactly a cross-correlation function, as it does not
> perform conjugation of the kernel, as is done in [scipy.signal.correlate](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.correlate.html).
> This can be remedied via `Nx.conjugate/1`, as seen below:
>
> ```elixir
> kernel =
> if Nx.Type.complex?(Nx.type(kernel)) do
> Nx.conjugate(kernel)
> else
> kernel
> end
>
> Nx.conv(tensor, kernel)
> ```
>
> If you need the proper Signal Processing convolution, such as the one in
> [scipy.signal.convolve](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.convolve.html),
> you can use `reverse/2`, like in the example:
>
> ```elixir
> reversal_axes =
> case Nx.rank(kernel) do
> 0 -> []
> 1 -> [1]
> 2 -> [0, 1]
> _ -> Enum.drop(Nx.axes(kernel), 2)
> end
>
> kernel = Nx.reverse(kernel, axes: reversal_axes)
>
> Nx.conv(tensor, kernel)
> ```
## Examples
iex> left = Nx.iota({1, 1, 3, 3})
Expand Down Expand Up @@ -13554,7 +13598,7 @@ defmodule Nx do
end)
|> Nx.stack()
|> Nx.revectorize(vectorized_axes,
target_shape: Tuple.append(List.to_tuple(lengths), :auto)
target_shape: tuple_append(List.to_tuple(lengths), :auto)
)

Nx.gather(tensor, idx)
Expand Down Expand Up @@ -14288,7 +14332,7 @@ defmodule Nx do
Nx.Shared.optional(:take_along_axis, [tensor, indices, [axis: axis]], out, fn
tensor, indices, _opts ->
axes_range = axes(indices)
new_axis_shape = Tuple.append(shape(indices), 1)
new_axis_shape = tuple_append(shape(indices), 1)

full_indices =
axes_range
Expand Down Expand Up @@ -14471,7 +14515,7 @@ defmodule Nx do
indices = devectorize(indices, keep_names: false)

iota_shape =
indices.shape |> Tuple.delete_at(tuple_size(indices.shape) - 1) |> Tuple.append(1)
indices.shape |> Tuple.delete_at(tuple_size(indices.shape) - 1) |> tuple_append(1)

offset_axes = (offset - 1)..0//-1

Expand Down
2 changes: 1 addition & 1 deletion nx/lib/nx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ defmodule Nx.Backend do
First we will attempt to call the optional callback itself
(one of the many callbacks defined below), then we attempt
to call this callback (which is also optional), then we
fallback to the default iomplementation.
fallback to the default implementation.
"""
@callback optional(atom, [term], fun) :: tensor

Expand Down
76 changes: 30 additions & 46 deletions nx/lib/nx/binary_backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -755,11 +755,29 @@ defmodule Nx.BinaryBackend do

defp element_equal(_, :nan, _), do: 0
defp element_equal(_, _, :nan), do: 0
defp element_equal(_, a, b), do: boolean_as_number(a == b)

defp element_equal(_, a, b) do
bool =
case {a, b} do
{%Complex{re: re_a, im: im_a}, b} when is_number(b) ->
re_a == b and im_a == 0

{a, %Complex{re: re_b, im: im_b}} when is_number(a) ->
a == re_b and im_b == 0

{a, b} ->
a == b
end

boolean_as_number(bool)
end

defp element_not_equal(_, :nan, _), do: 1
defp element_not_equal(_, _, :nan), do: 1
defp element_not_equal(_, a, b), do: boolean_as_number(a != b)

defp element_not_equal(out, a, b) do
1 - element_equal(out, a, b)
end

defp element_logical_and(_, a, b), do: boolean_as_number(as_boolean(a) and as_boolean(b))
defp element_logical_or(_, a, b), do: boolean_as_number(as_boolean(a) or as_boolean(b))
Expand Down Expand Up @@ -1240,25 +1258,6 @@ defmodule Nx.BinaryBackend do
output_batch_groups |> Enum.with_index() |> Enum.map(fn {x, i} -> {x, rem(i, groups)} end)
end

@impl true
def eigh(
{%{type: output_type} = eigenvals_holder, eigenvecs_holder},
%{type: input_type, shape: input_shape} = tensor,
opts
) do
bin = to_binary(tensor)
rank = tuple_size(input_shape)
n = elem(input_shape, rank - 1)

{eigenvals, eigenvecs} =
bin_batch_reduce(bin, n * n, input_type, {<<>>, <<>>}, fn matrix, {vals_acc, vecs_acc} ->
{vals, vecs} = B.Matrix.eigh(matrix, input_type, {n, n}, output_type, opts)
{vals_acc <> vals, vecs_acc <> vecs}
end)

{from_binary(eigenvals_holder, eigenvals), from_binary(eigenvecs_holder, eigenvecs)}
end

@impl true
def lu(
{%{type: p_type} = p_holder, %{type: l_type} = l_holder, %{type: u_type} = u_holder},
Expand Down Expand Up @@ -1492,7 +1491,7 @@ defmodule Nx.BinaryBackend do
dilations = opts[:window_dilations]

%T{shape: padded_shape, type: {_, size} = type} =
tensor = Nx.pad(tensor, acc, Enum.map(padding_config, &Tuple.append(&1, 0)))
tensor = Nx.pad(tensor, acc, Enum.map(padding_config, &tuple_append(&1, 0)))

acc = scalar_to_number(acc)

Expand Down Expand Up @@ -1608,7 +1607,7 @@ defmodule Nx.BinaryBackend do
init_value = scalar_to_number(init_value)

%T{shape: padded_shape, type: {_, size} = type} =
tensor = Nx.pad(t, init_value, Enum.map(padding, &Tuple.append(&1, 0)))
tensor = Nx.pad(t, init_value, Enum.map(padding, &tuple_append(&1, 0)))

input_data = to_binary(tensor)
input_weighted_shape = weighted_shape(padded_shape, size, window_dimensions)
Expand Down Expand Up @@ -2077,7 +2076,7 @@ defmodule Nx.BinaryBackend do
for <<match!(x, 0) <- data>>, into: <<>> do
x = read!(x, 0)

case x do
generated_case x do
%Complex{re: re} when float_output? and real_output? ->
number_to_binary(re, output_type)

Expand Down Expand Up @@ -2253,17 +2252,16 @@ defmodule Nx.BinaryBackend do
end
end

output_data =
match_types [out.type] do
for row <- result, %Complex{re: re, im: im} <- row, into: <<>> do
re = if abs(re) <= eps, do: 0, else: re
im = if abs(im) <= eps, do: 0, else: im
%{type: {_, output_size}} = out

<<write!(Complex.new(re, im), 0)>>
end
output_data =
for row <- result, %Complex{re: re, im: im} <- row, into: <<>> do
re = if abs(re) <= eps, do: 0, else: re
im = if abs(im) <= eps, do: 0, else: im
<<write_complex(re, im, div(output_size, 2))::binary>>
end

intermediate_shape = out.shape |> Tuple.delete_at(axis) |> Tuple.append(n)
intermediate_shape = out.shape |> Tuple.delete_at(axis) |> tuple_append(n)

permuted_output = from_binary(%{out | shape: intermediate_shape}, output_data)

Expand Down Expand Up @@ -2391,20 +2389,6 @@ defmodule Nx.BinaryBackend do
end
end

defp bin_zip_reduce(t1, [], t2, [], type, acc, fun) do
%{type: {_, s1}} = t1
%{type: {_, s2}} = t2
b1 = to_binary(t1)
b2 = to_binary(t2)

match_types [t1.type, t2.type] do
for <<d1::size(s1)-bitstring <- b1>>, <<d2::size(s2)-bitstring <- b2>>, into: <<>> do
{result, _} = fun.(d1, d2, acc)
scalar_to_binary!(result, type)
end
end
end

defp bin_zip_reduce(t1, [_ | _] = axes1, t2, [_ | _] = axes2, type, acc, fun) do
{_, s1} = t1.type
{_, s2} = t2.type
Expand Down
Loading

0 comments on commit 26223c6

Please sign in to comment.