Skip to content

Commit

Permalink
Add tensordot benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
j-towns committed Sep 6, 2017
1 parent d5b45ef commit 82068d0
Showing 1 changed file with 40 additions and 3 deletions.
43 changes: 40 additions & 3 deletions benchmarks/bench_numpy_vjps.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
B = npr.randn(2, 3, 5, 4)
g = npr.randn(2, 3, 4, 2, 3, 4)

def time_dot():
np.dot(A, B)

def time_dot_0():
dot_0(A, B, g)

Expand All @@ -44,3 +41,43 @@ def time_dot_1_1():

def time_dot_1_2():
dot_1_2(A, B, g)

tensordot_0 = lambda A, B, G: make_vjp(np.tensordot)(A, B, 2)[0](G)
tensordot_1 = lambda A, B, G: make_vjp(np.tensordot, argnum=1)(A, B, 2)[0](G)

tensordot_0_0 = lambda A, B, G: make_vjp(tensordot_0)(A, B, G)[0](A)
tensordot_0_1 = lambda A, B, G: make_vjp(tensordot_0)(A, B, G)[0](A)
tensordot_0_2 = lambda A, B, G: make_vjp(tensordot_0)(A, B, G)[0](A)

tensordot_1_0 = lambda A, B, G: make_vjp(tensordot_1)(A, B, G)[0](B)
tensordot_1_1 = lambda A, B, G: make_vjp(tensordot_1)(A, B, G)[0](B)
tensordot_1_2 = lambda A, B, G: make_vjp(tensordot_1)(A, B, G)[0](B)

A = npr.randn(2, 3, 5, 4)
B = npr.randn(5, 4, 2, 3)
G = npr.randn(2, 3, 2, 3)

def time_tensordot_0():
tensordot_0(A, B, G)

def time_tensordot_1():
tensordot_1(A, B, G)

def time_tensordot_0_0():
tensordot_0_0(A, B, G)

def time_tensordot_0_1():
tensordot_0_1(A, B, G)

def time_tensordot_0_2():
tensordot_0_2(A, B, G)

def time_tensordot_1_0():
tensordot_1_0(A, B, G)

def time_tensordot_1_1():
tensordot_1_1(A, B, G)

def time_tensordot_1_2():
tensordot_1_2(A, B, G)

0 comments on commit 82068d0

Please sign in to comment.