Skip to content

Commit

Permalink
Merge branch 'main' into precompiled-cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Nov 2, 2023
2 parents 0ea23d5 + 123d332 commit b695f0d
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 42 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ name: CI
on:
pull_request:
push:
branches:
- main

jobs:
main:
Expand Down Expand Up @@ -33,6 +31,6 @@ jobs:
- run: mix deps.unlock --check-unused
if: ${{ matrix.lint }}
- run: mix deps.compile
# - run: mix compile --warnings-as-errors
# if: ${{ matrix.lint }}
- run: mix compile --warnings-as-errors
if: ${{ matrix.lint }}
- run: mix test
50 changes: 46 additions & 4 deletions lib/candlex/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,31 @@ defmodule Candlex.Backend do
# Aggregates

@impl true
def all(%T{} = out, %T{} = tensor, _opts) do
from_nx(tensor)
|> Native.all()
def all(%T{} = out, %T{} = tensor, opts) do
case opts[:axes] do
nil ->
from_nx(tensor)
|> Native.all()

axes ->
from_nx(tensor)
|> Native.all_within_dims(axes, opts[:keep_axes])
end
|> unwrap!()
|> to_nx(out)
end

@impl true
def any(%T{} = out, %T{} = tensor, opts) do
case opts[:axes] do
nil ->
from_nx(tensor)
|> Native.any()

axes ->
from_nx(tensor)
|> Native.any_within_dims(axes, opts[:keep_axes])
end
|> unwrap!()
|> to_nx(out)
end
Expand Down Expand Up @@ -506,6 +528,25 @@ defmodule Candlex.Backend do
end

@impl true
def dot(
%T{type: _out_type} = out,
%T{shape: left_shape, type: _left_type} = left,
[left_axis] = _left_axes,
[] = _left_batched_axes,
%T{shape: right_shape, type: _right_type} = right,
[0] = _right_axes,
[] = _right_batched_axes
)
when tuple_size(left_shape) >= 1 and tuple_size(right_shape) == 1 and
left_axis == tuple_size(left_shape) - 1 do
{left, right} = maybe_upcast(left, right)

from_nx(left)
|> Native.dot(from_nx(right))
|> unwrap!()
|> to_nx(out)
end

def dot(
%T{type: _out_type} = out,
%T{shape: left_shape, type: _left_type} = left,
Expand All @@ -516,6 +557,8 @@ defmodule Candlex.Backend do
[] = _right_batched_axes
)
when tuple_size(left_shape) == 2 and tuple_size(right_shape) == 2 do
{left, right} = maybe_upcast(left, right)

Native.matmul(
from_nx(left),
from_nx(right)
Expand Down Expand Up @@ -827,7 +870,6 @@ defmodule Candlex.Backend do
end

for op <- [
:any,
:argsort,
:eigh,
:fft,
Expand Down
4 changes: 4 additions & 0 deletions lib/candlex/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ defmodule Candlex.Native do
def from_binary(_binary, _dtype, _shape, _device), do: error()
def to_binary(_tensor), do: error()
def all(_tensor), do: error()
def all_within_dims(_tensor, _dims, _keep_dims), do: error()
def any(_tensor), do: error()
def any_within_dims(_tensor, _dims, _keep_dims), do: error()
def where_cond(_tensor, _on_true, _on_false), do: error()
def narrow(_tensor, _dim, _start, _length), do: error()
def gather(_tensor, _indexes, _dim), do: error()
Expand Down Expand Up @@ -92,6 +95,7 @@ defmodule Candlex.Native do
:bitwise_or,
:bitwise_xor,
:divide,
:dot,
:equal,
:greater,
:greater_equal,
Expand Down
5 changes: 3 additions & 2 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ defmodule Candlex.MixProject do
# Run "mix help deps" to learn about dependencies.
defp deps do
[
{:nx, "~> 0.6.2"},
# {:nx, "~> 0.6.2"},
{:nx, git: "https://github.com/elixir-nx/nx", sparse: "nx"},
{:rustler_precompiled, "~> 0.7.0"},

# Optional
{:rustler, "~> 0.30.0", optional: true},
{:rustler, "~> 0.29", optional: true},

# Dev
{:ex_doc, "~> 0.30.9", only: :dev, runtime: false}
Expand Down
2 changes: 1 addition & 1 deletion mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"},
"makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"},
"nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"},
"nx": {:hex, :nx, "0.6.2", "f1d137f477b1a6f84f8db638f7a6d5a0f8266caea63c9918aa4583db38ebe1d6", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "ac913b68d53f25f6eb39bddcf2d2cd6ea2e9bcb6f25cf86a79e35d0411ba96ad"},
"nx": {:git, "https://github.com/elixir-nx/nx", "7706e8601e40916c02f8773df7802b3bfab43054", [sparse: "nx"]},
"rustler": {:hex, :rustler, "0.30.0", "cefc49922132b072853fa9b0ca4dc2ffcb452f68fb73b779042b02d545e097fb", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "9ef1abb6a7dda35c47cfc649e6a5a61663af6cf842a55814a554a84607dee389"},
"rustler_precompiled": {:hex, :rustler_precompiled, "0.7.0", "5d0834fc06dbc76dd1034482f17b1797df0dba9b491cef8bb045fcaca94bcade", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "fdf43a6835f4e4de5bfbc4c019bfb8c46d124bd4635fefa3e20d9a2bbbec1512"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
Expand Down
4 changes: 4 additions & 0 deletions native/candlex/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ rustler::init! {
tensors::less,
tensors::less_equal,
tensors::all,
tensors::all_within_dims,
tensors::any,
tensors::any_within_dims,
tensors::sum,
tensors::dtype,
tensors::t_shape,
Expand All @@ -68,6 +71,7 @@ rustler::init! {
tensors::permute,
tensors::slice_scatter,
tensors::pad_with_zeros,
tensors::dot,
tensors::matmul,
tensors::abs,
tensors::acos,
Expand Down
86 changes: 71 additions & 15 deletions native/candlex/src/tensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,38 @@ pub fn arange(

#[rustler::nif(schedule = "DirtyCpu")]
pub fn all(ex_tensor: ExTensor) -> Result<ExTensor, CandlexError> {
let device = ex_tensor.device();
let t = ex_tensor.flatten_all()?;
let dims = t.shape().dims();
let on_true = Tensor::ones(dims, DType::U8, device)?;
let on_false = Tensor::zeros(dims, DType::U8, device)?;

let bool_scalar = match t
.where_cond(&on_true, &on_false)?
.min(0)?
.to_scalar::<u8>()?
{
0 => 0u8,
_ => 1u8,
};
Ok(ExTensor::new(_all(
&ex_tensor.flatten_all()?,
vec![0],
false,
)?))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn all_within_dims(
ex_tensor: ExTensor,
dims: Vec<usize>,
keep_dims: bool,
) -> Result<ExTensor, CandlexError> {
Ok(ExTensor::new(_all(ex_tensor.deref(), dims, keep_dims)?))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn any(ex_tensor: ExTensor) -> Result<ExTensor, CandlexError> {
Ok(ExTensor::new(_any(
&ex_tensor.flatten_all()?,
vec![0],
false,
)?))
}

Ok(ExTensor::new(Tensor::new(bool_scalar, device)?))
#[rustler::nif(schedule = "DirtyCpu")]
pub fn any_within_dims(
ex_tensor: ExTensor,
dims: Vec<usize>,
keep_dims: bool,
) -> Result<ExTensor, CandlexError> {
Ok(ExTensor::new(_any(ex_tensor.deref(), dims, keep_dims)?))
}

#[rustler::nif(schedule = "DirtyCpu")]
Expand Down Expand Up @@ -346,6 +362,14 @@ pub fn divide(left: ExTensor, right: ExTensor) -> Result<ExTensor, CandlexError>
))
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn dot(left: ExTensor, right: ExTensor) -> Result<ExTensor, CandlexError> {
Ok(ExTensor::new(
left.mul(&right.broadcast_as(left.shape())?)?
.sum(left.rank() - 1)?,
))
}

macro_rules! unary_nif {
($nif_name:ident, $native_fn_name:ident) => {
#[rustler::nif(schedule = "DirtyCpu")]
Expand Down Expand Up @@ -446,6 +470,38 @@ custom_binary_nif!(pow, Pow);
custom_binary_nif!(right_shift, Shr);
custom_binary_nif!(remainder, Remainder);

fn _any(tensor: &Tensor, dims: Vec<usize>, keep_dims: bool) -> Result<Tensor, CandlexError> {
let comparison = tensor.ne(&tensor.zeros_like()?)?;

let result = if keep_dims {
dims.iter()
.rev()
.fold(comparison, |t, dim| t.max_keepdim(*dim).unwrap())
} else {
dims.iter()
.rev()
.fold(comparison, |t, dim| t.max(*dim).unwrap())
};

Ok(result)
}

fn _all(tensor: &Tensor, dims: Vec<usize>, keep_dims: bool) -> Result<Tensor, CandlexError> {
let comparison = tensor.ne(&tensor.zeros_like()?)?;

let result = if keep_dims {
dims.iter()
.rev()
.fold(comparison, |t, dim| t.min_keepdim(*dim).unwrap())
} else {
dims.iter()
.rev()
.fold(comparison, |t, dim| t.min(*dim).unwrap())
};

Ok(result)
}

fn tuple_to_vec(term: Term) -> Result<Vec<usize>, rustler::Error> {
rustler::types::tuple::get_tuple(term)?
.iter()
Expand Down
Loading

0 comments on commit b695f0d

Please sign in to comment.