From 7225ecdff0e404ffeeff7eb8ea4de93cedb5ee27 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 1 Nov 2023 11:33:30 -0300 Subject: [PATCH] feat: dot/2 supports dot product of n-D tensor with 1-D tensor (vector) --- lib/candlex/backend.ex | 19 +++++++++++++++++-- lib/candlex/native.ex | 4 ++-- native/candlex/src/lib.rs | 2 +- native/candlex/src/tensors.rs | 4 ++-- test/candlex_test.exs | 20 ++++++++++++-------- 5 files changed, 34 insertions(+), 15 deletions(-) diff --git a/lib/candlex/backend.ex b/lib/candlex/backend.ex index b9b3d7f..ed7e178 100644 --- a/lib/candlex/backend.ex +++ b/lib/candlex/backend.ex @@ -517,12 +517,27 @@ defmodule Candlex.Backend do ) when tuple_size(left_shape) == 1 and tuple_size(right_shape) == 1 do from_nx(left) - |> Native.vec_dot(from_nx(right)) + |> Native.dot(from_nx(right)) + |> unwrap!() + |> to_nx(out) + end + + def dot( + %T{type: _out_type} = out, + %T{shape: left_shape, type: _left_type} = left, + [left_axis] = _left_axes, + [] = _left_batched_axes, + %T{shape: right_shape, type: _right_type} = right, + [0] = _right_axes, + [] = _right_batched_axes + ) + when tuple_size(left_shape) > 1 and tuple_size(right_shape) == 1 and left_axis == tuple_size(left_shape) - 1 do + from_nx(left) + |> Native.dot(from_nx(right)) |> unwrap!() |> to_nx(out) end - @impl true def dot( %T{type: _out_type} = out, %T{shape: left_shape, type: _left_type} = left, diff --git a/lib/candlex/native.ex b/lib/candlex/native.ex index 13f2375..8402b49 100644 --- a/lib/candlex/native.ex +++ b/lib/candlex/native.ex @@ -89,6 +89,7 @@ defmodule Candlex.Native do :bitwise_or, :bitwise_xor, :divide, + :dot, :equal, :greater, :greater_equal, @@ -107,8 +108,7 @@ defmodule Candlex.Native do :quotient, :remainder, :right_shift, - :subtract, - :vec_dot + :subtract ] do def unquote(op)(_left, _right), do: error() end diff --git a/native/candlex/src/lib.rs b/native/candlex/src/lib.rs index 243f42c..1ad3e56 100644 --- a/native/candlex/src/lib.rs +++ b/native/candlex/src/lib.rs @@ -68,7 +68,7 @@ rustler::init! { tensors::permute, tensors::slice_scatter, tensors::pad_with_zeros, - tensors::vec_dot, + tensors::dot, tensors::matmul, tensors::abs, tensors::acos, diff --git a/native/candlex/src/tensors.rs b/native/candlex/src/tensors.rs index 42730cd..7cbec18 100644 --- a/native/candlex/src/tensors.rs +++ b/native/candlex/src/tensors.rs @@ -347,8 +347,8 @@ pub fn divide(left: ExTensor, right: ExTensor) -> Result } #[rustler::nif(schedule = "DirtyCpu")] -pub fn vec_dot(left: ExTensor, right: ExTensor) -> Result { - Ok(ExTensor::new(left.mul(right.deref())?.sum_all()?)) +pub fn dot(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new(left.mul(&right.broadcast_as(left.shape())?)?.sum(left.rank() - 1)?)) } macro_rules! unary_nif { diff --git a/test/candlex_test.exs b/test/candlex_test.exs index f5e3c3e..cc8efdf 100644 --- a/test/candlex_test.exs +++ b/test/candlex_test.exs @@ -542,14 +542,18 @@ defmodule CandlexTest do # Dot product of vector and n-D tensor - # t([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]], names: [:i, :j, :k]) - # |> Nx.dot(t([5.0, 10], names: [:x])) - # |> assert_equal(t( - # [ - # [25, 55], - # [85, 115] - # ] - # )) + t([[0.0]]) + |> Nx.dot(t([55.0])) + |> assert_equal(t([0.0])) + + t([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]]) + |> Nx.dot(t([5.0, 10])) + |> assert_equal(t( + [ + [25, 55], + [85, 115] + ] + )) # t([5.0, 10], names: [:x]) # |> Nx.dot(t([[1.0, 2, 3], [4, 5, 6]], names: [:i, :j]))