Skip to content

Commit

Permalink
feat: Nx.reverse (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy authored Nov 13, 2023
1 parent d372c16 commit 7d0f261
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 1 deletion.
10 changes: 9 additions & 1 deletion lib/candlex/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,15 @@ defmodule Candlex.Backend do
)
end

@impl true
def reverse(%T{} = out, %T{} = tensor, axes) do
tensor
|> from_nx()
|> Native.reverse(axes)
|> unwrap!()
|> to_nx(out)
end

# Shape

@impl true
Expand Down Expand Up @@ -877,7 +886,6 @@ defmodule Candlex.Backend do
:lu,
:product,
:qr,
:reverse,
:sort
] do
@impl true
Expand Down
1 change: 1 addition & 0 deletions lib/candlex/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ defmodule Candlex.Native do
def slice_scatter(_tensor, _src, _dim, _start), do: error()
def pad_with_zeros(_tensor, _left, _right), do: error()
def clamp(_tensor, _min, _max), do: error()
def reverse(_tensor, _axes), do: error()

for op <- [
:abs,
Expand Down
1 change: 1 addition & 0 deletions native/candlex/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ rustler::init! {
tensors::chunk,
tensors::squeeze,
tensors::clamp,
tensors::reverse,
tensors::arange,
tensors::to_type,
tensors::broadcast_to,
Expand Down
16 changes: 16 additions & 0 deletions native/candlex/src/tensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,22 @@ pub fn clamp(t: ExTensor, min_val: ExTensor, max_val: ExTensor) -> Result<ExTens
)?))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn reverse(t: ExTensor, dims: Vec<usize>) -> Result<ExTensor, CandlexError> {
let device = t.device();
let t_dims = t.dims();
let mut new_t = t.clone();

for dim in dims {
new_t = new_t.index_select(
&Tensor::arange_step::<i64>((t_dims[dim] as i64) - 1, -1, -1, device)?,
dim,
)?;
}

Ok(ExTensor::new(new_t))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn rsqrt(t: ExTensor) -> Result<ExTensor, CandlexError> {
Ok(ExTensor::new(t.sqrt()?.recip()?))
Expand Down
52 changes: 52 additions & 0 deletions test/candlex_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2291,6 +2291,58 @@ defmodule CandlexTest do
|> assert_equal(t([[1], [1]]))
end

test "reverse" do
t([1, 2, 3])
|> Nx.reverse()
|> assert_equal(t([3, 2, 1]))

t([[1, 2, 3], [4, 5, 6]])
|> Nx.reverse()
|> assert_equal(
t([
[6, 5, 4],
[3, 2, 1]
])
)

t([1, 2, 3], names: [:x])
|> Nx.reverse(axes: [:x])
|> assert_equal(t([3, 2, 1]))

t([[1, 2, 3], [4, 5, 6]], names: [:x, :y])
|> Nx.reverse(axes: [:x])
|> assert_equal(
t([
[4, 5, 6],
[1, 2, 3]
])
)

t([[1, 2, 3], [4, 5, 6]], names: [:x, :y])
|> Nx.reverse(axes: [:y])
|> assert_equal(
t([
[3, 2, 1],
[6, 5, 4]
])
)

Nx.iota({2, 2, 2}, type: :f32, names: [:x, :y, :z])
|> Nx.reverse(axes: [:x, :z])
|> assert_equal(
t([
[
[5.0, 4.0],
[7.0, 6.0]
],
[
[1.0, 0.0],
[3.0, 2.0]
]
])
)
end

if Candlex.Backend.cuda_available?() do
test "different devices" do
t([1, 2, 3], backend: {Candlex.Backend, device: :cpu})
Expand Down

0 comments on commit 7d0f261

Please sign in to comment.