Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nx.dot: complete implementation #14

Closed
7 tasks done
grzuy opened this issue Oct 31, 2023 · 0 comments
Closed
7 tasks done

Nx.dot: complete implementation #14

grzuy opened this issue Oct 31, 2023 · 0 comments

Comments

@grzuy
Copy link
Collaborator

grzuy commented Oct 31, 2023

Nx.dot/2:

Returns the dot product of two tensors.

Given a and b, computes the dot product according to the following rules:

  • If both a and b are scalars, it is equivalent to a * b.
  • If a is a scalar and b is a tensor, it is equivalent to Nx.multiply(a, b).
  • If a is a tensor and b is a scalar, it is equivalent to Nx.multiply(a, b).
  • If both a and b are 1-D tensors (vectors), it is the sum of the element-wise product between a and b. The lengths of a and b must be equal.
  • If both a and b are 2-D tensors (matrices), it is equivalent to matrix-multiplication.
  • If either a or b is a 1-D tensor, and the other is an n-D tensor, it is the sum of the element-wise product along the last axis of a or b. The length of the 1-D tensor must match the last dimension of the n-D tensor.
  • If a is an n-D tensor and b is an m-D tensor, it is the sum of the element-wise product along the last axis of a and the second-to-last axis of b. The last dimension of a must match the second-to-last dimension of b.

Cases:

  • If both a and b are scalars, it is equivalent to a * b.
  • If a is a scalar and b is a tensor, it is equivalent to Nx.multiply(a, b).
  • If a is a tensor and b is a scalar, it is equivalent to Nx.multiply(a, b).
  • If both a and b are 1-D tensors (vectors), it is the sum of the element-wise product between a and b. The lengths of a and b must be equal.
  • If both a and b are 2-D tensors (matrices), it is equivalent to matrix-multiplication.
  • If either a or b is a 1-D tensor, and the other is an n-D tensor, it is the sum of the element-wise product along the last axis of a or b. The length of the 1-D tensor must match the last dimension of the n-D tensor. (feat: dot/2 supports receiving 1-D tensors (vectors) #17)
  • If a is an n-D tensor and b is an m-D tensor, it is the sum of the element-wise product along the last axis of a and the second-to-last axis of b. The last dimension of a must match the second-to-last dimension of b. (feat: Nx.dot, support n x m #51)

Right now is partially implemented given the limitation of having only matmul in candle-core.

See partial test coverage we have in

test "dot/2" do
# Dot product of scalars
Nx.dot(5, 5)
|> assert_equal(t(25))
Nx.dot(-2.0, 5.0)
|> assert_equal(t(-10.0))
Nx.dot(2, 2.0)
|> assert_equal(t(4.0))
# Dot product of vectors
t([1, 2, 3])
|> Nx.dot(t([4, 5, 6]))
|> assert_equal(t(32))
t([1.0, 2, 3])
|> Nx.dot(t([1, 2, 3]))
|> assert_equal(t(14.0))
# Dot product of matrices (2-D tensors)
# TODO: Candle matmul doesn't support integers yet
# t([[1, 2, 3], [4, 5, 6]])
# |> Nx.dot(t([[7, 8], [9, 10], [11, 12]]))
# |> assert_equal(t(
# [
# [58, 64],
# [139, 154]
# ]
# ))
t([[1.0, 2, 3], [4, 5, 6]])
|> Nx.dot(t([[7, 8], [9, 10], [11, 12]]))
|> assert_equal(
t([
[58.0, 64],
[139, 154]
])
)
# Dot product of vector and n-D tensor
t([[0.0]])
|> Nx.dot(t([55.0]))
|> assert_equal(t([0.0]))
t([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]])
|> Nx.dot(t([5, 10]))
|> assert_equal(
t([
[25.0, 55],
[85, 115]
])
)
# t([5.0, 10], names: [:x])
# |> Nx.dot(t([[1.0, 2, 3], [4, 5, 6]], names: [:i, :j]))
# |> assert_equal(t(
# [45, 60, 75]
# ))
# t([[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]], names: [:shard, :batch, :x, :y, :z])
# |> Nx.dot(t([2.0, 2.0], names: [:data]))
# |> assert_equal(t(
# [
# [
# [
# [6.0, 14.0],
# [22.0, 30.0]
# ]
# ]
# ]
# ))
# Dot product of n-D and m-D tensors
# t([[[1.0, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:x, :y, :z])
# |> Nx.dot(t([[[1.0, 2, 3], [3, 4, 5], [5, 6, 7]]], names: [:i, :j, :k]))
# |> assert_equal(t(
# [
# [
# [
# [22, 28, 34]
# ],
# [
# [49, 64, 79]
# ],
# [
# [76, 100, 124]
# ]
# ],
# [
# [
# [22, 28, 34]
# ],
# [
# [49, 64, 79]
# ],
# [
# [76, 100, 124]
# ]
# ]
# ]
# ))
end
test "dot/6" do
# Contracting along axes
t1 = t([[1.0, 2], [3, 4]], names: [:x, :y])
t2 = t([[10.0, 20], [30, 40]], names: [:height, :width])
t1
|> Nx.dot([0], [], t2, [0], [])
|> assert_equal(
t([
[100, 140],
[140, 200]
])
)
# TODO:
t1
|> Nx.dot([0], [], t2, [1], [])
|> assert_equal(
t([
[70, 150],
[100, 220]
])
)
t1
|> Nx.dot([1], [], t2, [0], [])
|> assert_equal(
t([
[70, 100],
[150, 220]
])
)
# t1
# |> Nx.dot([1], [], t2, [1], [])
# |> assert_equal(t(
# [
# [50, 110],
# [110, 250]
# ]
# ))
# t1
# |> Nx.dot([0, 1], [], t2, [0, 1], [])
# |> assert_equal(t(300))
end
.

@grzuy grzuy changed the title Complete Nx.dot Nx.dot: more complete implementation Nov 21, 2023
@grzuy grzuy changed the title Nx.dot: more complete implementation Nx.dot: complete implementation Nov 22, 2023
@grzuy grzuy closed this as completed Nov 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant