From 7d0f261042b61a8260abd18345d0926b9552e0a8 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 13 Nov 2023 15:15:20 -0300 Subject: [PATCH] feat: Nx.reverse (#25) --- lib/candlex/backend.ex | 10 ++++++- lib/candlex/native.ex | 1 + native/candlex/src/lib.rs | 1 + native/candlex/src/tensors.rs | 16 +++++++++++ test/candlex_test.exs | 52 +++++++++++++++++++++++++++++++++++ 5 files changed, 79 insertions(+), 1 deletion(-) diff --git a/lib/candlex/backend.ex b/lib/candlex/backend.ex index 5c0c701..c508008 100644 --- a/lib/candlex/backend.ex +++ b/lib/candlex/backend.ex @@ -609,6 +609,15 @@ defmodule Candlex.Backend do ) end + @impl true + def reverse(%T{} = out, %T{} = tensor, axes) do + tensor + |> from_nx() + |> Native.reverse(axes) + |> unwrap!() + |> to_nx(out) + end + # Shape @impl true @@ -877,7 +886,6 @@ defmodule Candlex.Backend do :lu, :product, :qr, - :reverse, :sort ] do @impl true diff --git a/lib/candlex/native.ex b/lib/candlex/native.ex index b85f0b6..d6fddd7 100644 --- a/lib/candlex/native.ex +++ b/lib/candlex/native.ex @@ -50,6 +50,7 @@ defmodule Candlex.Native do def slice_scatter(_tensor, _src, _dim, _start), do: error() def pad_with_zeros(_tensor, _left, _right), do: error() def clamp(_tensor, _min, _max), do: error() + def reverse(_tensor, _axes), do: error() for op <- [ :abs, diff --git a/native/candlex/src/lib.rs b/native/candlex/src/lib.rs index f4a3398..98d47bc 100644 --- a/native/candlex/src/lib.rs +++ b/native/candlex/src/lib.rs @@ -62,6 +62,7 @@ rustler::init! { tensors::chunk, tensors::squeeze, tensors::clamp, + tensors::reverse, tensors::arange, tensors::to_type, tensors::broadcast_to, diff --git a/native/candlex/src/tensors.rs b/native/candlex/src/tensors.rs index 627c4f4..780046b 100644 --- a/native/candlex/src/tensors.rs +++ b/native/candlex/src/tensors.rs @@ -133,6 +133,22 @@ pub fn clamp(t: ExTensor, min_val: ExTensor, max_val: ExTensor) -> Result) -> Result { + let device = t.device(); + let t_dims = t.dims(); + let mut new_t = t.clone(); + + for dim in dims { + new_t = new_t.index_select( + &Tensor::arange_step::((t_dims[dim] as i64) - 1, -1, -1, device)?, + dim, + )?; + } + + Ok(ExTensor::new(new_t)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn rsqrt(t: ExTensor) -> Result { Ok(ExTensor::new(t.sqrt()?.recip()?)) diff --git a/test/candlex_test.exs b/test/candlex_test.exs index 7a30982..1be9ecf 100644 --- a/test/candlex_test.exs +++ b/test/candlex_test.exs @@ -2291,6 +2291,58 @@ defmodule CandlexTest do |> assert_equal(t([[1], [1]])) end + test "reverse" do + t([1, 2, 3]) + |> Nx.reverse() + |> assert_equal(t([3, 2, 1])) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.reverse() + |> assert_equal( + t([ + [6, 5, 4], + [3, 2, 1] + ]) + ) + + t([1, 2, 3], names: [:x]) + |> Nx.reverse(axes: [:x]) + |> assert_equal(t([3, 2, 1])) + + t([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) + |> Nx.reverse(axes: [:x]) + |> assert_equal( + t([ + [4, 5, 6], + [1, 2, 3] + ]) + ) + + t([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) + |> Nx.reverse(axes: [:y]) + |> assert_equal( + t([ + [3, 2, 1], + [6, 5, 4] + ]) + ) + + Nx.iota({2, 2, 2}, type: :f32, names: [:x, :y, :z]) + |> Nx.reverse(axes: [:x, :z]) + |> assert_equal( + t([ + [ + [5.0, 4.0], + [7.0, 6.0] + ], + [ + [1.0, 0.0], + [3.0, 2.0] + ] + ]) + ) + end + if Candlex.Backend.cuda_available?() do test "different devices" do t([1, 2, 3], backend: {Candlex.Backend, device: :cpu})