Skip to content

Commit

Permalink
feat: all/2 accepts options (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy authored Nov 2, 2023
1 parent 9ddde1d commit d7c234e
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 9 deletions.
13 changes: 10 additions & 3 deletions lib/candlex/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,16 @@ defmodule Candlex.Backend do
# Aggregates

@impl true
def all(%T{} = out, %T{} = tensor, _opts) do
from_nx(tensor)
|> Native.all()
def all(%T{} = out, %T{} = tensor, opts) do
case opts[:axes] do
nil ->
from_nx(tensor)
|> Native.all()

axes ->
from_nx(tensor)
|> Native.all_within_dims(axes, opts[:keep_axes])
end
|> unwrap!()
|> to_nx(out)
end
Expand Down
1 change: 1 addition & 0 deletions lib/candlex/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ defmodule Candlex.Native do
def from_binary(_binary, _dtype, _shape, _device), do: error()
def to_binary(_tensor), do: error()
def all(_tensor), do: error()
def all_within_dims(_tensor, _dims, _keep_dims), do: error()
def where_cond(_tensor, _on_true, _on_false), do: error()
def narrow(_tensor, _dim, _start, _length), do: error()
def gather(_tensor, _indexes, _dim), do: error()
Expand Down
1 change: 1 addition & 0 deletions native/candlex/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ rustler::init! {
tensors::less,
tensors::less_equal,
tensors::all,
tensors::all_within_dims,
tensors::sum,
tensors::dtype,
tensors::t_shape,
Expand Down
21 changes: 21 additions & 0 deletions native/candlex/src/tensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,27 @@ pub fn all(ex_tensor: ExTensor) -> Result<ExTensor, CandlexError> {
Ok(ExTensor::new(Tensor::new(bool_scalar, device)?))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn all_within_dims(
ex_tensor: ExTensor,
dims: Vec<usize>,
keep_dims: bool,
) -> Result<ExTensor, CandlexError> {
let comparison = ex_tensor.ne(&ex_tensor.zeros_like()?)?;

let tensor = if keep_dims {
dims.iter()
.rev()
.fold(comparison, |t, dim| t.min_keepdim(*dim).unwrap())
} else {
dims.iter()
.rev()
.fold(comparison, |t, dim| t.min(*dim).unwrap())
};

Ok(ExTensor::new(tensor))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn argmax(ex_tensor: ExTensor, dim: usize, keep_dim: bool) -> Result<ExTensor, CandlexError> {
let t = if keep_dim {
Expand Down
77 changes: 71 additions & 6 deletions test/candlex_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2144,17 +2144,82 @@ defmodule CandlexTest do
end

test "all" do
t(0)
|> Nx.all()
|> assert_equal(t(0))

t(10)
|> Nx.all()
|> assert_equal(t(1))

t([0, 1, 2])
|> Nx.all()
|> assert_equal(t(0))

# t([[-1, 0, 1], [2, 3, 4]], names: [:x, :y])
# |> Nx.all(axes: [:x])
# |> assert_equal(t([1, 0, 1]))
t([[-1, 0, 1], [2, 3, 4]], names: [:x, :y])
|> Nx.all(axes: [:x])
|> assert_equal(t([1, 0, 1]))

t([[-1, 0, 1], [2, 3, 4]], names: [:x, :y])
|> Nx.all(axes: [:y])
|> assert_equal(t([0, 1]))

# t([[-1, 0, 1], [2, 3, 4]], names: [:x, :y])
# |> Nx.all(axes: [:y])
# |> assert_equal(t([0, 1]))
t([[-1, 0, 1], [2, 3, 4]], names: [:x, :y])
|> Nx.all(axes: [:y], keep_axes: true)
|> assert_equal(
t([
[0],
[1]
])
)

tensor = Nx.tensor([[[1, 2], [0, 4]], [[5, 6], [7, 8]]], names: [:x, :y, :z])

tensor
|> Nx.all(axes: [:x, :y])
|> assert_equal(t([0, 1]))

tensor
|> Nx.all(axes: [:y, :z])
|> assert_equal(t([0, 1]))

tensor
|> Nx.all(axes: [:x, :z])
|> assert_equal(t([1, 0]))

tensor
|> Nx.all(axes: [:x, :y], keep_axes: true)
|> assert_equal(
t([
[
[0, 1]
]
])
)

tensor
|> Nx.all(axes: [:y, :z], keep_axes: true)
|> assert_equal(
t([
[
[0]
],
[
[1]
]
])
)

tensor
|> Nx.all(axes: [:x, :z], keep_axes: true)
|> assert_equal(
t([
[
[1],
[0]
]
])
)
end

if Candlex.Backend.cuda_available?() do
Expand Down

0 comments on commit d7c234e

Please sign in to comment.