Skip to content

Commit

Permalink
feat: dot/2 supports dot product of n-D tensor with 1-D tensor (vector)
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Nov 1, 2023
1 parent 1ed960f commit 7225ecd
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 15 deletions.
19 changes: 17 additions & 2 deletions lib/candlex/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,27 @@ defmodule Candlex.Backend do
)
when tuple_size(left_shape) == 1 and tuple_size(right_shape) == 1 do
from_nx(left)
|> Native.vec_dot(from_nx(right))
|> Native.dot(from_nx(right))
|> unwrap!()
|> to_nx(out)
end

def dot(
%T{type: _out_type} = out,
%T{shape: left_shape, type: _left_type} = left,
[left_axis] = _left_axes,
[] = _left_batched_axes,
%T{shape: right_shape, type: _right_type} = right,
[0] = _right_axes,
[] = _right_batched_axes
)
when tuple_size(left_shape) > 1 and tuple_size(right_shape) == 1 and left_axis == tuple_size(left_shape) - 1 do
from_nx(left)
|> Native.dot(from_nx(right))
|> unwrap!()
|> to_nx(out)
end

@impl true
def dot(
%T{type: _out_type} = out,
%T{shape: left_shape, type: _left_type} = left,
Expand Down
4 changes: 2 additions & 2 deletions lib/candlex/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ defmodule Candlex.Native do
:bitwise_or,
:bitwise_xor,
:divide,
:dot,
:equal,
:greater,
:greater_equal,
Expand All @@ -107,8 +108,7 @@ defmodule Candlex.Native do
:quotient,
:remainder,
:right_shift,
:subtract,
:vec_dot
:subtract
] do
def unquote(op)(_left, _right), do: error()
end
Expand Down
2 changes: 1 addition & 1 deletion native/candlex/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ rustler::init! {
tensors::permute,
tensors::slice_scatter,
tensors::pad_with_zeros,
tensors::vec_dot,
tensors::dot,
tensors::matmul,
tensors::abs,
tensors::acos,
Expand Down
4 changes: 2 additions & 2 deletions native/candlex/src/tensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,8 @@ pub fn divide(left: ExTensor, right: ExTensor) -> Result<ExTensor, CandlexError>
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn vec_dot(left: ExTensor, right: ExTensor) -> Result<ExTensor, CandlexError> {
Ok(ExTensor::new(left.mul(right.deref())?.sum_all()?))
pub fn dot(left: ExTensor, right: ExTensor) -> Result<ExTensor, CandlexError> {
Ok(ExTensor::new(left.mul(&right.broadcast_as(left.shape())?)?.sum(left.rank() - 1)?))
}

macro_rules! unary_nif {
Expand Down
20 changes: 12 additions & 8 deletions test/candlex_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -542,14 +542,18 @@ defmodule CandlexTest do

# Dot product of vector and n-D tensor

# t([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]], names: [:i, :j, :k])
# |> Nx.dot(t([5.0, 10], names: [:x]))
# |> assert_equal(t(
# [
# [25, 55],
# [85, 115]
# ]
# ))
t([[0.0]])
|> Nx.dot(t([55.0]))
|> assert_equal(t([0.0]))

t([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]])
|> Nx.dot(t([5.0, 10]))
|> assert_equal(t(
[
[25, 55],
[85, 115]
]
))

# t([5.0, 10], names: [:x])
# |> Nx.dot(t([[1.0, 2, 3], [4, 5, 6]], names: [:i, :j]))
Expand Down

0 comments on commit 7225ecd

Please sign in to comment.