diff --git a/lib/candlex/backend.ex b/lib/candlex/backend.ex index 5c0c701..c508008 100644 --- a/lib/candlex/backend.ex +++ b/lib/candlex/backend.ex @@ -609,6 +609,15 @@ defmodule Candlex.Backend do ) end + @impl true + def reverse(%T{} = out, %T{} = tensor, axes) do + tensor + |> from_nx() + |> Native.reverse(axes) + |> unwrap!() + |> to_nx(out) + end + # Shape @impl true @@ -877,7 +886,6 @@ defmodule Candlex.Backend do :lu, :product, :qr, - :reverse, :sort ] do @impl true diff --git a/lib/candlex/native.ex b/lib/candlex/native.ex index b85f0b6..d6fddd7 100644 --- a/lib/candlex/native.ex +++ b/lib/candlex/native.ex @@ -50,6 +50,7 @@ defmodule Candlex.Native do def slice_scatter(_tensor, _src, _dim, _start), do: error() def pad_with_zeros(_tensor, _left, _right), do: error() def clamp(_tensor, _min, _max), do: error() + def reverse(_tensor, _axes), do: error() for op <- [ :abs, diff --git a/native/candlex/src/lib.rs b/native/candlex/src/lib.rs index f4a3398..98d47bc 100644 --- a/native/candlex/src/lib.rs +++ b/native/candlex/src/lib.rs @@ -62,6 +62,7 @@ rustler::init! { tensors::chunk, tensors::squeeze, tensors::clamp, + tensors::reverse, tensors::arange, tensors::to_type, tensors::broadcast_to, diff --git a/native/candlex/src/tensors.rs b/native/candlex/src/tensors.rs index 627c4f4..709c522 100644 --- a/native/candlex/src/tensors.rs +++ b/native/candlex/src/tensors.rs @@ -133,6 +133,23 @@ pub fn clamp(t: ExTensor, min_val: ExTensor, max_val: ExTensor) -> Result) -> Result { + let device = t.device(); + let t_dims = t.dims(); + let mut new_t = t.clone(); + + for i in dims { + new_t = new_t.index_select( + &Tensor::arange_step::(t_dims[i] as i64, 0, -1, device)? + .broadcast_sub(&Tensor::new(1i64, device)?)?, + i, + )?; + } + + Ok(ExTensor::new(new_t)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn rsqrt(t: ExTensor) -> Result { Ok(ExTensor::new(t.sqrt()?.recip()?)) diff --git a/test/candlex_test.exs b/test/candlex_test.exs index 7a30982..1be9ecf 100644 --- a/test/candlex_test.exs +++ b/test/candlex_test.exs @@ -2291,6 +2291,58 @@ defmodule CandlexTest do |> assert_equal(t([[1], [1]])) end + test "reverse" do + t([1, 2, 3]) + |> Nx.reverse() + |> assert_equal(t([3, 2, 1])) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.reverse() + |> assert_equal( + t([ + [6, 5, 4], + [3, 2, 1] + ]) + ) + + t([1, 2, 3], names: [:x]) + |> Nx.reverse(axes: [:x]) + |> assert_equal(t([3, 2, 1])) + + t([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) + |> Nx.reverse(axes: [:x]) + |> assert_equal( + t([ + [4, 5, 6], + [1, 2, 3] + ]) + ) + + t([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) + |> Nx.reverse(axes: [:y]) + |> assert_equal( + t([ + [3, 2, 1], + [6, 5, 4] + ]) + ) + + Nx.iota({2, 2, 2}, type: :f32, names: [:x, :y, :z]) + |> Nx.reverse(axes: [:x, :z]) + |> assert_equal( + t([ + [ + [5.0, 4.0], + [7.0, 6.0] + ], + [ + [1.0, 0.0], + [3.0, 2.0] + ] + ]) + ) + end + if Candlex.Backend.cuda_available?() do test "different devices" do t([1, 2, 3], backend: {Candlex.Backend, device: :cpu})