From 19924366f30681f114b63d70cbcae6f64c9b46be Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 25 Oct 2023 11:15:33 -0300 Subject: [PATCH] feat: initial --- .gitignore | 3 + config/config.exs | 17 + lib/candlex.ex | 13 - lib/candlex/backend.ex | 979 +++++++++ lib/candlex/native.ex | 108 + mix.exs | 8 +- mix.lock | 8 + native/candlex/.cargo/config.toml | 5 + native/candlex/.gitignore | 1 + native/candlex/Cargo.lock | 911 ++++++++ native/candlex/Cargo.toml | 24 + native/candlex/build.rs | 250 +++ native/candlex/src/devices.rs | 4 + native/candlex/src/error.rs | 21 + native/candlex/src/kernels.rs | 4 + native/candlex/src/kernels/custom_binary.cu | 111 + native/candlex/src/kernels/custom_unary.cu | 112 + native/candlex/src/kernels/strides.cuh | 34 + native/candlex/src/lib.rs | 115 + native/candlex/src/ops.rs | 458 ++++ native/candlex/src/tensors.rs | 527 +++++ test/candlex_test.exs | 2192 ++++++++++++++++++- test/support/nx_case.ex | 45 + test/test_helper.exs | 1 + 24 files changed, 5933 insertions(+), 18 deletions(-) create mode 100644 config/config.exs create mode 100644 lib/candlex/backend.ex create mode 100644 lib/candlex/native.ex create mode 100644 mix.lock create mode 100644 native/candlex/.cargo/config.toml create mode 100644 native/candlex/.gitignore create mode 100644 native/candlex/Cargo.lock create mode 100644 native/candlex/Cargo.toml create mode 100644 native/candlex/build.rs create mode 100644 native/candlex/src/devices.rs create mode 100644 native/candlex/src/error.rs create mode 100644 native/candlex/src/kernels.rs create mode 100644 native/candlex/src/kernels/custom_binary.cu create mode 100644 native/candlex/src/kernels/custom_unary.cu create mode 100644 native/candlex/src/kernels/strides.cuh create mode 100644 native/candlex/src/lib.rs create mode 100644 native/candlex/src/ops.rs create mode 100644 native/candlex/src/tensors.rs create mode 100644 test/support/nx_case.ex diff --git a/.gitignore b/.gitignore index 82bb469..1fb4d17 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,6 @@ candlex-*.tar # Temporary files, for example, from tests. /tmp/ + +# Shared objects build by Rust. +*.so diff --git a/config/config.exs b/config/config.exs new file mode 100644 index 0000000..aecd8ea --- /dev/null +++ b/config/config.exs @@ -0,0 +1,17 @@ +import Config + +enable_cuda = + case System.get_env("CUDA") do + nil -> System.find_executable("nvcc") && System.find_executable("nvidia-smi") + "false" -> false + _ -> true + end + +crate_features = + if enable_cuda do + [:cuda] + else + [] + end + +config :candlex, crate_features: crate_features diff --git a/lib/candlex.ex b/lib/candlex.ex index a569e05..ee5cfce 100644 --- a/lib/candlex.ex +++ b/lib/candlex.ex @@ -2,17 +2,4 @@ defmodule Candlex do @moduledoc """ Documentation for `Candlex`. """ - - @doc """ - Hello world. - - ## Examples - - iex> Candlex.hello() - :world - - """ - def hello do - :world - end end diff --git a/lib/candlex/backend.ex b/lib/candlex/backend.ex new file mode 100644 index 0000000..9a04f9f --- /dev/null +++ b/lib/candlex/backend.ex @@ -0,0 +1,979 @@ +defmodule Candlex.Backend do + @moduledoc """ + An opaque Nx backend with bindings to candle. + """ + + defstruct [:device, :resource] + + @behaviour Nx.Backend + + alias Nx.Tensor, as: T + alias Candlex.Native + + @device_cuda :cuda + @device_cpu :cpu + + @impl true + def init(opts) do + Keyword.validate!(opts, [:device]) + end + + # Creation + + @impl true + def constant(%T{} = tensor, scalar, backend_options) do + tensor + |> Nx.BinaryBackend.constant(scalar, []) + |> Nx.BinaryBackend.backend_transfer(__MODULE__, backend_options) + end + + @impl true + def from_binary(%T{shape: shape, type: type} = tensor, binary, backend_options) do + binary + |> Native.from_binary(to_candle_dtype(type), shape, device_option(backend_options)) + |> unwrap!() + |> to_nx(tensor) + end + + @impl true + def iota(%T{shape: {}} = out, nil, backend_options) do + constant(out, 0, backend_options) + end + + def iota(%T{shape: shape, type: type} = out, nil, backend_options) do + Native.arange(0, Nx.size(shape), to_candle_dtype(type), shape, device_option(backend_options)) + |> unwrap!() + |> to_nx(out) + end + + def iota(%T{shape: shape, type: type} = out, axis, backend_options) do + # Build in one dimension, then broadcast + axis_size = elem(shape, axis) + + Native.arange( + 0, + axis_size, + to_candle_dtype(type), + Tuple.duplicate(1, Nx.rank(shape)) |> put_elem(axis, axis_size), + device_option(backend_options) + ) + |> unwrap!() + |> Native.broadcast_to(shape) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def eye(%T{shape: shape, type: type} = _out, backend_options) do + iota = Nx.iota(shape, backend: {__MODULE__, backend_options}) + + Nx.equal(Nx.tril(iota), Nx.triu(iota)) + |> Nx.as_type(type) + end + + # Backend + + @impl true + def backend_transfer(tensor, backend, backend_options) do + if backend == __MODULE__ && same_device?(tensor, device_option(backend_options)) do + tensor + else + try do + backend_copy(tensor, backend, backend_options) + after + backend_deallocate(tensor) + end + end + end + + @impl true + def backend_copy(%T{} = tensor, Candlex.Backend, backend_options) do + tensor + |> from_nx() + |> Native.to_device(device_option(backend_options)) + |> unwrap!() + |> to_nx(tensor) + end + + def backend_copy(%T{} = tensor, backend, backend_options) do + backend.from_binary(tensor, to_binary(tensor), backend_options) + end + + @impl true + def backend_deallocate(%T{} = _tensor) do + true + end + + # Conversion + + @impl true + def to_binary(tensor, _limit \\ nil) do + # TODO: don't ignore limit + + from_nx(tensor) + |> Native.to_binary() + |> unwrap!() + end + + # Aggregates + + @impl true + def all(%T{} = out, %T{} = tensor, _opts) do + from_nx(tensor) + |> Native.all() + |> unwrap!() + |> to_nx(out) + end + + @impl true + def sum(%T{type: out_type} = out, %T{} = t, opts) do + axes = opts[:axes] || Nx.axes(t) + keep_axes = opts[:keep_axes] || false + + t + |> from_nx() + |> Native.sum(axes, keep_axes) + |> unwrap!() + |> Native.to_type(to_candle_dtype(out_type)) + |> unwrap!() + |> to_nx(out) + end + + for op <- [:argmax, :argmin] do + @impl true + def unquote(op)(%T{} = out, %T{shape: {}} = _tensor, _opts) do + out + |> constant(0, []) + end + + def unquote(op)(%T{type: type} = out, %T{} = tensor, opts) do + axis = opts[:axis] || -1 + keep_axis = opts[:keep_axis] || false + + tensor + |> from_nx() + |> Native.unquote(op)(axis, keep_axis) + |> unwrap!() + # candle argmax/argmin changes to u32 + |> Native.to_type(to_candle_dtype(type)) + |> unwrap!() + |> to_nx(out) + end + end + + @impl true + def reduce_max(%T{} = out, %T{shape: {}} = tensor, _opts) do + out + |> from_binary(to_binary(tensor), []) + end + + def reduce_max(%T{} = out, %T{} = tensor, opts) do + axis = + case opts[:axes] do + nil -> 0 + [] -> 0 + [axis] -> axis + axes -> raise "doesn't support axes option with more than 1 axis, '#{inspect(axes)}'" + end + + keep_axis = opts[:keep_axes] || false + + tensor + |> from_nx() + |> Native.reduce_max(axis, keep_axis) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def reduce_min(%T{} = out, %T{shape: {}} = tensor, _opts) do + out + |> from_binary(to_binary(tensor), []) + end + + def reduce_min(%T{} = out, %T{} = tensor, opts) do + axis = + case opts[:axes] do + nil -> 0 + [] -> 0 + [axis] -> axis + axes -> raise "doesn't support axes option with more than 1 axis, '#{inspect(axes)}'" + end + + keep_axis = opts[:keep_axes] || false + + tensor + |> from_nx() + |> Native.reduce_min(axis, keep_axis) + |> unwrap!() + |> to_nx(out) + end + + # Element-wise + + @impl true + def clip(%T{} = out, %T{} = t, %T{} = min, %T{} = max) do + [t, min, max] = maybe_upcast([t, min, max]) + + t + |> from_nx() + |> Native.clamp(from_nx(min), from_nx(max)) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def select(%T{shape: shape, type: type} = out, pred, on_true, on_false) do + on_true = + on_true + |> from_nx() + |> Native.to_type(to_candle_dtype(type)) + |> unwrap!() + |> Native.broadcast_to(shape) + |> unwrap!() + + on_false = + on_false + |> from_nx() + |> Native.to_type(to_candle_dtype(type)) + |> unwrap!() + |> Native.broadcast_to(shape) + |> unwrap!() + + pred + |> from_nx() + |> Native.where_cond(on_true, on_false) + |> unwrap!() + |> to_nx(out) + end + + # Binary ops + + for op <- [:add, :divide, :max, :min, :multiply, :subtract] do + @impl true + def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do + {left, right} = maybe_transfer_device(left, right) + {left, right} = maybe_upcast(left, right) + + from_nx(left) + |> Native.unquote(op)(from_nx(right)) + |> unwrap!() + |> to_nx(out) + end + end + + for op <- [:atan2, :pow, :quotient, :remainder] do + @impl true + def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do + {left, right} = maybe_upcast(left, right) + {left, right} = maybe_broadcast_bin_args(out.shape, left, right) + + left + |> Native.unquote(op)(right) + |> unwrap!() + |> to_nx(out) + end + end + + for op <- [ + :bitwise_and, + :bitwise_or, + :bitwise_xor, + :equal, + :greater, + :greater_equal, + :left_shift, + :less, + :less_equal, + :logical_and, + :logical_or, + :logical_xor, + :not_equal, + :right_shift + ] do + @impl true + def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do + {left, right} = maybe_transfer_device(left, right) + {left, right} = maybe_upcast(left, right) + {left, right} = maybe_broadcast_bin_args(out.shape, left, right) + + left + |> Native.unquote(op)(right) + |> unwrap!() + # TODO: Do this conditionally or as part of native op + |> Native.to_type(to_candle_dtype(out.type)) + |> unwrap!() + |> to_nx(out) + end + end + + # Unary ops + + for op <- [ + :abs, + :acos, + :acosh, + :asin, + :asinh, + :atan, + :atanh, + :bitwise_not, + :cbrt, + :ceil, + :cos, + :cosh, + :erf, + :erfc, + :erf_inv, + :exp, + :expm1, + :floor, + :is_infinity, + :is_nan, + :log, + :log1p, + :negate, + :round, + :rsqrt, + :sigmoid, + :sign, + :sin, + :sinh, + :sqrt, + :tan, + :tanh + ] do + @impl true + def unquote(op)(%T{} = out, %T{} = tensor) do + tensor + |> from_nx() + |> Native.unquote(op)() + |> unwrap!() + |> to_nx(out) + end + end + + # Indexed + + @impl true + def gather(%T{} = out, %T{shape: {_}} = tensor, %T{} = indices) do + tensor + |> from_nx() + |> Native.gather(from_nx(Nx.flatten(indices)), 0) + |> unwrap!() + |> to_nx(out) + end + + def gather(%T{} = _out, %T{} = _tensor, %T{} = _indices) do + raise("unsupported gather for tensor of rank greater than 1") + end + + @impl true + def indexed_add(%T{} = out, %T{shape: {_}} = tensor, %T{} = indices, %T{} = updates) do + {tensor, updates} = maybe_upcast(tensor, updates) + + tensor + |> from_nx() + |> Native.index_add(from_nx(Nx.flatten(indices)), from_nx(updates), 0) + |> unwrap!() + |> to_nx(out) + end + + def indexed_add(%T{} = _out, %T{} = _tensor, %T{} = _indices, %T{} = _updates) do + raise("unsupported indexed_add for tensor of rank greater than 1") + end + + @impl true + def put_slice(%T{} = out, %T{} = t, [_ | _] = start_indices, slice) do + [last_start_index | leading_start_indices] = Enum.reverse(start_indices) + + if Enum.all?(leading_start_indices, fn i -> Nx.equal(i, 0) end) do + t + |> from_nx() + |> Native.slice_scatter( + from_nx(slice), + length(start_indices) - 1, + Nx.to_number(last_start_index) + ) + |> unwrap!() + |> to_nx(out) + else + raise "put_slice only supports last start index not to be 0 for now" + end + end + + @impl true + def slice( + %T{shape: _output_shape} = out, + %T{shape: input_shape} = t, + starts, + lengths, + _strides + ) do + t + |> from_nx() + |> narrow(starts, lengths, 0, input_shape) + # TODO: Support strides + # |> stride(output_shape, lengths, strides) + |> to_nx(out) + end + + @impl true + def take(%T{} = out, %T{} = tensor, %T{} = indexes, axis) do + if Nx.rank(indexes) > 1 do + raise "only indexes of rank=1 supported for now" + end + + tensor + |> from_nx() + |> Native.index_select(from_nx(indexes), axis) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def take_along_axis(%T{} = out, %T{} = tensor, %T{} = indexes, axis) do + tensor + |> from_nx() + |> Native.gather(from_nx(indexes), axis) + |> unwrap!() + |> to_nx(out) + end + + # N-dim + + @impl true + def concatenate(%T{} = out, tensors, axis) do + tensors + |> maybe_upcast() + |> Enum.map(&from_nx/1) + |> Native.concatenate(axis) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def conv(%T{type: out_type} = out, %T{shape: shape} = tensor, %T{} = kernel, opts) do + # TODO: Support more opts + unsupported_option!(opts, :batch_group_size, 1) + unsupported_option!(opts, :feature_group_size, 1) + + # For now we assume: + # strides = opts[:strides] # [1, 1] + # padding = opts[:padding] # [{0, 0}, {0, 0}] + # input_dilation = opts[:input_dilation] # [1, 1] + # kernel_dilation = opts[:kernel_dilation] # [1, 1] + + input_permutation = opts[:input_permutation] + kernel_permutation = opts[:kernel_permutation] + + output_permutation = + case opts[:output_permutation] do + nil -> + nil + + l -> + # The permutation that Nx.Shape expects is actually the reverse permutation + # for the given input + l |> Enum.with_index() |> Enum.sort() |> Enum.map(&elem(&1, 1)) + end + + native_tensor = + tensor + |> from_nx() + |> permute(input_permutation) + |> Native.to_type(to_candle_dtype(out_type)) + |> unwrap!() + + native_kernel = + kernel + |> from_nx() + |> permute(kernel_permutation) + |> Native.to_type(to_candle_dtype(out_type)) + |> unwrap!() + + native_result = + case Nx.rank(shape) do + 3 -> Native.conv1d(native_tensor, native_kernel) + 4 -> Native.conv2d(native_tensor, native_kernel) + rank -> raise("unsupported conv for tensor of rank #{rank}, only 3 or 4 supported") + end + + native_result + |> unwrap!() + |> permute(output_permutation) + |> to_nx(out) + end + + @impl true + def dot( + %T{type: _out_type} = out, + %T{shape: left_shape, type: _left_type} = left, + [1] = _left_axes, + [] = _left_batched_axes, + %T{shape: right_shape, type: _right_type} = right, + [0] = _right_axes, + [] = _right_batched_axes + ) + when tuple_size(left_shape) == 2 and tuple_size(right_shape) == 2 do + Native.matmul( + from_nx(left), + from_nx(right) + ) + |> unwrap!() + |> to_nx(out) + end + + def dot( + out, + %T{shape: left_shape} = left, + [0], + left_batched_axes, + right, + right_axes, + right_batched_axes + ) + when tuple_size(left_shape) == 2 do + dot( + out, + left |> Nx.transpose(axes: [1, 0]), + [1], + left_batched_axes, + right, + right_axes, + right_batched_axes + ) + end + + def dot( + out, + left, + left_axes, + left_batched_axes, + %T{shape: right_shape} = right, + [1], + right_batched_axes + ) + when tuple_size(right_shape) == 2 do + dot( + out, + left, + left_axes, + left_batched_axes, + right |> Nx.transpose(axes: [1, 0]), + [0], + right_batched_axes + ) + end + + # Shape + + @impl true + def broadcast(out, %T{} = t, shape, axes) do + t + |> maybe_reshape(shape, axes) + |> from_nx() + |> Native.broadcast_to(shape) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def pad(%T{} = out, %T{} = _t, _pad_value, []) do + out + end + + def pad(%T{} = out, %T{} = t, %T{shape: {}} = pad_value, [{low, high, 0 = _inner}]) do + if !Nx.equal(pad_value, 0) do + raise "only pad_value=0 supported for now" + end + + t + |> from_nx() + |> Native.pad_with_zeros(low, high) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def reshape(%T{shape: shape} = out, %T{} = t) do + from_nx(t) + |> Native.reshape(shape) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def squeeze(%T{} = out, %T{} = t, axes) do + # sort the axes desc so we don't have to decrease the axis numbers after each squeeze + for axis <- Enum.sort(axes, :desc), reduce: from_nx(t) do + ref -> + ref + |> Native.squeeze(axis) + |> unwrap!() + end + |> to_nx(out) + end + + @impl true + def transpose(out, %T{} = t, axes) do + from_nx(t) + |> Native.permute(axes) + |> unwrap!() + |> to_nx(out) + end + + # Type + + @impl true + def as_type(%T{type: type} = out, %T{} = t) do + from_nx(t) + |> Native.to_type(to_candle_dtype(type)) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def bitcast(out, tensor) do + out + |> from_binary(to_binary(tensor), []) + end + + # Inspect + + @impl true + def inspect(%T{} = tensor, inspect_opts) do + limit = if inspect_opts.limit == :infinity, do: :infinity, else: inspect_opts.limit + 1 + + tensor + |> to_binary(min(limit, Nx.size(tensor))) + |> then(&Nx.Backend.inspect(tensor, &1, inspect_opts)) + |> maybe_add_signature(tensor) + end + + defp maybe_add_signature(result, %T{data: %__MODULE__{device: device, resource: ref}}) + when is_reference(ref) do + Inspect.Algebra.concat([ + "Candlex.Backend(#{device})", + Inspect.Algebra.line(), + result + ]) + end + + defp narrow(t, [start | starts], [length | lengths], axis, shape) do + dim = elem(shape, axis) + start = min(start, dim - length) + + if start == 0 and length == dim do + # Nothing to narrow at this step + t + else + t + |> Native.narrow(axis, start, length) + |> unwrap!() + end + |> narrow(starts, lengths, axis + 1, shape) + end + + defp narrow(t, [], [], _axis, _shape), do: t + + defp maybe_reshape(%T{shape: {}} = t, target_shape, _axes) do + shape = + 1 + |> List.duplicate(tuple_size(target_shape)) + |> List.to_tuple() + + t + |> Nx.reshape(shape) + end + + defp maybe_reshape(%T{shape: shape} = t, target_shape, axes) do + base_broadcast_shape = 1 |> List.duplicate(tuple_size(target_shape)) |> List.to_tuple() + + new_shape = + shape + |> Tuple.to_list() + |> Enum.zip(axes) + |> Enum.reduce(base_broadcast_shape, fn {dim_size, target_axis}, shape_acc -> + shape_acc + |> Tuple.delete_at(target_axis) + |> Tuple.insert_at(target_axis, dim_size) + end) + + t + |> Nx.reshape(new_shape) + end + + defp maybe_upcast(%T{type: t} = left, %T{type: t} = right) do + {left, right} + end + + defp maybe_upcast(left, right) do + type = Nx.Type.merge(left.type, right.type) + + {Nx.as_type(left, type), Nx.as_type(right, type)} + end + + defp maybe_upcast([first | _] = tensors) do + type = + tensors + |> Enum.reduce( + first.type, + fn tensor, type -> + Nx.Type.merge(type, tensor.type) + end + ) + + tensors + |> Enum.map(fn tensor -> + Nx.as_type(tensor, type) + end) + end + + defp maybe_broadcast_bin_args(out_shape, l, r) do + { + case l.shape do + ^out_shape -> + from_nx(l) + + _ -> + l |> from_nx() |> Native.broadcast_to(out_shape) |> unwrap!() + end, + case r.shape do + ^out_shape -> from_nx(r) + _ -> r |> from_nx() |> Native.broadcast_to(out_shape) |> unwrap!() + end + } + end + + defp maybe_transfer_device( + %T{data: %__MODULE__{device: device}} = l, + %T{data: %__MODULE__{device: device}} = r + ) do + {l, r} + end + + defp maybe_transfer_device( + %T{data: %__MODULE__{device: device}} = l, + %T{data: %__MODULE__{device: _other_device}} = r + ) do + { + l, + r |> Nx.backend_transfer({__MODULE__, device: device}) + } + end + + defp maybe_transfer_device(%T{} = l, %T{data: %__MODULE__{device: device}} = r) do + { + l |> Nx.backend_transfer({__MODULE__, device: device}), + r + } + end + + defp maybe_transfer_device(%T{data: %__MODULE__{device: device}} = l, %T{} = r) do + { + l, + r |> Nx.backend_transfer({__MODULE__, device: device}) + } + end + + ## Conversions + + @impl true + def to_batched(%T{shape: out_shape} = out, %T{shape: shape} = t, opts) do + leftover = opts[:leftover] + first_dimension = 0 + batch_size = elem(out_shape, first_dimension) + axis_total = elem(shape, first_dimension) + remainder = rem(axis_total, batch_size) + num_batches = div(axis_total, batch_size) + native_tensor = from_nx(t) + + cond do + remainder == 0 -> + native_tensor + |> Native.chunk(num_batches) + |> unwrap!() + + remainder > 0 && leftover == :repeat -> + [ + native_tensor, + Native.narrow(native_tensor, first_dimension, 0, batch_size - remainder) + |> unwrap!() + ] + |> Native.concatenate(first_dimension) + |> unwrap!() + |> Native.chunk(num_batches + 1) + |> unwrap!() + + true -> + raise "not implemented" + end + |> Stream.map(&to_nx(&1, out)) + end + + for op <- [ + :cholesky, + :conjugate, + :count_leading_zeros, + :imag, + :population_count, + :real + ] do + @impl true + def unquote(op)(_out, _tensor) do + raise "unsupported Candlex.Backend.#{unquote(op)} function" + end + end + + for op <- [ + :any, + :argsort, + :eigh, + :fft, + :ifft, + :lu, + :product, + :qr, + :reverse, + :sort, + ] do + @impl true + def unquote(op)(_out, _tensor, _) do + raise "unsupported Candlex.Backend.#{unquote(op)} function" + end + end + + for op <- [ + :indexed_put, + :map, + :triangular_solve, + :window_max, + :window_min, + :window_product, + :window_sum + ] do + @impl true + def unquote(op)(_out, _tensor, _, _) do + raise "unsupported Candlex.Backend.#{unquote(op)} function" + end + end + + @impl true + def reduce(_out, _tensor, _, _, _) do + raise "unsupported Candlex.Backend.reduce function" + end + + for op <- [ + :window_reduce, + :window_scatter_max, + :window_scatter_min + ] do + @impl true + def unquote(op)(_out, _tensor, _, _, _, _) do + raise "unsupported Candlex.Backend.#{unquote(op)} function" + end + end + + defp permute(native_tensor, permutation) do + native_tensor + |> Native.permute(permutation) + |> unwrap!() + end + + @doc false + defp from_nx(%T{data: %__MODULE__{} = data}), do: data + + defp from_nx(%T{} = tensor) do + tensor + |> Nx.backend_transfer(__MODULE__) + |> from_nx() + end + + defp to_nx(%__MODULE__{resource: ref} = backend_tensor, %T{type: nx_type, shape: nx_shape} = t) + when is_reference(ref) do + {:ok, candle_dtype} = Native.dtype(backend_tensor) + {:ok, candle_shape} = Native.t_shape(backend_tensor) + + case {nx_type, from_candle_dtype(candle_dtype)} do + {{:u, 64}, {:s, 64}} -> + :ok + + {type, type} -> + :ok + + {type, other_type} -> + raise "tensor type mismatch, Nx (#{inspect(type)}) and Candle (#{inspect(other_type)})" + end + + if nx_shape != candle_shape do + raise "tensor shape mismatch, Nx (#{inspect(nx_shape)}) and Candle (#{inspect(candle_shape)})" + end + + %{t | data: backend_tensor} + end + + defp to_candle_dtype({:s, 8} = t), do: unsupported_dtype(t) + defp to_candle_dtype({:s, 16} = t), do: unsupported_dtype(t) + defp to_candle_dtype({:s, 32} = t), do: unsupported_dtype(t) + defp to_candle_dtype({:s, 64}), do: "i64" + defp to_candle_dtype({:u, 8}), do: "u8" + defp to_candle_dtype({:u, 16} = t), do: unsupported_dtype(t) + defp to_candle_dtype({:u, 32}), do: "u32" + defp to_candle_dtype({:u, 64}), do: "i64" + defp to_candle_dtype({:f, 16}), do: "f16" + defp to_candle_dtype({:f, 32}), do: "f32" + defp to_candle_dtype({:f, 64}), do: "f64" + defp to_candle_dtype({:bf, 16}), do: "bf16" + defp to_candle_dtype({:c, 64} = t), do: unsupported_dtype(t) + defp to_candle_dtype({:c, 128} = t), do: unsupported_dtype(t) + + defp from_candle_dtype("i64"), do: {:s, 64} + defp from_candle_dtype("u8"), do: {:u, 8} + defp from_candle_dtype("u32"), do: {:u, 32} + defp from_candle_dtype("f16"), do: {:f, 16} + defp from_candle_dtype("bf16"), do: {:bf, 16} + defp from_candle_dtype("f32"), do: {:f, 32} + defp from_candle_dtype("f64"), do: {:f, 64} + + defp device_option(nil) do + default_device() + end + + defp device_option(backend_options) do + backend_options[:device] || default_device() + end + + defp default_device do + if cuda_available?() do + @device_cuda + else + @device_cpu + end + end + + defp same_device?(%T{data: %__MODULE__{device: device}}, device) do + true + end + + defp same_device?(_t, _d) do + false + end + + def cuda_available? do + Native.is_cuda_available() + end + + defp unsupported_dtype(t) do + raise("Unsupported candle dtype for #{inspect(t)}") + end + + defp unsupported_option!(opts, key, acceptable_default) do + if opts[key] != nil and opts[key] != acceptable_default do + raise "#{inspect(key)} option with #{inspect(opts[key])} is not supported" + end + end + + defp unwrap!({:ok, result}), do: result + defp unwrap!({:error, error}), do: raise("Candlex: #{error}") +end diff --git a/lib/candlex/native.ex b/lib/candlex/native.ex new file mode 100644 index 0000000..4ccc8d3 --- /dev/null +++ b/lib/candlex/native.ex @@ -0,0 +1,108 @@ +defmodule Candlex.Native do + @moduledoc false + + use Rustler, otp_app: :candlex, features: Application.compile_env(:candlex, :crate_features, []) + + # Rustler will override all the below stub functions with real NIFs + def from_binary(_binary, _dtype, _shape, _device), do: error() + def to_binary(_tensor), do: error() + def all(_tensor), do: error() + def where_cond(_tensor, _on_true, _on_false), do: error() + def narrow(_tensor, _dim, _start, _length), do: error() + def gather(_tensor, _indexes, _dim), do: error() + def index_select(_tensor, _indexes, _dim), do: error() + def index_add(_tensor, _indexes, _source, _dim), do: error() + def chunk(_tensor, _num_chunks), do: error() + def squeeze(_tensor, _dim), do: error() + def arange(_start, _end, _dtype, _shape, _device), do: error() + def broadcast_to(_tensor, _shape), do: error() + def reshape(_tensor, _shape), do: error() + def to_type(_tensor, _dtype), do: error() + def dtype(_tensor), do: error() + def t_shape(_tensor), do: error() + def concatenate(_tensors, _axis), do: error() + def conv1d(_tensor, _kernel), do: error() + def conv2d(_tensor, _kernel), do: error() + def slice_scatter(_tensor, _src, _dim, _start), do: error() + def pad_with_zeros(_tensor, _left, _right), do: error() + def clamp(_tensor, _min, _max), do: error() + + for op <- [ + :abs, + :acos, + :acosh, + :asin, + :asinh, + :atan, + :atanh, + :bitwise_not, + :cbrt, + :ceil, + :cos, + :cosh, + :erf, + :erfc, + :erf_inv, + :exp, + :expm1, + :floor, + :is_infinity, + :is_nan, + :log, + :log1p, + :negate, + :round, + :rsqrt, + :sigmoid, + :sign, + :sin, + :sinh, + :sqrt, + :tan, + :tanh + ] do + def unquote(op)(_tensor), do: error() + end + + for op <- [ + :add, + :atan2, + :bitwise_and, + :bitwise_or, + :bitwise_xor, + :divide, + :equal, + :greater, + :greater_equal, + :left_shift, + :less, + :less_equal, + :logical_and, + :logical_or, + :logical_xor, + :matmul, + :max, + :min, + :multiply, + :not_equal, + :pow, + :quotient, + :remainder, + :right_shift, + :subtract + ] do + def unquote(op)(_left, _right), do: error() + end + + def sum(_tensor, _dims, _keep_dims), do: error() + def permute(_tensor, _dims), do: error() + + for op <- [:argmax, :argmin, :reduce_max, :reduce_min] do + def unquote(op)(_tensor, _dim, _keep_dim), do: error() + end + + def is_cuda_available(), do: error() + def to_device(_tensor, _device), do: error() + + defp error(), do: :erlang.nif_error(:nif_not_loaded) +end diff --git a/mix.exs b/mix.exs index 59d4d2e..374d923 100644 --- a/mix.exs +++ b/mix.exs @@ -6,6 +6,7 @@ defmodule Candlex.MixProject do app: :candlex, version: "0.1.0", elixir: "~> 1.15", + elixirc_paths: elixirc_paths(Mix.env()), start_permanent: Mix.env() == :prod, deps: deps() ] @@ -18,11 +19,14 @@ defmodule Candlex.MixProject do ] end + defp elixirc_paths(:test), do: ["lib", "test/support"] + defp elixirc_paths(_), do: ["lib"] + # Run "mix help deps" to learn about dependencies. defp deps do [ - # {:dep_from_hexpm, "~> 0.3.0"}, - # {:dep_from_git, git: "https://github.com/elixir-lang/my_dep.git", tag: "0.1.0"} + {:nx, "~> 0.6.2"}, + {:rustler, "~> 0.29.1"} ] end end diff --git a/mix.lock b/mix.lock new file mode 100644 index 0000000..5d69de1 --- /dev/null +++ b/mix.lock @@ -0,0 +1,8 @@ +%{ + "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, + "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, + "nx": {:hex, :nx, "0.6.2", "f1d137f477b1a6f84f8db638f7a6d5a0f8266caea63c9918aa4583db38ebe1d6", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "ac913b68d53f25f6eb39bddcf2d2cd6ea2e9bcb6f25cf86a79e35d0411ba96ad"}, + "rustler": {:hex, :rustler, "0.29.1", "880f20ae3027bd7945def6cea767f5257bc926f33ff50c0d5d5a5315883c084d", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "109497d701861bfcd26eb8f5801fe327a8eef304f56a5b63ef61151ff44ac9b6"}, + "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, + "toml": {:hex, :toml, "0.7.0", "fbcd773caa937d0c7a02c301a1feea25612720ac3fa1ccb8bfd9d30d822911de", [:mix], [], "hexpm", "0690246a2478c1defd100b0c9b89b4ea280a22be9a7b313a8a058a2408a2fa70"}, +} diff --git a/native/candlex/.cargo/config.toml b/native/candlex/.cargo/config.toml new file mode 100644 index 0000000..20f03f3 --- /dev/null +++ b/native/candlex/.cargo/config.toml @@ -0,0 +1,5 @@ +[target.'cfg(target_os = "macos")'] +rustflags = [ + "-C", "link-arg=-undefined", + "-C", "link-arg=dynamic_lookup", +] diff --git a/native/candlex/.gitignore b/native/candlex/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/native/candlex/.gitignore @@ -0,0 +1 @@ +/target diff --git a/native/candlex/Cargo.lock b/native/candlex/Cargo.lock new file mode 100644 index 0000000..78a1854 --- /dev/null +++ b/native/candlex/Cargo.lock @@ -0,0 +1,911 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "aho-corasick" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea5d730647d4fadd988536d06fecce94b7b4f2a7efdae548f1cf4b63205518ab" +dependencies = [ + "memchr", +] + +[[package]] +name = "anyhow" +version = "1.0.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" + +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bytemuck" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" + +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + +[[package]] +name = "candle-core" +version = "0.3.0" +source = "git+https://github.com/huggingface/candle#8f7973958c55324a24f0c514e7ac6ded6681980f" +dependencies = [ + "byteorder", + "candle-gemm", + "candle-kernels", + "cudarc", + "half", + "memmap2", + "num-traits", + "num_cpus", + "rand", + "rand_distr", + "rayon", + "safetensors", + "thiserror", + "yoke", + "zip", +] + +[[package]] +name = "candle-gemm" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef9b07a4b0ba1a304b44432006580980ddff9748c201261c279437e7b11bba68" +dependencies = [ + "candle-gemm-c32", + "candle-gemm-c64", + "candle-gemm-common", + "candle-gemm-f16", + "candle-gemm-f32", + "candle-gemm-f64", + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-c32" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f595241dad99811de285e029889f57c29dd98e33de7a8a6b881867b1488d7d4a" +dependencies = [ + "candle-gemm-common", + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-c64" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "648f22fd8f5a4f330e29d791845b514966421308a6a2b5fedb949ee07e54c77f" +dependencies = [ + "candle-gemm-common", + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-common" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e03c01b4ca3b9d71e4eb89e42946a08f8b0d2f1b861f7fa2ea0966233f1e0b08" +dependencies = [ + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-f16" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97f8af2a482131713d28a337abff6debf26c529afa1837caf2ba190909b2107c" +dependencies = [ + "candle-gemm-common", + "candle-gemm-f32", + "dyn-stack", + "half", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-f32" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "938927961e2f0c0a6064fcf3524ea3f7f455fe5708419532a6fea9aea1ab45ae" +dependencies = [ + "candle-gemm-common", + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-f64" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d192d7126e59b81ef4cf13cd9f194e6dbdc09171f65d0074d059dc009ac06775" +dependencies = [ + "candle-gemm-common", + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-kernels" +version = "0.3.0" +source = "git+https://github.com/huggingface/candle#8f7973958c55324a24f0c514e7ac6ded6681980f" +dependencies = [ + "glob", + "rayon", +] + +[[package]] +name = "candlex" +version = "0.1.0" +dependencies = [ + "anyhow", + "candle-core", + "half", + "num-traits", + "rustler", + "statrs", + "thiserror", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "crc32fast" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "cudarc" +version = "0.9.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc4cab390f4a32340211f015292a4551742a63e528e9ade9e0bde0d1a989d2a1" +dependencies = [ + "half", +] + +[[package]] +name = "dyn-stack" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fe7f8d7bcc523381d3c437b82cf74805de3931de0da69309ae0fe1bdf7a256e" +dependencies = [ + "bytemuck", + "reborrow", +] + +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + +[[package]] +name = "getrandom" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + +[[package]] +name = "half" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +dependencies = [ + "cfg-if", + "crunchy", + "num-traits", + "rand", + "rand_distr", +] + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "hermit-abi" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" + +[[package]] +name = "itoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "libc" +version = "0.2.148" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cdc71e17332e86d2e1d38c1f99edcb6288ee11b815fb1a4b049eaa2114d369b" + +[[package]] +name = "libm" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" + +[[package]] +name = "matrixmultiply" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memchr" +version = "2.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" + +[[package]] +name = "memmap2" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" +dependencies = [ + "libc", + "stable_deref_trait", +] + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "nalgebra" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d506eb7e08d6329505faa8a3a00a5dcc6de9f76e0c77e4b75763ae3c770831ff" +dependencies = [ + "approx", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "rand", + "rand_distr", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01fcc0b8149b4632adc89ac3b7b31a12fb6099a0317a4eb2ebff574ef7de7218" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "num-complex" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "paste" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro2" +version = "1.0.67" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d433d9f1a3e8c1263d9456598b16fec66f4acc9a74dacffd35c7bb09b3a1328" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "raw-cpuid" +version = "10.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "reborrow" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" + +[[package]] +name = "regex" +version = "1.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebee201405406dbf528b8b672104ae6d6d63e6d118cb10e4d51abbc7b58044ff" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59b23e92ee4318893fa3fe3e6fb365258efbfe6ac6ab30f090cdcbb7aa37efa9" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" + +[[package]] +name = "rustler" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0884cb623b9f43d3e2c51f9071c5e96a5acf3e6e6007866812884ff0cb983f1e" +dependencies = [ + "lazy_static", + "rustler_codegen", + "rustler_sys", +] + +[[package]] +name = "rustler_codegen" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50e277af754f2560cf4c4ebedb68c1a735292fb354505c6133e47ec406e699cf" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.37", +] + +[[package]] +name = "rustler_sys" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a7c0740e5322b64e2b952d8f0edce5f90fcf6f6fe74cca3f6e78eb3de5ea858" +dependencies = [ + "regex", + "unreachable", +] + +[[package]] +name = "ryu" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" + +[[package]] +name = "safe_arch" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f398075ce1e6a179b46f51bd88d0598b92b00d3551f1a2d4ac49e771b56ac354" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "safetensors" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" + +[[package]] +name = "serde" +version = "1.0.188" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.188" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", +] + +[[package]] +name = "serde_json" +version = "1.0.107" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "simba" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0b7840f121a46d63066ee7a99fc81dcabbc6105e437cae43528cea199b5a05f" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", + "wide", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + +[[package]] +name = "statrs" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d08e5e1748192713cc281da8b16924fb46be7b0c2431854eadc785823e5696e" +dependencies = [ + "approx", + "lazy_static", + "nalgebra", + "num-traits", + "rand", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7303ef2c05cd654186cb250d29049a24840ca25d2747c25c0381c8d9e2f582e8" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "synstructure" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "285ba80e733fac80aa4270fbcdf83772a79b80aa35c97075320abfee4a915b06" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", + "unicode-xid", +] + +[[package]] +name = "thiserror" +version = "1.0.49" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.49" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unicode-xid" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" + +[[package]] +name = "unreachable" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "382810877fe448991dfc7f0dd6e3ae5d58088fd0ea5e35189655f84e6814fa56" +dependencies = [ + "void", +] + +[[package]] +name = "void" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wide" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebecebefc38ff1860b4bc47550bbfa63af5746061cf0d29fcd7fa63171602598" +dependencies = [ + "bytemuck", + "safe_arch", +] + +[[package]] +name = "yoke" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e38c508604d6bbbd292dadb3c02559aa7fff6b654a078a36217cad871636e4" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5e19fb6ed40002bab5403ffa37e53e0e56f914a4450c8765f533018db1db35f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", + "synstructure", +] + +[[package]] +name = "zerofrom" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "655b0814c5c0b19ade497851070c640773304939a6c0fd5f5fb43da0696d05b7" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6a647510471d372f2e6c2e6b7219e44d8c574d24fdc11c610a61455782f18c3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", + "synstructure", +] + +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "byteorder", + "crc32fast", + "crossbeam-utils", +] diff --git a/native/candlex/Cargo.toml b/native/candlex/Cargo.toml new file mode 100644 index 0000000..79f7cc2 --- /dev/null +++ b/native/candlex/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "candlex" +version = "0.1.0" +authors = [] +edition = "2021" + +[lib] +name = "candlex" +path = "src/lib.rs" +crate-type = ["cdylib"] + +[dependencies] +candle-core = { git = "https://github.com/huggingface/candle" } +half = "2.3.1" +num-traits = "0.2.16" +rustler = "0.29.1" +statrs = "0.16.0" +thiserror = "1.0.47" + +[build-dependencies] +anyhow = "1.0.75" + +[features] +cuda = ["candle-core/cuda"] diff --git a/native/candlex/build.rs b/native/candlex/build.rs new file mode 100644 index 0000000..86a9c61 --- /dev/null +++ b/native/candlex/build.rs @@ -0,0 +1,250 @@ +#![allow(unused)] +use anyhow::{Context, Result}; +use std::io::Write; +use std::path::PathBuf; + +struct KernelDirectories { + kernel_dir: &'static str, + rust_target: &'static str, + include_dirs: &'static [&'static str], +} + +const DIRS: [KernelDirectories; 1] = [KernelDirectories { + kernel_dir: "src/kernels/", + rust_target: "src/kernels.rs", + include_dirs: &[], +}]; + +impl KernelDirectories { + fn maybe_build_ptx( + &self, + cu_file: &std::path::Path, + ptx_file: &std::path::Path, + compute_cap: usize, + ) -> Result<()> { + let should_compile = if ptx_file.exists() { + cu_file + .metadata()? + .modified()? + .duration_since(ptx_file.metadata()?.modified()?) + .is_ok() + } else { + true + }; + + if should_compile { + #[cfg(feature = "cuda")] + { + let mut command = std::process::Command::new("nvcc"); + let out_dir = ptx_file.parent().context("no parent for ptx file")?; + let include_dirs: Vec = + self.include_dirs.iter().map(|c| format!("-I{c}")).collect(); + + command + .arg(format!("--gpu-architecture=sm_{compute_cap}")) + .arg("--ptx") + .args(["--default-stream", "per-thread"]) + .args(["--output-directory", out_dir.to_str().unwrap()]) + .arg(format!("-I/{}", self.kernel_dir)) + .args(include_dirs) + .arg(cu_file); + + let output = command + .spawn() + .context("failed spawning nvcc")? + .wait_with_output()?; + + if !output.status.success() { + anyhow::bail!( + "nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ) + } + } + + #[cfg(not(feature = "cuda"))] + std::fs::OpenOptions::new() + .create(true) + .write(true) + .open(ptx_file)?; + } + + Ok(()) + } + + fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> { + println!("cargo:rerun-if-changed={}", self.kernel_dir); + let kernel_dir = PathBuf::from(self.kernel_dir); + let out_dir = out_dir.join(self.kernel_dir); + if !out_dir.exists() { + std::fs::create_dir_all(&out_dir)?; + } + let mut cu_files = vec![]; + let mut cuh_files = vec![]; + for file in std::fs::read_dir(kernel_dir)?.flatten() { + let file = file.path(); + match file.extension().and_then(|v| v.to_str()) { + Some("cu") => cu_files.push(file), + Some("cuh") => cuh_files.push(file), + _ => {} + } + } + + let mut ptx_paths = vec![]; + for cu_file in cu_files.iter() { + let file_stem = cu_file + .file_stem() + .with_context(|| format!("no stem {cu_file:?}"))?; + let file_stem = file_stem.to_string_lossy().into_owned(); + let ptx_file = out_dir.join(&format!("{file_stem}.ptx")); + self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?; + ptx_paths.push(ptx_file); + } + + let regenerate_rs_file = true; + if regenerate_rs_file { + let mut file = std::fs::File::create(self.rust_target)?; + for ptx_path in ptx_paths { + let name = ptx_path + .file_stem() + .context("empty stem")? + .to_string_lossy(); + file.write_all(b"#[rustfmt::skip]\n")?; + let const_definition = format!( + r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#, + name.to_uppercase().replace('.', "_"), + self.kernel_dir, + ); + file.write_all(const_definition.as_bytes())?; + file.write_all(b"\n")?; + } + } + + Ok(()) + } +} + +fn main() -> Result<()> { + println!("cargo:rerun-if-changed=build.rs"); + + let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?; + let out_dir = PathBuf::from(out_dir); + #[cfg(feature = "cuda")] + set_cuda_include_dir()?; + #[cfg(feature = "cuda")] + let compute_cap = compute_cap()?; + #[cfg(not(feature = "cuda"))] + let compute_cap = 0; + + for dir in DIRS { + dir.process(&out_dir, compute_cap)? + } + + Ok(()) +} + +fn set_cuda_include_dir() -> Result<()> { + // NOTE: copied from cudarc build.rs. + let env_vars = [ + "CUDA_PATH", + "CUDA_ROOT", + "CUDA_TOOLKIT_ROOT_DIR", + "CUDNN_LIB", + ]; + let env_vars = env_vars + .into_iter() + .map(std::env::var) + .filter_map(Result::ok) + .map(Into::::into); + + let roots = [ + "/usr", + "/usr/local/cuda", + "/opt/cuda", + "/usr/lib/cuda", + "C:/Program Files/NVIDIA GPU Computing Toolkit", + "C:/CUDA", + ]; + let roots = roots.into_iter().map(Into::::into); + let root = env_vars + .chain(roots) + .find(|path| path.join("include").join("cuda.h").is_file()) + .context("cannot find include/cuda.h")?; + println!( + "cargo:rustc-env=CUDA_INCLUDE_DIR={}", + root.join("include").display() + ); + Ok(()) +} + +#[allow(unused)] +fn compute_cap() -> Result { + // Grab compute code from nvidia-smi + let mut compute_cap = { + let out = std::process::Command::new("nvidia-smi") + .arg("--query-gpu=compute_cap") + .arg("--format=csv") + .output() + .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; + let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; + let mut lines = out.lines(); + assert_eq!( + lines.next().context("missing line in stdout")?, + "compute_cap" + ); + let cap = lines + .next() + .context("missing line in stdout")? + .replace('.', ""); + cap.parse::() + .with_context(|| format!("cannot parse as int {cap}"))? + }; + + // Grab available GPU codes from nvcc and select the highest one + let max_nvcc_code = { + let out = std::process::Command::new("nvcc") + .arg("--list-gpu-code") + .output() + .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); + let out = std::str::from_utf8(&out.stdout).unwrap(); + + let out = out.lines().collect::>(); + let mut codes = Vec::with_capacity(out.len()); + for code in out { + let code = code.split('_').collect::>(); + if !code.is_empty() && code.contains(&"sm") { + if let Ok(num) = code[1].parse::() { + codes.push(num); + } + } + } + codes.sort(); + if !codes.contains(&compute_cap) { + anyhow::bail!( + "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}." + ); + } + *codes.last().unwrap() + }; + + // If nvidia-smi compute_cap is higher than the highest gpu code from nvcc, + // then choose the highest gpu code in nvcc + if compute_cap > max_nvcc_code { + println!( + "cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}." + ); + compute_cap = max_nvcc_code; + } + + println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); + + if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { + compute_cap = compute_cap_str + .parse::() + .with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?; + println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP"); + } + println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}"); + Ok(compute_cap) +} diff --git a/native/candlex/src/devices.rs b/native/candlex/src/devices.rs new file mode 100644 index 0000000..43a615a --- /dev/null +++ b/native/candlex/src/devices.rs @@ -0,0 +1,4 @@ +#[rustler::nif(schedule = "DirtyCpu")] +pub fn is_cuda_available() -> bool { + candle_core::utils::cuda_is_available() +} diff --git a/native/candlex/src/error.rs b/native/candlex/src/error.rs new file mode 100644 index 0000000..6cdccce --- /dev/null +++ b/native/candlex/src/error.rs @@ -0,0 +1,21 @@ +use rustler::{Encoder, Env, Term}; +use thiserror::Error; + +// Defines the atoms for each value of CandlexError. +rustler::atoms! { + candle, +} + +#[derive(Error, Debug)] +pub enum CandlexError { + #[error("Candle Error: {0}")] + Candle(#[from] candle_core::Error), + #[error("Generic Error: {0}")] + Other(String), +} + +impl Encoder for CandlexError { + fn encode<'b>(&self, env: Env<'b>) -> Term<'b> { + format!("{self}").encode(env) + } +} diff --git a/native/candlex/src/kernels.rs b/native/candlex/src/kernels.rs new file mode 100644 index 0000000..13317b3 --- /dev/null +++ b/native/candlex/src/kernels.rs @@ -0,0 +1,4 @@ +#[rustfmt::skip] +pub const CUSTOM_BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_binary.ptx")); +#[rustfmt::skip] +pub const CUSTOM_UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_unary.ptx")); diff --git a/native/candlex/src/kernels/custom_binary.cu b/native/candlex/src/kernels/custom_binary.cu new file mode 100644 index 0000000..d9d9090 --- /dev/null +++ b/native/candlex/src/kernels/custom_binary.cu @@ -0,0 +1,111 @@ +#include +#include +#include "strides.cuh" + +#define DEVICE_FN_FLOAT_WRAPPER(FN_NAME) \ + __device__ __forceinline__ float FN_NAME##g(float a, float b) { return FN_NAME##f(a, b); } + +#define DEVICE_FN_DOUBLE_WRAPPER(FN_NAME) \ + __device__ __forceinline__ double FN_NAME##g(double a, double b) { return FN_NAME(a, b); } + +DEVICE_FN_FLOAT_WRAPPER(atan2) +DEVICE_FN_DOUBLE_WRAPPER(atan2) +DEVICE_FN_FLOAT_WRAPPER(fmod) +DEVICE_FN_DOUBLE_WRAPPER(fmod) +DEVICE_FN_FLOAT_WRAPPER(pow) +DEVICE_FN_DOUBLE_WRAPPER(pow) + +#define CUSTOM_BINARY_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, FUNC) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *dims_and_strides, \ + const TYPENAME *lhs, \ + const TYPENAME *rhs, \ + OUT_TYPENAME *out \ +) { \ + const size_t *dims = dims_and_strides; \ + const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \ + const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \ + bool lhs_cont = is_contiguous(num_dims, dims, lhs_strides); \ + bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \ + if (lhs_cont && rhs_cont) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + TYPENAME x = lhs[i]; \ + TYPENAME y = rhs[i]; \ + out[i] = FUNC; \ + } \ + } else if (lhs_cont) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int rhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + rhs_i += i_dim * rhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs[i]; \ + TYPENAME y = rhs[rhs_i]; \ + out[i] = FUNC; \ + } \ + } else if (rhs_cont) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int lhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + lhs_i += i_dim * lhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs[lhs_i]; \ + TYPENAME y = rhs[i]; \ + out[i] = FUNC; \ + } \ + } else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int lhs_i = 0; \ + unsigned int rhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + lhs_i += i_dim * lhs_strides[d]; \ + rhs_i += i_dim * rhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs[lhs_i]; \ + TYPENAME y = rhs[rhs_i]; \ + out[i] = FUNC; \ + } \ + } \ +} \ + +#define CUSTOM_BINARY_OP(TYPENAME, FN_NAME, FUNC) \ + CUSTOM_BINARY_OP_OUT(TYPENAME, TYPENAME, FN_NAME, FUNC) + +CUSTOM_BINARY_OP(float, atan2_f32, atan2g(x, y)) +CUSTOM_BINARY_OP(double, atan2_f64, atan2g(x, y)) +CUSTOM_BINARY_OP(uint32_t, bit_and_u32, x & y) +CUSTOM_BINARY_OP(int64_t, bit_and_i64, x & y) +CUSTOM_BINARY_OP(uint32_t, bit_or_u32, x | y) +CUSTOM_BINARY_OP(int64_t, bit_or_i64, x | y) +CUSTOM_BINARY_OP(uint32_t, bit_xor_u32, x ^ y) +CUSTOM_BINARY_OP(int64_t, bit_xor_i64, x ^ y) +CUSTOM_BINARY_OP(float, pow_f32, powg(x, y)) +CUSTOM_BINARY_OP(double, pow_f64, powg(x, y)) +CUSTOM_BINARY_OP(uint8_t, remainder_u8, x % y) +CUSTOM_BINARY_OP(int64_t, remainder_i64, x % y) +CUSTOM_BINARY_OP(float, remainder_f32, fmodg(x, y)) +CUSTOM_BINARY_OP(double, remainder_f64, fmodg(x, y)) +CUSTOM_BINARY_OP(uint32_t, shl_u32, x << y) +CUSTOM_BINARY_OP(int64_t, shl_i64, x << y) +CUSTOM_BINARY_OP(uint32_t, shr_u32, x >> y) +CUSTOM_BINARY_OP(int64_t, shr_i64, x >> y) + +CUSTOM_BINARY_OP_OUT(uint8_t, uint8_t, logical_and_u8, x && y) +CUSTOM_BINARY_OP_OUT(int64_t, uint8_t, logical_and_i64, x && y) +CUSTOM_BINARY_OP_OUT(float, uint8_t, logical_and_f32, x && y) +CUSTOM_BINARY_OP_OUT(uint8_t, uint8_t, logical_or_u8, x || y) +CUSTOM_BINARY_OP_OUT(int64_t, uint8_t, logical_or_i64, x || y) +CUSTOM_BINARY_OP_OUT(float, uint8_t, logical_or_f32, x || y) +CUSTOM_BINARY_OP_OUT(int64_t, uint8_t, logical_xor_i64, !x != !y) +CUSTOM_BINARY_OP_OUT(float, uint8_t, logical_xor_f32, !x != !y) diff --git a/native/candlex/src/kernels/custom_unary.cu b/native/candlex/src/kernels/custom_unary.cu new file mode 100644 index 0000000..40bc3ae --- /dev/null +++ b/native/candlex/src/kernels/custom_unary.cu @@ -0,0 +1,112 @@ +#define _USE_MATH_DEFINES +#include +#include +#include +#include "strides.cuh" + +#define DEVICE_FN_FLOAT_WRAPPER(FN_NAME) \ + __device__ __forceinline__ float FN_NAME##g(float a) { return FN_NAME##f(a); } + +#define DEVICE_FN_DOUBLE_WRAPPER(FN_NAME) \ + __device__ __forceinline__ double FN_NAME##g(double a) { return FN_NAME(a); } + +DEVICE_FN_FLOAT_WRAPPER(acos) +DEVICE_FN_DOUBLE_WRAPPER(acos) +DEVICE_FN_FLOAT_WRAPPER(acosh) +DEVICE_FN_DOUBLE_WRAPPER(acosh) +DEVICE_FN_FLOAT_WRAPPER(asin) +DEVICE_FN_DOUBLE_WRAPPER(asin) +DEVICE_FN_FLOAT_WRAPPER(asinh) +DEVICE_FN_DOUBLE_WRAPPER(asinh) +DEVICE_FN_FLOAT_WRAPPER(atan) +DEVICE_FN_DOUBLE_WRAPPER(atan) +DEVICE_FN_FLOAT_WRAPPER(atanh) +DEVICE_FN_DOUBLE_WRAPPER(atanh) +DEVICE_FN_FLOAT_WRAPPER(cbrt) +DEVICE_FN_DOUBLE_WRAPPER(cbrt) +DEVICE_FN_FLOAT_WRAPPER(cosh) +DEVICE_FN_DOUBLE_WRAPPER(cosh) +DEVICE_FN_FLOAT_WRAPPER(erfc) +DEVICE_FN_DOUBLE_WRAPPER(erfc) +DEVICE_FN_FLOAT_WRAPPER(erfinv) +DEVICE_FN_DOUBLE_WRAPPER(erfinv) +DEVICE_FN_FLOAT_WRAPPER(exp) +DEVICE_FN_DOUBLE_WRAPPER(exp) +DEVICE_FN_FLOAT_WRAPPER(expm1) +DEVICE_FN_DOUBLE_WRAPPER(expm1) +DEVICE_FN_FLOAT_WRAPPER(log1p) +DEVICE_FN_DOUBLE_WRAPPER(log1p) +DEVICE_FN_FLOAT_WRAPPER(sinh) +DEVICE_FN_DOUBLE_WRAPPER(sinh) +DEVICE_FN_FLOAT_WRAPPER(tan) +DEVICE_FN_DOUBLE_WRAPPER(tan) + +#define CUSTOM_UNARY_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, FUNC) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const TYPENAME *inp, \ + OUT_TYPENAME *out \ +) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + if (is_contiguous(num_dims, dims, strides)) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + TYPENAME x = inp ? inp[i] : out[i]; \ + out[i] = FUNC; \ + } \ + } \ + else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + TYPENAME x = inp ? inp[strided_i] : out[i]; \ + out[i] = FUNC; \ + } \ + } \ +} \ + +#define CUSTOM_UNARY_OP(TYPENAME, FN_NAME, FUNC) \ + CUSTOM_UNARY_OP_OUT(TYPENAME, TYPENAME, FN_NAME, FUNC) + +CUSTOM_UNARY_OP(float, acos_f32, acosg(x)) +CUSTOM_UNARY_OP(double, acos_f64, acosg(x)) +CUSTOM_UNARY_OP(float, acosh_f32, acoshg(x)) +CUSTOM_UNARY_OP(double, acosh_f64, acoshg(x)) +CUSTOM_UNARY_OP(float, asin_f32, asing(x)) +CUSTOM_UNARY_OP(double, asin_f64, asing(x)) +CUSTOM_UNARY_OP(float, asinh_f32, asinhg(x)) +CUSTOM_UNARY_OP(double, asinh_f64, asinhg(x)) +CUSTOM_UNARY_OP(float, atan_f32, atang(x)) +CUSTOM_UNARY_OP(double, atan_f64, atang(x)) +CUSTOM_UNARY_OP(float, atanh_f32, atanhg(x)) +CUSTOM_UNARY_OP(double, atanh_f64, atanhg(x)) +CUSTOM_UNARY_OP(uint8_t, bit_not_u8, ~x) +CUSTOM_UNARY_OP(uint32_t, bit_not_u32, ~x) +CUSTOM_UNARY_OP(int64_t, bit_not_i64, ~x) +CUSTOM_UNARY_OP(float, cbrt_f32, cbrtg(x)) +CUSTOM_UNARY_OP(double, cbrt_f64, cbrtg(x)) +CUSTOM_UNARY_OP(float, cosh_f32, coshg(x)) +CUSTOM_UNARY_OP(double, cosh_f64, coshg(x)) +CUSTOM_UNARY_OP(float, erfc_f32, erfcg(x)) +CUSTOM_UNARY_OP(double, erfc_f64, erfcg(x)) +CUSTOM_UNARY_OP(float, erf_inv_f32, erfinvg(x)) +CUSTOM_UNARY_OP(double, erf_inv_f64, erfinvg(x)) +CUSTOM_UNARY_OP(float, expm1_f32, expm1g(x)) +CUSTOM_UNARY_OP(double, expm1_f64, expm1g(x)) +CUSTOM_UNARY_OP(float, ln_1p_f32, log1pg(x)) +CUSTOM_UNARY_OP(double, ln_1p_f64, log1pg(x)) +CUSTOM_UNARY_OP(float, sigmoid_f32, 1.0 / (1.0 + expg(-x))) +CUSTOM_UNARY_OP(double, sigmoid_f64, 1.0 / (1.0 + expg(-x))) +CUSTOM_UNARY_OP(int64_t, sign_i64, x > 0 ? 1 : (x == 0 ? 0 : -1)) +CUSTOM_UNARY_OP(float, sign_f32, signbit(x)) +CUSTOM_UNARY_OP(double, sign_f64, signbit(x)) +CUSTOM_UNARY_OP(float, sinh_f32, sinhg(x)) +CUSTOM_UNARY_OP(double, sinh_f64, sinhg(x)) +CUSTOM_UNARY_OP(float, tan_f32, tang(x)) +CUSTOM_UNARY_OP(double, tan_f64, tang(x)) + +CUSTOM_UNARY_OP_OUT(float, uint8_t, is_inf_f32, isinf(x) ? 1 : 0) +CUSTOM_UNARY_OP_OUT(double, uint8_t, is_inf_f64, isinf(x) ? 1 : 0) +CUSTOM_UNARY_OP_OUT(float, uint8_t, is_nan_f32, isnan(x) ? 1 : 0) +CUSTOM_UNARY_OP_OUT(double, uint8_t, is_nan_f64, isnan(x) ? 1 : 0) diff --git a/native/candlex/src/kernels/strides.cuh b/native/candlex/src/kernels/strides.cuh new file mode 100644 index 0000000..c95123d --- /dev/null +++ b/native/candlex/src/kernels/strides.cuh @@ -0,0 +1,34 @@ +// TODO: This is often used to check that the data is contiguous so that +// kernels can be easily mapped. However this only returns true for row +// major, if all the inputs are column major, we could apply the fast path +// too (but we wouldn't if some of them are row major and some column major). +__device__ bool is_contiguous( + const size_t num_dims, + const size_t *dims, + const size_t *strides +) { + size_t acc = 1; + for (unsigned int d = 0; d < num_dims; d++) { + unsigned int dim_idx = num_dims - 1 - d; + if (acc != strides[dim_idx]) { + return false; + } + acc *= dims[dim_idx]; + } + return true; +} + +__device__ unsigned int get_strided_index( + unsigned int idx, + const size_t num_dims, + const size_t *dims, + const size_t *strides +) { + unsigned int strided_i = 0; + for (unsigned int d = 0; d < num_dims; d++) { + unsigned int dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} diff --git a/native/candlex/src/lib.rs b/native/candlex/src/lib.rs new file mode 100644 index 0000000..8d2ae71 --- /dev/null +++ b/native/candlex/src/lib.rs @@ -0,0 +1,115 @@ +mod atoms { + rustler::atoms! { + cpu, + cuda + } +} + +mod devices; +mod error; +#[cfg(feature = "cuda")] +mod kernels; +mod ops; +mod tensors; + +use rustler::{Env, Term}; +use tensors::TensorRef; + +fn load(env: Env, _info: Term) -> bool { + rustler::resource!(TensorRef, env); + true +} + +rustler::init! { + "Elixir.Candlex.Native", + [ + tensors::from_binary, + tensors::to_binary, + tensors::add, + tensors::atan2, + tensors::subtract, + tensors::multiply, + tensors::divide, + tensors::quotient, + tensors::remainder, + tensors::pow, + tensors::max, + tensors::min, + tensors::equal, + tensors::not_equal, + tensors::greater, + tensors::greater_equal, + tensors::less, + tensors::less_equal, + tensors::all, + tensors::sum, + tensors::dtype, + tensors::t_shape, + tensors::argmax, + tensors::argmin, + tensors::reduce_max, + tensors::reduce_min, + tensors::negate, + tensors::where_cond, + tensors::narrow, + tensors::gather, + tensors::index_select, + tensors::index_add, + tensors::chunk, + tensors::squeeze, + tensors::clamp, + tensors::arange, + tensors::to_type, + tensors::broadcast_to, + tensors::reshape, + tensors::concatenate, + tensors::conv1d, + tensors::conv2d, + tensors::permute, + tensors::slice_scatter, + tensors::pad_with_zeros, + tensors::matmul, + tensors::abs, + tensors::acos, + tensors::acosh, + tensors::asin, + tensors::asinh, + tensors::atan, + tensors::atanh, + tensors::cbrt, + tensors::ceil, + tensors::cos, + tensors::cosh, + tensors::sigmoid, + tensors::sign, + tensors::sin, + tensors::sinh, + tensors::erf, + tensors::erfc, + tensors::erf_inv, + tensors::exp, + tensors::expm1, + tensors::floor, + tensors::is_infinity, + tensors::is_nan, + tensors::round, + tensors::log, + tensors::log1p, + tensors::rsqrt, + tensors::sqrt, + tensors::tan, + tensors::tanh, + tensors::bitwise_not, + tensors::bitwise_and, + tensors::bitwise_or, + tensors::bitwise_xor, + tensors::logical_and, + tensors::logical_or, + tensors::logical_xor, + tensors::left_shift, + tensors::right_shift, + tensors::to_device, + devices::is_cuda_available + ], + load = load +} diff --git a/native/candlex/src/ops.rs b/native/candlex/src/ops.rs new file mode 100644 index 0000000..36450c5 --- /dev/null +++ b/native/candlex/src/ops.rs @@ -0,0 +1,458 @@ +#[cfg(feature = "cuda")] +use candle_core::CudaStorage; +use candle_core::{CpuStorage, CustomOp1, CustomOp2, Error, Layout, Shape}; +use num_traits::cast::FromPrimitive; +use num_traits::Float; + +fn erfc(v: T) -> T { + FromPrimitive::from_f64(statrs::function::erf::erfc(v.to_f64().unwrap())).unwrap() +} + +fn erf_inv(v: T) -> T { + FromPrimitive::from_f64(statrs::function::erf::erf_inv(v.to_f64().unwrap())).unwrap() +} + +macro_rules! custom_unary_op { + ($struct_name:ident, $name:expr, $cpu_closure:expr, ($($dtypes:ident),+)) => { + pub(crate) struct $struct_name; + + impl CustomOp1 for $struct_name { + // Box does not support const yet, so use a function to get the name. + fn name(&self) -> &'static str { + $name + } + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + storage: &CpuStorage, + layout: &Layout, + ) -> Result<(CpuStorage, Shape), candle_core::Error> { + use candle_core::backend::BackendStorage; + + match storage { + $( + CpuStorage::$dtypes(vec) => { + Ok( + ( + CpuStorage::$dtypes(candle_core::cpu_backend::unary_map(vec, layout, $cpu_closure)), + layout.shape().clone() + ) + ) + } + )* + s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())? + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + storage: &CudaStorage, + layout: &Layout, + ) -> Result<(CudaStorage, Shape), candle_core::Error> { + use crate::kernels; + use candle_core::cuda_backend::cudarc::driver::{CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig}; + use candle_core::cuda_backend::{kernel_name, Map1, WrapErr}; + use candle_core::{CudaDevice, WithDType}; + + impl Map1 for $struct_name { + fn f( + &self, + src: &CudaSlice, + device: &CudaDevice, + layout: &Layout, + ) -> Result, candle_core::Error> { + let src = src.slice(layout.start_offset()..); + let func = device.get_or_load_func(&kernel_name::($name), kernels::CUSTOM_UNARY)?; + let dims = layout.shape().dims(); + let elem_count = layout.shape().elem_count(); + let launch_config = LaunchConfig::for_num_elems(elem_count as u32); + let dims_and_strides = device.htod_copy([dims, layout.stride()].concat()).w()?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { device.alloc::(elem_count) }.w()?; + let params = (elem_count, dims.len(), &dims_and_strides, &src, &dst); + // SAFETY: ffi. + unsafe { func.launch(launch_config, params) }.w()?; + + Ok(dst) + } + } + + use candle_core::backend::BackendStorage; + let device = storage.device(); + let slice = $struct_name.map(&storage.slice, device, layout)?; + + Ok( + ( + CudaStorage { + slice, + device: device.clone(), + }, + layout.shape().clone() + ) + ) + } + } + }; +} + +macro_rules! custom_unary_bool_op { + ($struct_name:ident, $name:expr, $fn_name:ident, ($($dtypes:ident),+)) => { + pub(crate) struct $struct_name; + + impl CustomOp1 for $struct_name { + // Box does not support const yet, so use a function to get the name. + fn name(&self) -> &'static str { + $name + } + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + storage: &CpuStorage, + layout: &Layout, + ) -> Result<(CpuStorage, Shape), candle_core::Error> { + use candle_core::backend::BackendStorage; + + match storage { + $( + CpuStorage::$dtypes(vec) => { + Ok( + ( + CpuStorage::U8( + candle_core::cpu_backend::unary_map(vec, layout, |v| u8::from(v.$fn_name())) + ), + layout.shape().clone() + ) + ) + } + )* + s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())? + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + storage: &CudaStorage, + layout: &Layout, + ) -> Result<(CudaStorage, Shape), candle_core::Error> { + use crate::kernels; + use candle_core::cuda_backend::cudarc::driver::{CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits}; + use candle_core::cuda_backend::{kernel_name, CudaStorageSlice, Map1Any, WrapErr}; + use candle_core::{CudaDevice, WithDType}; + + impl Map1Any for $struct_name { + fn f) -> CudaStorageSlice>( + &self, + src: &CudaSlice, + device: &CudaDevice, + layout: &Layout, + _wrap: W, + ) -> Result { + let src = src.slice(layout.start_offset()..); + let func = device.get_or_load_func(&kernel_name::($name), kernels::CUSTOM_UNARY)?; + let dims = layout.shape().dims(); + let elem_count = layout.shape().elem_count(); + let launch_config = LaunchConfig::for_num_elems(elem_count as u32); + let dims_and_strides = device.htod_copy([dims, layout.stride()].concat()).w()?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { device.alloc::(elem_count) }.w()?; + let params = (elem_count, dims.len(), &dims_and_strides, &src, &dst); + // SAFETY: ffi. + unsafe { func.launch(launch_config, params) }.w()?; + + Ok(CudaStorageSlice::U8(dst)) + } + } + + use candle_core::backend::BackendStorage; + let device = storage.device(); + let slice = $struct_name.map(&storage.slice, device, layout)?; + + Ok( + ( + CudaStorage { + slice, + device: device.clone(), + }, + layout.shape().clone() + ) + ) + } + } + }; +} + +macro_rules! custom_binary_op { + ($struct_name:ident, $name:literal, $cpu_closure:expr, ($($dtypes:ident),+)) => { + pub(crate) struct $struct_name; + + impl CustomOp2 for $struct_name { + fn name(&self) -> &'static str { + $name + } + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + ) -> Result<(CpuStorage, Shape), candle_core::Error> { + use candle_core::backend::BackendStorage; + + match (s1, s2) { + $( + (CpuStorage::$dtypes(lhs), CpuStorage::$dtypes(rhs)) => { + Ok( + ( + CpuStorage::$dtypes( + candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, $cpu_closure) + ), + l1.shape().clone() + ) + ) + } + )* + _ => { + Err(Error::DTypeMismatchBinaryOp { + lhs: s1.dtype(), + rhs: s2.dtype(), + op: self.name(), + } + .bt()) + } + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &CudaStorage, + l1: &Layout, + s2: &CudaStorage, + l2: &Layout, + ) -> Result<(CudaStorage, Shape), candle_core::Error> { + use crate::kernels; + use candle_core::cuda_backend::cudarc::driver::{CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits}; + use candle_core::cuda_backend::{kernel_name, Map2, WrapErr}; + use candle_core::{CudaDevice, WithDType}; + + impl Map2 for $struct_name { + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + device: &CudaDevice, + ) -> Result, candle_core::Error> { + let shape1 = layout1.shape(); + let dims1 = shape1.dims(); + let elem_count1 = shape1.elem_count(); + let launch_config = LaunchConfig::for_num_elems(elem_count1 as u32); + let dims_and_strides = device + .htod_copy([dims1, layout1.stride(), layout2.stride()].concat()) + .w()?; + let src1 = src1.slice(layout1.start_offset()..); + let src2 = src2.slice(layout2.start_offset()..); + let func = device.get_or_load_func(&kernel_name::($name), kernels::CUSTOM_BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { device.alloc::(elem_count1) }.w()?; + let params = (elem_count1, dims1.len(), &dims_and_strides, &src1, &src2, &out); + // SAFETY: ffi + unsafe { func.launch(launch_config, params) }.w()?; + + Ok(out) + } + } + + use candle_core::backend::BackendStorage; + let device = s1.device(); + let slice = $struct_name.map(&s1.slice, l1, &s2.slice, l2, device)?; + + Ok( + ( + CudaStorage { + slice, + device: device.clone(), + }, + l1.shape().clone() + ) + ) + } + } + } +} + +macro_rules! custom_binary_bool_op { + ($struct_name:ident, $name:literal, $cpu_closure:expr, ($($dtypes:ident),+)) => { + pub(crate) struct $struct_name; + + impl CustomOp2 for $struct_name { + fn name(&self) -> &'static str { + $name + } + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + ) -> Result<(CpuStorage, Shape), candle_core::Error> { + use candle_core::backend::BackendStorage; + + match (s1, s2) { + $( + (CpuStorage::$dtypes(lhs), CpuStorage::$dtypes(rhs)) => { + Ok( + ( + CpuStorage::U8( + candle_core::cpu_backend::binary_map( + l1, + l2, + lhs, + rhs, + |v1, v2| u8::from($cpu_closure(v1, v2)) + ) + ), + l1.shape().clone() + ) + ) + } + )* + _ => { + Err(Error::DTypeMismatchBinaryOp { + lhs: s1.dtype(), + rhs: s2.dtype(), + op: self.name(), + } + .bt()) + } + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &CudaStorage, + l1: &Layout, + s2: &CudaStorage, + l2: &Layout, + ) -> Result<(CudaStorage, Shape), candle_core::Error> { + use crate::kernels; + use candle_core::cuda_backend::cudarc::driver::{CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits}; + use candle_core::cuda_backend::{kernel_name, CudaStorageSlice, Map2Any, WrapErr}; + use candle_core::{CudaDevice, WithDType}; + + impl Map2Any for $struct_name { + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + device: &CudaDevice, + ) -> Result { + let shape1 = layout1.shape(); + let dims1 = shape1.dims(); + let elem_count1 = shape1.elem_count(); + let launch_config = LaunchConfig::for_num_elems(elem_count1 as u32); + let dims_and_strides = device + .htod_copy([dims1, layout1.stride(), layout2.stride()].concat()) + .w()?; + let src1 = src1.slice(layout1.start_offset()..); + let src2 = src2.slice(layout2.start_offset()..); + let func = device.get_or_load_func(&kernel_name::($name), kernels::CUSTOM_BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { device.alloc::(elem_count1) }.w()?; + let params = (elem_count1, dims1.len(), &dims_and_strides, &src1, &src2, &out); + // SAFETY: ffi + unsafe { func.launch(launch_config, params) }.w()?; + + Ok(CudaStorageSlice::U8(out)) + } + } + + use candle_core::backend::BackendStorage; + let device = s1.device(); + let slice = $struct_name.map(&s1.slice, l1, &s2.slice, l2, device)?; + + Ok( + ( + CudaStorage { + slice, + device: device.clone(), + }, + l1.shape().clone() + ) + ) + } + } + } +} + +custom_unary_op!(Acos, "acos", |v| v.acos(), (BF16, F16, F32, F64)); +custom_unary_op!(Acosh, "acosh", |v| v.acosh(), (BF16, F16, F32, F64)); +custom_unary_op!(Asin, "asin", |v| v.asin(), (BF16, F16, F32, F64)); +custom_unary_op!(Asinh, "asinh", |v| v.asinh(), (BF16, F16, F32, F64)); +custom_unary_op!(Atan, "atan", |v| v.atan(), (BF16, F16, F32, F64)); +custom_unary_op!(Atanh, "atanh", |v| v.atanh(), (BF16, F16, F32, F64)); +custom_unary_op!(BitNot, "bit_not", |v| !v, (U8, U32, I64)); +custom_unary_op!(Cbrt, "cbrt", |v| v.cbrt(), (BF16, F16, F32, F64)); +custom_unary_op!(Cosh, "cosh", |v| v.cosh(), (BF16, F16, F32, F64)); +custom_unary_op!(Erfc, "erfc", |v| erfc(v), (BF16, F16, F32, F64)); +custom_unary_op!(ErfInv, "erf_inv", |v| erf_inv(v), (BF16, F16, F32, F64)); +custom_unary_op!(Expm1, "expm1", |v| v.exp_m1(), (BF16, F16, F32, F64)); +custom_unary_op!(Log1p, "ln_1p", |v| v.ln_1p(), (BF16, F16, F32, F64)); +custom_unary_op!(Sigmoid, "sigmoid", |v| 1. / (1. + (-v).exp()), (F32, F64)); +custom_unary_op!(Sign, "sign", |v| v.signum(), (I64, BF16, F16, F32, F64)); +custom_unary_op!(Sinh, "sinh", |v| v.sinh(), (BF16, F16, F32, F64)); +custom_unary_op!(Tan, "tan", |v| v.tan(), (BF16, F16, F32, F64)); +custom_unary_bool_op!(IsInf, "is_inf", is_infinite, (F32, F64)); +custom_unary_bool_op!(IsNan, "is_nan", is_nan, (F32, F64)); + +custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2, (U32, I64)); +custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2, (U32, I64)); +custom_binary_op!(BitXor, "bit_xor", |v1, v2| v1 ^ v2, (U32, I64)); +custom_binary_op!(Atan2, "atan2", |v1, v2| v1.atan2(v2), (F32, F64)); +custom_binary_op!(Pow, "pow", |v1, v2| v1.powf(v2), (F32, F64)); +custom_binary_op!( + Remainder, + "remainder", + |v1, v2| v1 % v2, + (U8, I64, F32, F64) +); +custom_binary_op!(Shl, "shl", |v1, v2| v1 << v2, (U32, I64)); +custom_binary_op!(Shr, "shr", |v1, v2| v1 >> v2, (U32, I64)); +custom_binary_bool_op!( + LogicalAnd, + "logical_and", + |v1, v2| if v1 as i8 != 0 && v2 as i8 != 0 { 1 } else { 0 }, + (U8, U32, I64, F32, F64) +); +custom_binary_bool_op!( + LogicalOr, + "logical_or", + |v1, v2| if v1 as i8 == 0 && v2 as i8 == 0 { 0 } else { 1 }, + (U8, U32, I64, F32, F64) +); +custom_binary_bool_op!( + LogicalXor, + "logical_xor", + |v1, v2| if (v1 as i8 != 0) == (v2 as i8 != 0) { + 0 + } else { + 1 + }, + (U8, U32, I64, F32, F64) +); diff --git a/native/candlex/src/tensors.rs b/native/candlex/src/tensors.rs new file mode 100644 index 0000000..3ea396a --- /dev/null +++ b/native/candlex/src/tensors.rs @@ -0,0 +1,527 @@ +use crate::atoms; +use crate::error::CandlexError; +use crate::ops::{ + Acos, Acosh, Asin, Asinh, Atan, Atan2, Atanh, BitAnd, BitNot, BitOr, BitXor, Cbrt, Cosh, + ErfInv, Erfc, Expm1, IsInf, IsNan, Log1p, LogicalAnd, LogicalOr, LogicalXor, Pow, Remainder, + Shl, Shr, Sigmoid, Sign, Sinh, Tan, +}; +use candle_core::{DType, Device, Tensor}; +use half::{bf16, f16}; +use rustler::{Atom, Binary, Encoder, Env, NewBinary, NifStruct, ResourceArc, Term}; +use std::ops::Deref; +use std::result::Result; +use std::str::FromStr; + +pub(crate) struct TensorRef(Tensor); + +#[derive(NifStruct)] +#[module = "Candlex.Backend"] +pub struct ExTensor { + device: Atom, + resource: ResourceArc, +} + +impl ExTensor { + pub fn new(tensor: Tensor) -> Self { + let dev_string = match tensor.device() { + Device::Cpu => atoms::cpu(), + Device::Cuda(_) => atoms::cuda(), + }; + + Self { + device: dev_string, + resource: ResourceArc::new(TensorRef(tensor)), + } + } +} + +// Implement Deref so we can call `Tensor` functions directly from an `ExTensor` struct. +impl Deref for ExTensor { + type Target = Tensor; + + fn deref(&self) -> &Self::Target { + &self.resource.0 + } +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn from_binary( + binary: Binary, + dtype_str: &str, + shape: Term, + device: Atom, +) -> Result { + Ok(ExTensor::new(Tensor::from_raw_buffer( + binary.as_slice(), + // TODO: Handle DTypeParseError + DType::from_str(dtype_str).unwrap(), + // TODO: Handle rustler::Error + &tuple_to_vec(shape).unwrap(), + &device_from_atom(device)?, + )?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn to_device(ex_tensor: ExTensor, device: Atom) -> Result { + Ok(ExTensor::new( + ex_tensor.to_device(&device_from_atom(device)?)?, + )) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Result { + let bytes = tensor_bytes(ex_tensor.flatten_all()?)?; + let mut binary = NewBinary::new(env, bytes.len()); + binary.as_mut_slice().copy_from_slice(bytes.as_slice()); + + Ok(binary.into()) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn narrow( + t: ExTensor, + dim: usize, + start: usize, + length: usize, +) -> Result { + Ok(ExTensor::new(t.narrow(dim, start, length)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn gather(t: ExTensor, indexes: ExTensor, dim: usize) -> Result { + Ok(ExTensor::new(t.gather(indexes.deref(), dim)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn index_select(t: ExTensor, indexes: ExTensor, dim: usize) -> Result { + Ok(ExTensor::new(t.index_select(indexes.deref(), dim)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn index_add( + t: ExTensor, + indexes: ExTensor, + source: ExTensor, + dim: usize, +) -> Result { + Ok(ExTensor::new(t.index_add( + indexes.deref(), + source.deref(), + dim, + )?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn chunk(t: ExTensor, num_chunks: usize) -> Result, CandlexError> { + Ok(t.chunk(num_chunks, 0)? + .into_iter() + .map(|t| ExTensor::new(t)) + .collect()) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn squeeze(t: ExTensor, dim: usize) -> Result { + Ok(ExTensor::new(t.squeeze(dim)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn clamp(t: ExTensor, min_val: ExTensor, max_val: ExTensor) -> Result { + Ok(ExTensor::new(t.clamp( + &min_val.broadcast_as(t.shape())?, + &max_val.broadcast_as(t.shape())?, + )?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn rsqrt(t: ExTensor) -> Result { + Ok(ExTensor::new(t.sqrt()?.recip()?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn arange( + start: i64, + end: i64, + dtype_str: &str, + shape: Term, + device: Atom, +) -> Result { + Ok(ExTensor::new( + Tensor::arange(start, end, &device_from_atom(device)?)? + .to_dtype(DType::from_str(dtype_str).unwrap())? + .reshape(tuple_to_vec(shape).unwrap())?, + )) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn all(ex_tensor: ExTensor) -> Result { + let device = ex_tensor.device(); + let t = ex_tensor.flatten_all()?; + let dims = t.shape().dims(); + let on_true = Tensor::ones(dims, DType::U8, device)?; + let on_false = Tensor::zeros(dims, DType::U8, device)?; + + let bool_scalar = match t + .where_cond(&on_true, &on_false)? + .min(0)? + .to_scalar::()? + { + 0 => 0u8, + _ => 1u8, + }; + + Ok(ExTensor::new(Tensor::new(bool_scalar, device)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn argmax(ex_tensor: ExTensor, dim: usize, keep_dim: bool) -> Result { + let t = if keep_dim { + ex_tensor.argmax_keepdim(dim)? + } else { + ex_tensor.argmax(dim)? + }; + + Ok(ExTensor::new(t)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn argmin(ex_tensor: ExTensor, dim: usize, keep_dim: bool) -> Result { + let t = if keep_dim { + ex_tensor.argmin_keepdim(dim)? + } else { + ex_tensor.argmin(dim)? + }; + + Ok(ExTensor::new(t)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn reduce_max( + ex_tensor: ExTensor, + dim: usize, + keep_dim: bool, +) -> Result { + let t = if keep_dim { + ex_tensor.max_keepdim(dim)? + } else { + ex_tensor.max(dim)? + }; + + Ok(ExTensor::new(t)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn reduce_min( + ex_tensor: ExTensor, + dim: usize, + keep_dim: bool, +) -> Result { + let t = if keep_dim { + ex_tensor.min_keepdim(dim)? + } else { + ex_tensor.min(dim)? + }; + + Ok(ExTensor::new(t)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn sum( + ex_tensor: ExTensor, + dims: Vec, + keep_dims: bool, +) -> Result { + let t = if keep_dims { + ex_tensor.sum_keepdim(dims)? + } else { + ex_tensor.sum(dims)? + }; + + Ok(ExTensor::new(t)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn permute(ex_tensor: ExTensor, dims: Vec) -> Result { + Ok(ExTensor::new(ex_tensor.permute(dims)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn broadcast_to(t: ExTensor, shape: Term) -> Result { + Ok(ExTensor::new(t.broadcast_as(tuple_to_vec(shape).unwrap())?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn reshape(t: ExTensor, shape: Term) -> Result { + Ok(ExTensor::new(t.reshape(tuple_to_vec(shape).unwrap())?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn slice_scatter( + t: ExTensor, + src: ExTensor, + dim: usize, + start: usize, +) -> Result { + Ok(ExTensor::new(t.slice_scatter(src.deref(), dim, start)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn pad_with_zeros(t: ExTensor, left: usize, right: usize) -> Result { + Ok(ExTensor::new(t.pad_with_zeros(0, left, right)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn where_cond( + t: ExTensor, + on_true: ExTensor, + on_false: ExTensor, +) -> Result { + Ok(ExTensor::new(t.where_cond(&on_true, &on_false)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn to_type(t: ExTensor, dtype_str: &str) -> Result { + Ok(ExTensor::new( + t.to_dtype(DType::from_str(dtype_str).unwrap())?, + )) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn dtype(t: ExTensor) -> Result<&'static str, CandlexError> { + Ok(t.dtype().as_str()) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn t_shape(env: Env, t: ExTensor) -> Result { + Ok(vec_to_tuple(env, t.shape().clone().into_dims()).unwrap()) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn concatenate(ex_tensors: Vec, dim: usize) -> Result { + let tensors = ex_tensors + .iter() + .map(|t| t.deref()) + .collect::>(); + Ok(ExTensor::new(Tensor::cat(&tensors[..], dim)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn conv1d(tensor: ExTensor, kernel: ExTensor) -> Result { + let padding = 0; + let stride = 1; + let dilation = 1; + let groups = 1; + + Ok(ExTensor::new(tensor.conv1d( + kernel.deref(), + padding, + stride, + dilation, + groups, + )?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn conv2d(tensor: ExTensor, kernel: ExTensor) -> Result { + let padding = 0; + let stride = 1; + let dilation = 1; + let groups = 1; + + Ok(ExTensor::new(tensor.conv2d( + kernel.deref(), + padding, + stride, + dilation, + groups, + )?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn divide(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new( + // Need to force float in case we receive integers, given + // candle rounds down integer division. + left.to_dtype(DType::F32)? + .broadcast_div(&right.to_dtype(DType::F32)?)?, + )) +} + +macro_rules! unary_nif { + ($nif_name:ident, $native_fn_name:ident) => { + #[rustler::nif(schedule = "DirtyCpu")] + pub fn $nif_name(ex_tensor: ExTensor) -> Result { + Ok(ExTensor::new(ex_tensor.$native_fn_name()?)) + } + }; + ($nif_name:ident) => { + unary_nif!($nif_name, $nif_name); + }; +} + +macro_rules! binary_nif { + ($nif_name:ident, $native_fn_name:ident) => { + #[rustler::nif(schedule = "DirtyCpu")] + pub fn $nif_name(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new(left.$native_fn_name(right.deref())?)) + } + }; +} + +macro_rules! custom_unary_nif { + ($nif_name:ident, $custom_op_name:ident) => { + #[rustler::nif(schedule = "DirtyCpu")] + pub fn $nif_name(ex_tensor: ExTensor) -> Result { + Ok(ExTensor::new(ex_tensor.apply_op1_no_bwd(&$custom_op_name)?)) + } + }; +} + +macro_rules! custom_binary_nif { + ($nif_name:ident, $custom_op_name:ident) => { + #[rustler::nif(schedule = "DirtyCpu")] + pub fn $nif_name(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new( + left.apply_op2_no_bwd(right.deref(), &$custom_op_name)?, + )) + } + }; +} + +unary_nif!(negate, neg); +unary_nif!(abs); +unary_nif!(ceil); +unary_nif!(cos); +unary_nif!(erf); +unary_nif!(exp); +unary_nif!(floor); +unary_nif!(round); +unary_nif!(sin); +unary_nif!(log); +unary_nif!(sqrt); +unary_nif!(tanh); + +custom_unary_nif!(acos, Acos); +custom_unary_nif!(acosh, Acosh); +custom_unary_nif!(asin, Asin); +custom_unary_nif!(asinh, Asinh); +custom_unary_nif!(atan, Atan); +custom_unary_nif!(atanh, Atanh); +custom_unary_nif!(bitwise_not, BitNot); +custom_unary_nif!(cbrt, Cbrt); +custom_unary_nif!(cosh, Cosh); +custom_unary_nif!(erfc, Erfc); +custom_unary_nif!(erf_inv, ErfInv); +custom_unary_nif!(expm1, Expm1); +custom_unary_nif!(is_infinity, IsInf); +custom_unary_nif!(is_nan, IsNan); +custom_unary_nif!(log1p, Log1p); +custom_unary_nif!(sigmoid, Sigmoid); +custom_unary_nif!(sign, Sign); +custom_unary_nif!(sinh, Sinh); +custom_unary_nif!(tan, Tan); + +binary_nif!(add, broadcast_add); +binary_nif!(subtract, broadcast_sub); +binary_nif!(multiply, broadcast_mul); +binary_nif!(quotient, broadcast_div); +binary_nif!(max, broadcast_maximum); +binary_nif!(min, broadcast_minimum); +binary_nif!(equal, eq); +binary_nif!(not_equal, ne); +binary_nif!(greater, gt); +binary_nif!(greater_equal, ge); +binary_nif!(less, lt); +binary_nif!(less_equal, le); +binary_nif!(matmul, broadcast_matmul); + +custom_binary_nif!(atan2, Atan2); +custom_binary_nif!(bitwise_and, BitAnd); +custom_binary_nif!(bitwise_or, BitOr); +custom_binary_nif!(bitwise_xor, BitXor); +custom_binary_nif!(left_shift, Shl); +custom_binary_nif!(logical_and, LogicalAnd); +custom_binary_nif!(logical_or, LogicalOr); +custom_binary_nif!(logical_xor, LogicalXor); +custom_binary_nif!(pow, Pow); +custom_binary_nif!(right_shift, Shr); +custom_binary_nif!(remainder, Remainder); + +fn tuple_to_vec(term: Term) -> Result, rustler::Error> { + Ok(rustler::types::tuple::get_tuple(term)? + .iter() + .map(|elem| elem.decode()) + .collect::>()?) +} + +fn vec_to_tuple(env: Env, vec: Vec) -> Result { + Ok(rustler::types::tuple::make_tuple( + env, + &vec.into_iter() + .map(|elem| elem.encode(env)) + .collect::>(), + )) +} + +static CUDA_DEVICE: std::sync::Mutex> = std::sync::Mutex::new(None); + +fn device_from_atom(atom: Atom) -> Result { + if atom == atoms::cpu() { + Ok(Device::Cpu) + } else if atom == atoms::cuda() { + let mut cuda_device = CUDA_DEVICE.lock().unwrap(); + + if let Some(device) = cuda_device.as_ref() { + Ok(device.clone()) + } else { + let new_cuda_device = Device::new_cuda(0)?; + *cuda_device = Some(new_cuda_device.clone()); + + Ok(new_cuda_device) + } + } else { + Err(CandlexError::Other(format!( + "unsupported device {:?}", + atom + ))) + } +} + +fn tensor_bytes(tensor: Tensor) -> Result, CandlexError> { + Ok(match tensor.dtype() { + DType::I64 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::U8 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::U32 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::F16 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::F32 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::F64 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::BF16 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + }) +} diff --git a/test/candlex_test.exs b/test/candlex_test.exs index 9a9f9e2..1fe8c0d 100644 --- a/test/candlex_test.exs +++ b/test/candlex_test.exs @@ -1,8 +1,2194 @@ defmodule CandlexTest do - use ExUnit.Case + use Nx.Case, async: true doctest Candlex - test "greets the world" do - assert Candlex.hello() == :world + describe "creation" do + test "tensor" do + check(255, type: :u8) + check(100_002, type: :u32) + check(100_102, type: :u64) + check(-101, type: :s64) + check(1.16, type: :f16) + check(1.32, type: :f32) + check([1, 2, 3], type: :f32) + check(-0.002, type: :f64) + check([1, 2], type: :u32) + check([[1, 2], [3, 4]], type: :u32) + check([[1, 2, 3, 4], [5, 6, 7, 8]], type: :u32) + check([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], type: :u32) + check([0, 255], type: :u8) + check([-0.5, 0.88], type: :f32) + check([-0.5, 0.88], type: :f64) + check(2.16, type: :bf16) + end + + test "named dimensions" do + check([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) + + t([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) + |> assert_equal(t([[1, 2, 3], [4, 5, 6]])) + end + + test "tensor tensor" do + t(t([1, 2, 3])) + |> assert_equal(t([1, 2, 3])) + end + + test "tril" do + t([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + |> Nx.tril() + |> assert_equal(t([[1, 0, 0], [4, 5, 0], [7, 8, 9]])) + end + + test "triu" do + t([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + |> Nx.triu() + |> assert_equal(t([[1, 2, 3], [0, 5, 6], [0, 0, 9]])) + end + + test "addition" do + t([1, 2, 3]) + |> Nx.add(t([10, 20, 30])) + |> assert_equal(t([11, 22, 33])) + + t([1, 2, 3], type: :u64) + |> Nx.add(t([10, 20, 30], type: :u64)) + |> assert_equal(t([11, 22, 33])) + + Nx.add(1, 2.2) + |> assert_equal(t(3.2)) + + t([1, 2, 3]) + |> Nx.add(1.0) + |> assert_equal(t([2.0, 3.0, 4.0])) + end + + test "iota" do + Nx.iota({}) + |> assert_equal(t(0)) + + Nx.iota({}, type: :f32) + |> assert_equal(t(0.0)) + + Nx.iota({5}) + |> assert_equal(t([0, 1, 2, 3, 4])) + + Nx.iota({5}, type: :u64) + |> assert_equal(t([0, 1, 2, 3, 4])) + + Nx.iota({5}, type: :f32) + |> assert_equal(t([0.0, 1.0, 2.0, 3.0, 4.0])) + + Nx.iota({2, 3}) + |> assert_equal(t([[0, 1, 2], [3, 4, 5]])) + + Nx.iota({3, 3}, axis: 1) + |> assert_equal(t( + [ + [0, 1, 2], + [0, 1, 2], + [0, 1, 2] + ] + )) + + Nx.iota({3, 3}, axis: -1) + |> assert_equal(t( + [ + [0, 1, 2], + [0, 1, 2], + [0, 1, 2] + ] + )) + + Nx.iota({3, 4, 3}, axis: 0, type: :f64) + |> assert_equal(t( + [ + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0] + ], + [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0] + ], + [ + [2.0, 2.0, 2.0], + [2.0, 2.0, 2.0], + [2.0, 2.0, 2.0], + [2.0, 2.0, 2.0] + ] + ] + )) + + Nx.iota({1, 3, 2}, axis: 2) + |> assert_equal(t( + [ + [ + [0, 1], + [0, 1], + [0, 1] + ] + ] + )) + end + + test "max" do + Nx.max(1, 2) + |> assert_equal(t(2)) + + Nx.max(1, t([1.0, 2.0, 3.0], names: [:data])) + |> assert_equal(t([1.0, 2.0, 3.0])) + + t([[1], [2]], type: :f32, names: [:x, nil]) + |> Nx.max(t([[10, 20]], type: :f32, names: [nil, :y])) + |> assert_equal(t([[10.0, 20.0], [10.0, 20.0]])) + end + + test "min" do + Nx.min(1, 2) + |> assert_equal(t(1)) + + Nx.min(1, t([1.0, 2.0, 3.0], names: [:data])) + |> assert_equal(t([1.0, 1.0, 1.0])) + + t([[1], [2]], type: :f32, names: [:x, nil]) + |> Nx.min(t([[10, 20]], type: :f32, names: [nil, :y])) + |> assert_equal(t([[1.0, 1.0], [2.0, 2.0]])) + end + + test "multiply" do + t([1, 2]) + |> Nx.multiply(t([3, 4])) + |> assert_equal(t([3, 8])) + + t([1, 2], type: :u64) + |> Nx.multiply(t([3, 4], type: :u64)) + |> assert_equal(t([3, 8])) + + t([[1], [2]]) + |> Nx.multiply(t([3, 4])) + |> assert_equal(t([[3, 4], [6, 8]])) + + t([1, 2]) + |> Nx.multiply(t([[3], [4]])) + |> assert_equal(t([[3, 6], [4, 8]])) + end + + test "divide/2" do + 1.0 + |> Nx.divide(2) + |> assert_equal(t(0.5)) + + t([1.0, 2, 3]) + |> Nx.divide(1) + |> assert_equal(t([1.0, 2.0, 3.0])) + + t([[1.0], [2]]) + |> Nx.divide(t([[10, 20]])) + |> assert_equal(t( + [ + [0.10000000149011612, 0.05000000074505806], + [0.20000000298023224, 0.10000000149011612] + ] + )) + + 1 + |> Nx.divide(2) + |> assert_equal(t(0.5)) + + t([1, 2, 3]) + |> Nx.divide(2) + |> assert_equal(t([0.5, 1.0, 1.5])) + + t([[1], [2]]) + |> Nx.divide(t([[10, 20]])) + |> assert_equal(t( + [ + [0.10000000149011612, 0.05000000074505806], + [0.20000000298023224, 0.10000000149011612] + ] + )) + end + + test "remainder" do + Nx.remainder(1, 2) + |> assert_equal(t(1)) + + t([1, 2, 3]) + |> Nx.remainder(2) + |> assert_equal(t([1, 0, 1])) + + 2 + |> Nx.remainder(t([1.0, 2.0, 3.0])) + |> assert_equal(t([0.0, 0.0, 2.0])) + + t([[10], [20]], names: [:x, :y]) + |> Nx.remainder(t([[3, 4]], names: [nil, :y])) + |> assert_equal(t( + [ + [1, 2], + [2, 0] + ] + )) + + left = t(-11) + right = t(10, type: :u8) + + Nx.remainder(left, right) + |> assert_equal(t(-1)) + + left + |> Nx.add(t(20)) + |> Nx.remainder(right) + |> assert_equal(t(9)) + + positive_left = t(9, type: :u8) + Nx.remainder(positive_left, right) + |> assert_equal(t(9)) + + positive_left + |> Nx.add(Nx.tensor(20, type: :u8)) + |> Nx.remainder(right) + |> assert_equal(t(9)) + end + + test "quotient" do + Nx.quotient(11, 2) + |> assert_equal(t(5)) + + t([2, 4, 5]) + |> Nx.quotient(2) + |> assert_equal(t([1, 2, 2])) + + 10 + |> Nx.quotient(t([1, 2, 3])) + |> assert_equal(t([10, 5, 3])) + + t([[10, 20]], names: [nil, :y]) + |> Nx.quotient(t([[1], [2]], names: [:x, nil])) + |> assert_equal(t( + [ + [10, 20], + [5, 10] + ] + )) + + t([[10, 20]]) + |> Nx.quotient(t([[1], [2]])) + |> assert_equal(t( + [ + [10, 20], + [5, 10] + ] + )) + + t([[10, 20]], type: :u8) + |> Nx.quotient(t([[1], [2]], type: :u32)) + |> assert_equal(t( + [ + [10, 20], + [5, 10] + ] + )) + end + + test "sign" do + t([-2, -1, 0, 1, 2]) + |> Nx.sign() + |> assert_equal(t([-1, -1, 0, 1, 1])) + end + + test "atan2" do + Nx.atan2(1.0, 2.0) + |> assert_close(t(0.46364760398864746)) + + t([1.0, 2, 3]) + |> Nx.atan2(1) + |> assert_close(t([0.7853981852531433, 1.1071487665176392, 1.249045729637146])) + + 1.0 + |> Nx.atan2(t([1.0, 2.0, 3.0])) + |> assert_close(t([0.7853981852531433, 0.46364760398864746, 0.32175055146217346])) + + t([[-0.0], [0.0]], type: :f64) + |> Nx.atan2(t([-0.0, 0.0], type: :f64)) + |> assert_close(t( + [ + [-3.141592653589793, -0.0], + [3.141592653589793, 0.0] + ] + )) + end + + test "broadcast" do + Nx.broadcast(1, {1, 2, 3}) + |> assert_equal(t([[[1, 1, 1], [1, 1, 1]]])) + + t([1, 2, 3]) + |> Nx.broadcast({3, 2}, axes: [0]) + |> assert_equal(t([[1, 1], [2, 2], [3, 3]])) + end + + test "access" do + tensor = t([[1, 2], [3, 4]]) + + assert_equal(tensor[0], t([1, 2])) + assert_equal(tensor[1], t([3, 4])) + end + + test "concatenate" do + [t([1, 2, 3])] + |> Nx.concatenate() + |> assert_equal(t([1, 2, 3])) + + [t([1, 2, 3]), t([4, 5, 6])] + |> Nx.concatenate() + |> assert_equal(t([1, 2, 3, 4, 5, 6])) + + t1 = Nx.iota({2, 2, 2}, names: [:x, :y, :z], type: :f32) + t2 = Nx.iota({1, 2, 2}, names: [:x, :y, :z], type: :u8) + t3 = Nx.iota({1, 2, 2}, names: [:x, :y, :z], type: :s64) + + [t1, t2, t3] + |> Nx.concatenate(axis: :x) + |> assert_equal( + t([ + [ + [0.0, 1.0], + [2.0, 3.0] + ], + [ + [4.0, 5.0], + [6.0, 7.0] + ], + [ + [0.0, 1.0], + [2.0, 3.0] + ], + [ + [0.0, 1.0], + [2.0, 3.0] + ] + ]) + ) + end + + test "greater" do + Nx.greater(1, 2) + |> assert_equal(t(0)) + + Nx.greater(1, t([1, 2, 3])) + |> assert_equal(t([0, 0, 0])) + + t([1, 2, 3]) + |> Nx.greater(t([1, 2, 2])) + |> assert_equal(t([0, 0, 1])) + + t([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + |> Nx.greater(t([1, 2, 3])) + |> assert_equal(t( + [ + [0, 0, 0], + [1, 1, 1] + ] + )) + end + + test "less" do + Nx.less(1, 2) + |> assert_equal(t(1)) + + Nx.less(1, t([1, 2, 3])) + |> assert_equal(t([0, 1, 1])) + + t([[1.0, 2.0, 3.0], [4.0, 2.0, 1.0]]) + |> Nx.less(t([1, 2, 3])) + |> assert_equal(t([[0, 0, 0], [0, 0, 1]])) + end + + test "less_equal" do + Nx.less_equal(1, 2) + |> assert_equal(t(1)) + + Nx.less_equal(1, t([1, 2, 3])) + |> assert_equal(t([1, 1, 1])) + + t([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + |> Nx.less_equal(t([1, 2, 3])) + |> assert_equal(t([[1, 1, 1], [0, 0, 0]])) + end + + test "bitcast" do + t([0, 0, 0], type: :s64) + |> Nx.bitcast(:f64) + |> assert_equal(t([0.0, 0.0, 0.0])) + + t([0, 0, 0], type: :u32) + |> Nx.bitcast(:f32) + |> assert_equal(t([0.0, 0.0, 0.0])) + + t([0, 0, 0], type: :u32) + |> Nx.bitcast(:u32) + |> assert_equal(t([0, 0, 0])) + end + + test "eye" do + Nx.eye(2) + |> assert_equal(t([[1, 0], [0, 1]])) + + Nx.eye(3, type: :f32) + |> assert_equal( + t([ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0] + ]) + ) + + Nx.eye({1, 2}) + |> assert_equal(t([[1, 0]])) + + Nx.eye({2, 4, 3}) + |> assert_equal( + t([ + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, 0, 0] + ], + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, 0, 0] + ] + ]) + ) + + # assert_equal doesn't yet work with vectorized axes + # Nx.eye({3}, vectorized_axes: [x: 1, y: 2]) + # |> assert_equal(t( + # [ + # [ + # [1, 0, 0], + # [1, 0, 0] + # ] + # ] + # )) + + # Nx.eye({2, 3}, vectorized_axes: [x: 2]) + # |> assert_equal(t( + # [ + # [ + # [1, 0, 0], + # [0, 1, 0] + # ], + # [ + # [1, 0, 0], + # [0, 1, 0] + # ] + # ] + # )) + end + + test "dot/2" do + # Dot product of scalars + + Nx.dot(5, 5) + |> assert_equal(t(25)) + + Nx.dot(-2.0, 5.0) + |> assert_equal(t(-10.0)) + + Nx.dot(2, 2.0) + |> assert_equal(t(4.0)) + + # Dot product of vectors + + # TODO: + # t([1, 2, 3]) + # |> Nx.dot(t([4, 5, 6])) + # |> assert_equal(t(32)) + + # t([1.0, 2.0, 3.0]) + # |> Nx.dot(t([1, 2, 3])) + # |> assert_equal(t(14.0)) + + # Dot product of matrices (2-D tensors) + + # TODO: Candle matmul doesn't support integers yet + # t([[1, 2, 3], [4, 5, 6]]) + # |> Nx.dot(t([[7, 8], [9, 10], [11, 12]])) + # |> assert_equal(t( + # [ + # [58, 64], + # [139, 154] + # ] + # )) + + t([[1.0, 2, 3], [4, 5, 6]]) + |> Nx.dot(t([[7.0, 8], [9, 10], [11, 12]])) + |> assert_equal(t( + [ + [58.0, 64], + [139, 154] + ] + )) + + # Dot product of vector and n-D tensor + + # t([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]], names: [:i, :j, :k]) + # |> Nx.dot(t([5.0, 10], names: [:x])) + # |> assert_equal(t( + # [ + # [25, 55], + # [85, 115] + # ] + # )) + + # t([5.0, 10], names: [:x]) + # |> Nx.dot(t([[1.0, 2, 3], [4, 5, 6]], names: [:i, :j])) + # |> assert_equal(t( + # [45, 60, 75] + # )) + + # t([[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]], names: [:shard, :batch, :x, :y, :z]) + # |> Nx.dot(t([2.0, 2.0], names: [:data])) + # |> assert_equal(t( + # [ + # [ + # [ + # [6.0, 14.0], + # [22.0, 30.0] + # ] + # ] + # ] + # )) + + # Dot product of n-D and m-D tensors + + # t([[[1.0, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:x, :y, :z]) + # |> Nx.dot(t([[[1.0, 2, 3], [3, 4, 5], [5, 6, 7]]], names: [:i, :j, :k])) + # |> assert_equal(t( + # [ + # [ + # [ + # [22, 28, 34] + # ], + # [ + # [49, 64, 79] + # ], + # [ + # [76, 100, 124] + # ] + # ], + # [ + # [ + # [22, 28, 34] + # ], + # [ + # [49, 64, 79] + # ], + # [ + # [76, 100, 124] + # ] + # ] + # ] + # )) + end + + test "dot/6" do + # Contracting along axes + + t1 = t([[1.0, 2], [3, 4]], names: [:x, :y]) + t2 = t([[10.0, 20], [30, 40]], names: [:height, :width]) + + t1 + |> Nx.dot([0], [], t2, [0], []) + |> assert_equal(t( + [ + [100, 140], + [140, 200] + ] + )) + + # TODO: + t1 + |> Nx.dot([0], [], t2, [1], []) + |> assert_equal(t( + [ + [70, 150], + [100, 220] + ] + )) + + t1 + |> Nx.dot([1], [], t2, [0], []) + |> assert_equal(t( + [ + [70, 100], + [150, 220] + ] + )) + + # t1 + # |> Nx.dot([1], [], t2, [1], []) + # |> assert_equal(t( + # [ + # [50, 110], + # [110, 250] + # ] + # )) + + # t1 + # |> Nx.dot([0, 1], [], t2, [0, 1], []) + # |> assert_equal(t(300)) + end + + test "negate" do + # TODO: candle doesn't support unary functions for integers yet + # Nx.negate(1) + # |> assert_equal(t(-1)) + + Nx.negate(1.0) + |> assert_equal(t(-1.0)) + + t([1.0, 2.0, -3.0], type: :f32) + |> Nx.negate() + |> assert_equal(t([-1.0, -2.0, 3.0])) + end + + test "sin" do + Nx.sin(1.0) + |> assert_close(t(0.8414709568023682)) + + t([1.0, 2.0, 3.0]) + |> Nx.sin() + |> assert_close(t([0.8414709568023682, 0.9092974066734314, 0.14112000167369843])) + end + + test "sinh" do + Nx.sinh(1.0) + |> assert_close(t(1.175201177597046)) + + t([1.0, 2, 3]) + |> Nx.sinh() + |> assert_close(t([1.175201177597046, 3.6268603801727295, 10.017874717712402])) + end + + test "exp" do + Nx.exp(1.0) + |> assert_equal(t(2.7182817459106445)) + + t([1.0, 2, 3]) + |> Nx.exp() + |> assert_equal(t([2.7182817459106445, 7.389056205749512, 20.08553695678711])) + end + + test "expm1" do + Nx.expm1(1.0) + |> assert_close(t(1.718281865119934)) + + t([1.0, 2, 3]) + |> Nx.expm1() + |> assert_close(t([1.718281865119934, 6.389056205749512, 19.08553695678711])) + end + + test "cos" do + Nx.cos(1.0) + |> assert_close(t(0.5403022766113281)) + + t([1.0, 2, 3]) + |> Nx.cos() + |> assert_close(t([0.5403022766113281, -0.416146844625473, -0.9899924993515015])) + end + + test "cosh" do + Nx.cosh(1.0) + |> assert_close(t(1.5430806875228882)) + + t([1.0, 2, 3]) + |> Nx.cosh() + |> assert_close(t([1.5430806875228882, 3.762195587158203, 10.067662239074707])) + end + + test "log" do + Nx.log(1.0) + |> assert_equal(t(0.0)) + + t([1.0, 2, 3]) + |> Nx.log() + |> assert_equal(t([0.0, 0.6931471824645996, 1.0986123085021973])) + end + + test "tanh" do + Nx.tanh(1.0) + |> assert_equal(t(0.7615941762924194)) + + t([1.0, 2, 3]) + |> Nx.tanh() + |> assert_equal(t([0.7615941762924194, 0.9640275835990906, 0.9950547814369202])) + end + + test "abs" do + t([-2.0, -1, 0, 1, 2]) + |> Nx.abs() + |> assert_equal(t([2, 1, 0, 1, 2])) + end + + test "sqrt" do + Nx.sqrt(1.0) + |> assert_equal(t(1.0)) + + t([1.0, 2, 3]) + |> Nx.sqrt() + |> assert_equal(t([1.0, 1.4142135381698608, 1.7320507764816284])) + end + + test "rsqrt" do + Nx.rsqrt(1.0) + |> assert_equal(t(1.0)) + + t([1.0, 2, 3]) + |> Nx.rsqrt() + |> assert_equal(t([1.0, 0.7071067690849304, 0.5773502588272095])) + end + + test "argmax" do + Nx.argmax(4) + |> assert_equal(t(0)) + + # TODO: Support argmax without specific axis + # t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) + # |> Nx.argmax() + # |> assert_equal(t(10)) + + # t([2.0, 4.0]) + # |> Nx.argmax() + # |> assert_equal(t(1)) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) + |> Nx.argmax(axis: 0) + |> assert_equal(t( + [ + [1, 0, 0], + [1, 1, 0] + ] + )) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) + |> Nx.argmax(axis: :z) + |> assert_equal(t( + [ + [0, 2], + [0, 1] + ] + )) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) + |> Nx.argmax(axis: :y, keep_axis: true) + |> assert_equal(t( + [ + [ + [0, 0, 0] + ], + [ + [0, 1, 0] + ] + ] + )) + end + + test "argmin" do + Nx.argmin(4) + |> assert_equal(t(0)) + + # TODO: Support argmin without specific axis + # t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) + # |> Nx.argmin() + # |> assert_equal(t(4)) + + # t([2.0, 4.0]) + # |> Nx.argmin() + # |> assert_equal(t(0)) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) + |> Nx.argmin(axis: 0) + |> assert_equal(t( + [ + [0, 0, 0], + [0, 0, 0] + ] + )) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) + |> Nx.argmin(axis: 1) + |> assert_equal(t( + [ + [1, 1, 0], + [1, 0, 0] + ] + )) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) + |> Nx.argmin(axis: :z) + |> assert_equal(t( + [ + [1, 1], + [1, 2] + ] + )) + end + + test "acos" do + Nx.acos(0.10000000149011612) + |> assert_equal(t(1.4706288576126099)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.acos() + |> assert_equal(t([1.4706288576126099, 1.0471975803375244, 0.4510268568992615])) + end + + test "acosh" do + Nx.acosh(1.0) + |> assert_equal(t(0.0)) + + t([1.0, 2, 3]) + |> Nx.acosh() + |> assert_close(t([0.0, 1.316957950592041, 1.7627471685409546])) + end + + test "asin" do + Nx.asin(0.10000000149011612) + |> assert_equal(t(0.1001674234867096)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.asin() + |> assert_equal(t([0.1001674234867096, 0.5235987901687622, 1.1197694540023804])) + end + + test "asinh" do + Nx.asinh(1.0) + |> assert_close(t(0.8813735842704773)) + + t([1.0, 2, 3]) + |> Nx.asinh() + |> assert_close(t([0.8813735842704773, 1.4436354637145996, 1.8184465169906616])) + end + + test "tan" do + Nx.tan(1.0) + |> assert_close(t(1.5574077367782593)) + + t([1.0, 2, 3]) + |> Nx.tan() + |> assert_close(t([1.5574077367782593, -2.185039758682251, -0.14254654943943024])) + end + + test "atan" do + Nx.atan(0.10000000149011612) + |> assert_close(t(0.09966865181922913)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.atan() + |> assert_close(t([0.09966865181922913, 0.46364760398864746, 0.7328150868415833])) + end + + test "atanh" do + Nx.atanh(0.10000000149011612) + |> assert_close(t(0.10033535212278366)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.atanh() + |> assert_close(t([0.10033535212278366, 0.5493061542510986, 1.4722193479537964])) + end + + test "ceil" do + t([-1, 0, 1]) + |> Nx.ceil() + |> assert_equal(t([-1, 0, 1])) + + t([-1.5, -0.5, 0.5, 1.5]) + |> Nx.ceil() + |> assert_equal(t([-1.0, 0.0, 1.0, 2.0])) + end + + test "floor" do + t([-1, 0, 1]) + |> Nx.floor() + |> assert_equal(t([-1, 0, 1])) + + t([-1.5, -0.5, 0.5, 1.5]) + |> Nx.floor() + |> assert_equal(t([-2.0, -1.0, 0.0, 1.0])) + end + + test "round" do + t([-1, 0, 1]) + |> Nx.round() + |> assert_equal(t([-1, 0, 1])) + + t([-1.5, -0.5, 0.5, 1.5]) + |> Nx.round() + |> assert_equal(t([-2.0, -1.0, 1.0, 2.0])) + end + + test "cbrt" do + Nx.cbrt(1.0) + |> assert_equal(t(1.0)) + + t([1.0, 2, 3]) + |> Nx.cbrt() + |> assert_equal(t([1.0, 1.2599210739135742, 1.4422495365142822])) + end + + test "log1p" do + Nx.log1p(1.0) + |> assert_equal(t(0.6931471824645996)) + + t([1.0, 2, 3]) + |> Nx.log1p() + |> assert_equal(t([0.6931471824645996, 1.0986123085021973, 1.3862943649291992])) + end + + test "bitwise_and" do + Nx.bitwise_and(1, 0) + |> assert_equal(t(0)) + + t([0, 1, 2]) + |> Nx.bitwise_and(1) + |> assert_equal(t([0, 1, 0])) + + t([0, -1, -2]) + |> Nx.bitwise_and(-1) + |> assert_equal(t([0, -1, -2])) + + t([0, 0, 1, 1]) + |> Nx.bitwise_and(t([0, 1, 0, 1])) + |> assert_equal(t([0, 0, 0, 1])) + end + + test "bitwise_or" do + Nx.bitwise_or(1, 0) + |> assert_equal(t(1)) + + t([0, 1, 2]) + |> Nx.bitwise_or(1) + |> assert_equal(t([1, 1, 3])) + + t([0, -1, -2]) + |> Nx.bitwise_or(-1) + |> assert_equal(t([-1, -1, -1])) + + t([0, 0, 1, 1]) + |> Nx.bitwise_or(t([0, 1, 0, 1])) + |> assert_equal(t([0, 1, 1, 1])) + end + + test "bitwise_xor" do + Nx.bitwise_xor(1, 0) + |> assert_equal(t(1)) + + t([1, 2, 3]) + |> Nx.bitwise_xor(2) + |> assert_equal(t([3, 0, 1])) + + t([1, 2, 3], type: :u32) + |> Nx.bitwise_xor(2) + |> assert_equal(t([3, 0, 1])) + + t([-1, -2, -3]) + |> Nx.bitwise_xor(2) + |> assert_equal(t([-3, -4, -1])) + + t([0, 0, 1, 1]) + |> Nx.bitwise_xor(t([0, 1, 0, 1])) + |> assert_equal(t([0, 1, 1, 0])) + end + + test "bitwise_not" do + Nx.bitwise_not(1) + |> assert_equal(t(-2)) + + t([-1, 0, 1]) + |> Nx.bitwise_not() + |> assert_equal(t([0, -1, -2])) + + t([0, 1, 254, 255], type: :u8) + |> Nx.bitwise_not() + |> assert_equal(t([255, 254, 1, 0])) + end + + test "left_shift" do + Nx.left_shift(1, 0) + |> assert_equal(t(1)) + + t([1, 2, 3]) + |> Nx.left_shift(2) + |> assert_equal(t([4, 8, 12])) + + t([1, 1, -1, -1]) + |> Nx.left_shift(t([1, 2, 3, 4])) + |> assert_equal(t([2, 4, -8, -16])) + + t([1, 2, 3], type: :u32) + |> Nx.left_shift(2) + |> assert_equal(t([4, 8, 12])) + + t([1, 2, 3], type: :u32) + |> Nx.left_shift(t(2, type: :u8)) + |> assert_equal(t([4, 8, 12])) + + t([1, 1, 0, 0], type: :u32) + |> Nx.left_shift(t([1, 2, 3, 4])) + |> assert_equal(t([2, 4, 0, 0])) + + t([1, 1, 0, 0], type: :u32) + |> Nx.left_shift(t([1, 2, 3, 4], type: :u8)) + |> assert_equal(t([2, 4, 0, 0])) + end + + test "right_shift" do + Nx.right_shift(1, 0) + |> assert_equal(t(1)) + + t([2, 4, 8]) + |> Nx.right_shift(2) + |> assert_equal(t([0, 1, 2])) + + t([16, 32, -64, -128]) + |> Nx.right_shift(t([1, 2, 3, 4])) + |> assert_equal(t([8, 8, -8, -8])) + + t([2, 4, 8], type: :u32) + |> Nx.right_shift(2) + |> assert_equal(t([0, 1, 2])) + + t([16, 32, -64, -128], type: :u32) + |> Nx.right_shift(t([1, 2, 3, 4])) + |> assert_equal(t([8, 8, 536_870_904, 268_435_448])) + end + + test "is_infinity" do + t([:infinity, :nan, :neg_infinity, 1, 0]) + |> Nx.is_infinity() + |> assert_equal(t([1, 0, 1, 0, 0])) + + t([:infinity, 1, :neg_infinity]) + |> Nx.is_infinity() + |> assert_equal(t([1, 0, 1])) + + # TODO: Not supported for :s64 + # t([1, 0]) + # |> Nx.is_infinity() + # |> assert_equal(t([0, 0])) + end + + test "is_nan" do + t([:nan, 1.0, 0.0]) + |> Nx.is_nan() + |> assert_equal(t([1, 0, 0])) + + t([:nan, :infinity]) + |> Nx.is_nan() + |> assert_equal(t([1, 0])) + + # Complex not yet supported + # t(Complex.new(0, :nan)) + # |> Nx.is_nan() + # |> assert_equal(t(1)) + + t([1.0, 0.0]) + |> Nx.is_nan() + |> assert_equal(t([0, 0])) + end + + test "logical_and" do + Nx.logical_and(1, t([-1, 0, 1])) + |> assert_equal(t([1, 0, 1])) + + t([-1, 0, 1]) + |> Nx.logical_and(t([[-1], [0], [1]])) + |> assert_equal(t( + [ + [1, 0, 1], + [0, 0, 0], + [1, 0, 1] + ] + )) + + t([-1.0, 0.0, 1.0]) + |> Nx.logical_and(t([[-1], [0], [1]]))\ + |> assert_equal(t( + [ + [1, 0, 1], + [0, 0, 0], + [1, 0, 1] + ] + )) + end + + test "logical_or" do + Nx.logical_or(0, t([-1, 0, 1])) + |> assert_equal(t([1, 0, 1])) + + t([-1, 0, 1]) + |> Nx.logical_or(t([[-1], [0], [1]])) + |> assert_equal(t( + [ + [1, 1, 1], + [1, 0, 1], + [1, 1, 1] + ] + )) + + t([-1.0, 0.0, 1.0]) + |> Nx.logical_or(t([[-1], [0], [1]])) + |> assert_equal(t( + [ + [1, 1, 1], + [1, 0, 1], + [1, 1, 1] + ] + )) + end + + test "logical_xor" do + 0 + |> Nx.logical_xor(t([-1, 0, 1])) + |> assert_equal(t([1, 0, 1])) + + t([-1, 0, 1]) + |> Nx.logical_xor(t([[-1], [0], [1]])) + |> assert_equal(t( + [ + [0, 1, 0], + [1, 0, 1], + [0, 1, 0] + ] + )) + + t([-1.0, 0.0, 1.0]) + |> Nx.logical_xor(t([[-1], [0], [1]])) + |> assert_equal(t( + [ + [0, 1, 0], + [1, 0, 1], + [0, 1, 0] + ] + )) + end + + test "erf" do + Nx.erf(1.0) + |> assert_close(t(0.8427007794380188)) + + Nx.erf(t([1.0, 2, 3])) + |> assert_close(t([0.8427007794380188, 0.9953222870826721, 0.9999778866767883])) + end + + test "erfc" do + Nx.erfc(1.0) + |> assert_close(t(0.15729920566082)) + + Nx.erfc(t([1.0, 2, 3])) + |> assert_close(t([0.15729920566082, 0.004677734803408384, 2.2090496713644825e-5])) + end + + test "erf_inv" do + Nx.erf_inv(0.10000000149011612) + |> assert_close(t(0.08885598927736282)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.erf_inv() + |> assert_close(t([0.08885598927736282, 0.4769362807273865, 1.163087010383606])) + + t(0.10000000149011612, type: :f64) + |> Nx.erf_inv() + |> assert_close(t(0.08885598927736282, type: :f64)) + + t([0.10000000149011612, 0.5, 0.8999999761581421], type: :f64) + |> Nx.erf_inv() + |> assert_close( + t([0.0888559891358877, 0.47693629334671295, 1.1630870196442271], type: :f64) + ) + end + + test "sum/2" do + t(42) + |> Nx.sum() + |> assert_equal(t(42)) + + t([1, 2, 3]) + |> Nx.sum() + |> assert_equal(t(6)) + + t([[1.0, 2.0], [3.0, 4.0]]) + |> Nx.sum() + |> assert_equal(t(10.0)) + + t = Nx.iota({2, 2, 3}, names: [:x, :y, :z]) + Nx.sum(t, axes: [:x]) + |> assert_equal(t( + [ + [6, 8, 10], + [12, 14, 16] + ] + )) + + Nx.sum(t, axes: [:y]) + |> assert_equal(t( + [ + [3, 5, 7], + [15, 17, 19] + ] + )) + + Nx.sum(t, axes: [:z]) + |> assert_equal(t( + [ + [3, 12], + [21, 30] + ] + )) + + Nx.sum(t, axes: [:x, :z]) + |> assert_equal(t([24, 42])) + + Nx.sum(t, axes: [-3]) + |> assert_equal(t( + [ + [6, 8, 10], + [12, 14, 16] + ] + )) + + t([[1, 2], [3, 4]], names: [:x, :y]) + |> Nx.sum(axes: [:x], keep_axes: true) + |> assert_equal(t( + [ + [4, 6] + ] + )) + end + + test "to_batched/2" do + [first, second] = + Nx.iota({2, 2, 2}) + |> Nx.to_batched(1) + |> Enum.to_list() + + first + |> assert_equal(t( + [ + [ + [0, 1], + [2, 3] + ] + ] + )) + + second + |> assert_equal(t( + [ + [ + [4, 5], + [6, 7] + ] + ] + )) + + [first, second] = + Nx.iota({10}) + |> Nx.to_batched(5) + |> Enum.to_list() + + first + |> assert_equal(Nx.tensor([0, 1, 2, 3, 4])) + + second + |> assert_equal(Nx.tensor([5, 6, 7, 8, 9])) + + [first, second, third, fourth] = + Nx.iota({10}) + |> Nx.to_batched(3) + |> Enum.to_list() + + first + |> assert_equal(Nx.tensor([0, 1, 2])) + + second + |> assert_equal(Nx.tensor([3, 4, 5])) + + third + |> assert_equal(Nx.tensor([6, 7, 8])) + + fourth + |> assert_equal(Nx.tensor([9, 0, 1])) + + # TODO: Implement with discard + # [first, second] = + # Nx.iota({10}) + # |> Nx.to_batched(4, leftover: :discard) + # |> Enum.to_list() + + # first + # |> assert_equal(Nx.tensor([0, 1, 2, 3])) + + # second + # |> assert_equal(Nx.tensor([4, 5, 6, 7])) + end + + test "sigmoid/1" do + Nx.sigmoid(1.0) + |> assert_close(t(0.7310585975646973)) + + t([1.0, 2, 3]) + |> Nx.sigmoid() + |> assert_close(t([0.7310585975646973, 0.8807970881462097, 0.9525741338729858])) + end + + test "mean/1" do + t(42) + |> Nx.mean() + |> assert_equal(t(42.0)) + + t([1, 2, 3]) + |> Nx.mean() + |> assert_equal(t(2.0)) + + t([0.1, 0.2, 0.3]) + |> Nx.mean() + |> assert_equal(t(0.2)) + end + + test "pow" do + # Nx.pow(2, 4) + # |> assert_equal(t(16)) + + # t([1, 2, 3], type: :u32) + # |> Nx.pow(t(2, type: :u32)) + # |> assert_equal(t([1, 4, 9])) + + t([1.0, 2.0, 3.0]) + |> Nx.pow(2) + |> assert_equal(t([1.0, 4.0, 9.0])) + + 2 + |> Nx.pow(t([1.0, 2.0, 3.0])) + |> assert_equal(t([2.0, 4.0, 8.0])) + + # t([[2], [3]]) + # |> Nx.pow(t([[4, 5]])) + # |> assert_equal(t( + # [ + # [16, 32], + # [81, 243] + # ] + # )) + end + + test "conv" do + Nx.iota({9}) + |> Nx.reshape({1, 1, 3, 3}) + |> Nx.conv( + Nx.iota({4}) + |> Nx.reshape({4, 1, 1, 1}), + strides: [1, 1] + ) + |> assert_equal(t( + [ + [ + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0] + ], + [ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0] + ], + [ + [0.0, 2.0, 4.0], + [6.0, 8.0, 10.0], + [12.0, 14.0, 16.0] + ], + [ + [0.0, 3.0, 6.0], + [9.0, 12.0, 15.0], + [18.0, 21.0, 24.0] + ] + ] + ] + )) + + # input/output permutation + + result = + Nx.iota({1, 3, 3, 6}) + |> Nx.conv( + 1 |> Nx.broadcast({2, 6, 1, 1}), + input_permutation: [0, 3, 1, 2], + output_permutation: [0, 3, 1, 2] + ) + + assert result.shape == {1, 3, 3, 2} + + result + |> assert_close(t( + [ + [ + [15.0, 15.0], + [51.0, 51.0], + [87.0, 87.0] + ], + [ + [123.0, 123.0], + [159.0, 159.0], + [195.0, 195.0] + ], + [ + [231.0, 231.0], + [267.0, 267.0], + [303.0, 303.0] + ] + ] + )) + + # Nx.iota({9}) + # |> Nx.reshape({1, 1, 3, 3}) + # |> Nx.conv( + # Nx.iota({8}) + # |> Nx.reshape({4, 1, 2, 1}), + # strides: 2, + # padding: :same, + # kernel_dilation: [2, 1] + # ) + # |> assert_equal(t( + # [ + # [ + # [ + # [3.0, 5.0], + # [0.0, 0.0] + # ], + # [ + # [9.0, 15.0], + # [6.0, 10.0] + # ], + # [ + # [15.0, 25.0], + # [12.0, 20.0] + # ], + # [ + # [21.0, 35.0], + # [18.0, 30.0] + # ] + # ] + # ] + # )) + end + + test "reduce_max" do + t(42) + |> Nx.reduce_max() + |> assert_equal(t(42)) + + t(42.0) + |> Nx.reduce_max() + |> assert_equal(t(42.0)) + + t([1, 2, 3]) + |> Nx.reduce_max() + |> assert_equal(t(3)) + + t([[3, 1, 4], [2, 1, 1]], names: [:x, :y]) + |> Nx.reduce_max(axes: [:x]) + |> assert_equal(t([3, 1, 4])) + + t([[3, 1, 4], [2, 1, 1]], names: [:x, :y]) + |> Nx.reduce_max(axes: [:y]) + |> assert_equal(t([4, 2])) + + # t([[[1, 2], [4, 5]], [[2, 4], [3, 8]]], names: [:x, :y, :z]) + # |> Nx.reduce_max(axes: [:x, :z]) + # |> assert_equal(t([4, 8])) + + # t([[[1, 2], [4, 5]], [[2, 4], [3, 8]]], names: [:x, :y, :z]) + # |> Nx.reduce_max(axes: [:x, :z], keep_axes: true) + # |> assert_equal(t( + # [ + # [ + # [4], + # [8] + # ] + # ] + # )) + end + + test "reduce_min" do + Nx.reduce_min(t(42)) + |> assert_equal(t(42)) + + Nx.reduce_min(t(42.0)) + |> assert_equal(t(42.0)) + + Nx.reduce_min(t([1, 2, 3])) + |> assert_equal(t(1)) + + t([[3, 1, 4], [2, 1, 1]], names: [:x, :y]) + |> Nx.reduce_min(axes: [:x]) + |> assert_equal(t([2, 1, 1])) + + t([[3, 1, 4], [2, 1, 1]], names: [:x, :y]) + |> Nx.reduce_min(axes: [:y]) + |> assert_equal(t([1, 1])) + + # t([[[1, 2], [4, 5]], [[2, 4], [3, 8]]], names: [:x, :y, :z]) + # |> Nx.reduce_min(axes: [:x, :z]) + # |> assert_equal(t([1, 3])) + + # t([[[1, 2], [4, 5]], [[2, 4], [3, 8]]], names: [:x, :y, :z]) + # |> Nx.reduce_min(axes: [:x, :z], keep_axes: true) + # |> assert_equal(t( + # [ + # [ + # [1], + # [3] + # ] + # ] + # )) + end + + test "take_along_axis" do + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.take_along_axis( + t( + [ + [0, 0, 2, 2, 1, 1], + [2, 2, 1, 1, 0, 0] + ] + ), + axis: 1 + ) + |> assert_equal(t( + [ + [1, 1, 3, 3, 2, 2], + [6, 6, 5, 5, 4, 4] + ] + )) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.take_along_axis( + t( + [ + [0, 1, 1], + [1, 0, 0], + [0, 1, 0] + ] + ), + axis: 0 + ) + |> assert_equal(t( + [ + [1, 5, 6], + [4, 2, 3], + [1, 5, 3] + ] + )) + end + + test "gather" do + t([1, 2, 3, 4]) + |> Nx.gather(t([[3], [1], [2]])) + |> assert_equal(t([4, 2, 3])) + + # t([[1, 2], [3, 4]]) + # |> Nx.gather(t([[1, 1], [0, 1], [1, 0]])) + # |> assert_equal(t([4, 2, 3])) + + # t([[1, 2], [3, 4]]) + # |> Nx.gather(t([[[1, 1], [0, 0]], [[1, 0], [0, 1]]])) + # |> assert_equal(t( + # [ + # [4, 1], + # [3, 2] + # ] + # )) + + # t([[[1, 2], [11, 12]], [[101, 102], [111, 112]]]) + # |> Nx.gather(t([[0, 0, 0], [0, 1, 1], [1, 1, 1]])) + # |> assert_equal(t([1, 12, 112])) + end + + test "indexed_add" do + t([1.0]) + |> Nx.indexed_add(t([[0], [0]]), t([1, 1])) + |> assert_equal(t([3.0])) + + t([1]) + |> Nx.indexed_add(t([[0], [0]]), t([1.0, 1.0])) + |> assert_equal(t([3.0])) + + t([1], type: :u8) + |> Nx.indexed_add(t([[0], [0]]), t([1, 1], type: :s64)) + |> assert_equal(t([3])) + + # Nx.iota({1, 2, 3}) + # |> Nx.indexed_add( + # t([[0, 0, 0], [0, 1, 1], [0, 0, 0], [0, 0, 2], [0, 1, 2]]), + # t([1, 3, 1, -2, 5]) + # ) + # |> assert_equal(t( + # [ + # [ + # [2, 1, 0], + # [3, 7, 10] + # ] + # ] + # )) + end + + test "transpose" do + t(1) + |> Nx.transpose() + |> assert_equal(t(1)) + + Nx.iota({2, 3, 4}, names: [:x, :y, :z]) + |> Nx.transpose() + |> assert_equal(t( + [ + [ + [0, 12], + [4, 16], + [8, 20] + ], + [ + [1, 13], + [5, 17], + [9, 21] + ], + [ + [2, 14], + [6, 18], + [10, 22] + ], + [ + [3, 15], + [7, 19], + [11, 23] + ] + ] + )) + + t(1) + |> Nx.transpose(axes: []) + |> assert_equal(t(1)) + + Nx.iota({2, 3, 4}, names: [:batch, :x, :y]) + |> Nx.transpose(axes: [2, 1, :batch]) + |> assert_equal(t( + [ + [ + [0, 12], + [4, 16], + [8, 20] + ], + [ + [1, 13], + [5, 17], + [9, 21] + ], + [ + [2, 14], + [6, 18], + [10, 22] + ], + [ + [3, 15], + [7, 19], + [11, 23] + ] + ] + )) + + Nx.iota({2, 3, 4}, names: [:batch, :x, :y]) + |> Nx.transpose(axes: [:y, :batch, :x]) + |> assert_equal(t( + [ + [ + [0, 4, 8], + [12, 16, 20] + ], + [ + [1, 5, 9], + [13, 17, 21] + ], + [ + [2, 6, 10], + [14, 18, 22] + ], + [ + [3, 7, 11], + [15, 19, 23] + ] + ] + )) + + Nx.iota({2, 3, 4}, names: [:batch, :x, :y]) + |> Nx.transpose(axes: [:batch, :y, :x]) + |> assert_equal(t( + [ + [ + [0, 4, 8], + [1, 5, 9], + [2, 6, 10], + [3, 7, 11] + ], + [ + [12, 16, 20], + [13, 17, 21], + [14, 18, 22], + [15, 19, 23] + ] + ] + )) + end + + test "put_slice" do + t([0, 1, 2, 3, 4]) + |> Nx.put_slice([2], Nx.tensor([5, 6])) + |> assert_equal(t([0, 1, 5, 6, 4])) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.put_slice([0, 0], t([[7, 8, 9], [10, 11, 12]])) + |> assert_equal(t( + [ + [7, 8, 9], + [10, 11, 12] + ] + )) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.put_slice([0, 1], t([[7, 8], [9, 10]])) + |> assert_equal(t( + [ + [1, 7, 8], + [4, 9, 10] + ] + )) + + # t([[1, 2, 3], [4, 5, 6]]) + # |> Nx.put_slice([t(0), t(1)], t([[10.0, 11.0]])) + # |> assert_equal(t( + # [ + # [1.0, 10.0, 11.0], + # [4.0, 5.0, 6.0] + # ] + # )) + + # t([[1, 2, 3], [4, 5, 6]]) + # |> Nx.put_slice([1, 1], t([[7, 8], [9, 10]])) + # |> assert_equal(t( + # [ + # [1, 7, 8], + # [4, 9, 10] + # ] + # )) + + t([ + [ + [1, 2], + [3, 4] + ], + [ + [4, 5], + [6, 7] + ] + ]) + |> Nx.put_slice([0, 0, 1], t([[[8], [9]], [[10], [11]]])) + |> assert_equal( + t([ + [ + [1, 8], + [3, 9] + ], + [ + [4, 10], + [6, 11] + ] + ]) + ) + end + + test "pad" do + t(1) + |> Nx.pad(0, []) + |> assert_equal(t(1)) + + t([1, 2, 3], names: [:data]) + |> Nx.pad(0, [{1, 1, 0}]) + |> assert_equal(t([0, 1, 2, 3, 0])) + + # t([[1, 2, 3], [4, 5, 6]]) + # |> Nx.pad(0, [{0, 0, 1}, {0, 0, 1}]) + # |> assert_equal(t( + # [ + # [1, 0, 2, 0, 3], + # [0, 0, 0, 0, 0], + # [4, 0, 5, 0, 6] + # ] + # )) + + # Nx.pad(Nx.tensor([[1, 2, 3], [4, 5, 6]]), 0, [{1, 1, 0}, {1, 1, 0}]) + # [ + # [0, 0, 0, 0, 0], + # [0, 1, 2, 3, 0], + # [0, 4, 5, 6, 0], + # [0, 0, 0, 0, 0] + # ] + # > + + # tensor = Nx.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + # Nx.pad(tensor, 0, [{0, 2, 0}, {1, 1, 0}, {1, 0, 0}]) + # [ + # [ + # [0, 0, 0], + # [0, 1, 2], + # [0, 3, 4], + # [0, 0, 0] + # ], + # [ + # [0, 0, 0], + # [0, 5, 6], + # [0, 7, 8], + # [0, 0, 0] + # ], + # [ + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0] + # ], + # [ + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0] + # ] + # ] + + # tensor = Nx.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + # Nx.pad(tensor, 0, [{1, 0, 0}, {1, 1, 0}, {0, 1, 0}]) + # [ + # [ + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0] + # ], + # [ + # [0, 0, 0], + # [1, 2, 0], + # [3, 4, 0], + # [0, 0, 0] + # ], + # [ + # [0, 0, 0], + # [5, 6, 0], + # [7, 8, 0], + # [0, 0, 0] + # ] + # ] + + # tensor = Nx.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]) + # Nx.pad(tensor, 0.0, [{1, 2, 0}, {1, 0, 0}, {0, 1, 0}]) + # [ + # [ + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0] + # ], + # [ + # [0.0, 0.0, 0.0], + # [1.0, 2.0, 0.0], + # [3.0, 4.0, 0.0] + # ], + # [ + # [0.0, 0.0, 0.0], + # [5.0, 6.0, 0.0], + # [7.0, 8.0, 0.0] + # ], + # [ + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0] + # ], + # [ + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0] + # ] + # ] + + # Nx.pad(Nx.tensor([0, 1, 2, 3, 0]), 0, [{-1, -1, 0}]) + # [1, 2, 3] + + # tensor = Nx.tensor([ + # [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + # [[0, 0, 0], [1, 2, 0], [3, 4, 0], [0, 0, 0]], + # [[0, 0, 0], [5, 6, 0], [7, 8, 0], [0, 0, 0]] + # ]) + # Nx.pad(tensor, 0, [{-1, 0, 0}, {-1, -1, 0}, {0, -1, 0}]) + # [ + # [ + # [1, 2], + # [3, 4] + # ], + # [ + # [5, 6], + # [7, 8] + # ] + # ] + + # t([[0, 1, 2, 3], [0, 4, 5, 6]]) + # |> Nx.pad(0, [{0, 0, 0}, {-1, 1, 0}]) + # |> assert_equal(t( + # [ + # [1, 2, 3, 0], + # [4, 5, 6, 0] + # ] + # )) + + # t([[0, 1, 2], [3, 4, 5]], type: :f32) + # |> Nx.pad(0, [{-1, 2, 0}, {1, -1, 0}]) + # |> assert_equal(t( + # [ + # [0.0, 3.0, 4.0], + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0] + # ] + # ) + end + + test "take" do + t([[1, 2], [3, 4]]) + |> Nx.take(t([1, 0, 1])) + |> assert_equal(t( + [ + [3, 4], + [1, 2], + [3, 4] + ] + )) + + t([[1, 2], [3, 4]]) + |> Nx.take(t([1, 0, 1]), axis: 1) + |> assert_equal(t( + [ + [2, 1, 2], + [4, 3, 4] + ] + )) + + t([[[1, 2], [11, 12]], [[101, 102], [111, 112]]]) + |> Nx.take(t([1, 0, 1]), axis: 1) + |> assert_equal(t( + [ + [ + [11, 12], + [1, 2], + [11, 12] + ], + [ + [111, 112], + [101, 102], + [111, 112] + ] + ] + )) + + # t([[1, 2], [11, 12]]) + # |> Nx.take(t([[0, 0], [1, 1], [0, 0]]), axis: 1) + # |> assert_equal(t( + # [ + # [ + # [1, 1], + # [2, 2], + # [1, 1] + # ], + # [ + # [11, 11], + # [12, 12], + # [11, 11] + # ] + # ] + # )) + + # t([[[1, 2], [11, 12]], [[101, 102], [111, 112]]]) + # |> Nx.take(t([[0, 0, 0], [1, 1, 1], [0, 0, 0]]), axis: 1) + # |> assert_equal(t( + # [ + # [ + # [ + # [1, 2], + # [1, 2], + # [1, 2] + # ], + # [ + # [11, 12], + # [11, 12], + # [11, 12] + # ], + # [ + # [1, 2], + # [1, 2], + # [1, 2] + # ] + # ], + # [ + # [ + # [101, 102], + # [101, 102], + # [101, 102] + # ], + # [ + # [111, 112], + # [111, 112], + # [111, 112] + # ], + # [ + # [101, 102], + # [101, 102], + # [101, 102] + # ] + # ] + # ] + # )) + end + + test "clip" do + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.clip(2, 4) + |> assert_equal(t( + [ + [2, 2, 3], + [4, 4, 4] + ] + )) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.clip(2.0, 3) + |> assert_equal(t( + [ + [2.0, 2.0, 3.0], + [3.0, 3.0, 3.0] + ] + )) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.clip(t(2.0), Nx.max(1.0, 3.0)) + |> assert_equal(t( + [ + [2.0, 2.0, 3.0], + [3.0, 3.0, 3.0] + ] + )) + + t([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + |> Nx.clip(2, 6.0) + |> assert_equal(t( + [ + [2.0, 2.0, 3.0], + [4.0, 5.0, 6.0] + ] + )) + + t([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], type: :f32) + |> Nx.clip(1, 4) + |> assert_equal(t( + [ + [1.0, 2.0, 3.0], + [4.0, 4.0, 4.0] + ] + )) + end + + test "not_equal" do + Nx.not_equal(1, 2) + |> assert_equal(t(1)) + + t([1, 2, 3]) + |> Nx.not_equal(t(1)) + |> assert_equal(t([0, 1, 1])) + + t([1, 1, 2]) + |> Nx.not_equal(t([1, 2, 3])) + |> assert_equal(t([0, 1, 1])) + + t([[1, 4, 2], [4, 5, 6]]) + |> Nx.not_equal(t([[1, 3, 2], [4, 2, 1]])) + |> assert_equal(t( + [ + [0, 1, 0], + [0, 1, 1] + ] + )) + end + + if Candlex.Backend.cuda_available? do + test "different devices" do + t([1, 2, 3], backend: {Candlex.Backend, device: :cpu}) + |> Nx.add(t([10, 20, 30], backend: {Candlex.Backend, device: :cuda})) + |> assert_equal(t([11, 22, 33])) + + t([1, 2, 3], backend: {Candlex.Backend, device: :cuda}) + |> Nx.add(t([10, 20, 30], backend: {Candlex.Backend, device: :cpu})) + |> assert_equal(t([11, 22, 33])) + end + end + + test "backend_transfer" do + t([1, 2, 3], backend: Nx.BinaryBackend) + |> Nx.backend_transfer({Candlex.Backend, device: :cpu}) + |> assert_equal(t([1, 2, 3])) + + t([1, 2, 3], backend: {Candlex.Backend, device: :cpu}) + |> Nx.backend_transfer(Nx.BinaryBackend) + |> assert_equal(t([1, 2, 3])) + + t([1, 2, 3], backend: {Candlex.Backend, device: :cpu}) + |> Nx.backend_transfer({Candlex.Backend, device: :cpu}) + |> assert_equal(t([1, 2, 3])) + end + end + + defp t(values, opts \\ []) do + opts = + [backend: Candlex.Backend] + |> Keyword.merge(opts) + + Nx.tensor(values, opts) + end + + defp check(value, opts) do + tensor = t(value, opts) + + tensor + # |> IO.inspect() + |> Nx.to_binary() + # |> IO.inspect() + + opts = + [backend: Nx.BinaryBackend] + |> Keyword.merge(opts) + + assert Nx.backend_copy(tensor) == t(value, opts) + assert Nx.backend_transfer(tensor) == t(value, opts) end end diff --git a/test/support/nx_case.ex b/test/support/nx_case.ex new file mode 100644 index 0000000..949a549 --- /dev/null +++ b/test/support/nx_case.ex @@ -0,0 +1,45 @@ +defmodule Nx.Case do + @moduledoc """ + Test case for tensor assertions + """ + + use ExUnit.CaseTemplate + + using do + quote do + import Nx.Case + end + end + + def assert_equal(left, right) do + equals = + left + |> Nx.equal(right) + # |> Nx.logical_or(Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right))) + |> Nx.all() + |> Nx.to_number() + + if equals != 1 || Nx.shape(left) != Nx.shape(right) do + flunk(""" + Tensor assertion failed. + left: #{inspect(left)} + right: #{inspect(right)} + """) + end + end + + def assert_close(left, right) do + equals = + left + |> Nx.all_close(right, atol: 1.0e-4, rtol: 1.0e-4) + |> Nx.backend_transfer(Nx.BinaryBackend) + + if equals != Nx.tensor(1, type: {:u, 8}, backend: Nx.BinaryBackend) do + flunk(""" + Tensor assertion failed. + left: #{inspect(left)} + right: #{inspect(right)} + """) + end + end +end diff --git a/test/test_helper.exs b/test/test_helper.exs index 869559e..13695bf 100644 --- a/test/test_helper.exs +++ b/test/test_helper.exs @@ -1 +1,2 @@ +Application.put_env(:nx, :default_backend, Candlex.Backend) ExUnit.start()