Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensordot adjoints #287

Merged
merged 3 commits into from
Sep 6, 2017
Merged

Tensordot adjoints #287

merged 3 commits into from
Sep 6, 2017

Conversation

j-towns
Copy link
Collaborator

@j-towns j-towns commented Sep 6, 2017

I've given tensordot the adjoint treatment, and this seems to have improved performance (though not as dramatically as for dot):

    before     after       ratio
  [4e517a01] [82068d0e]
-  366.45μs   241.00μs      0.66  bench_numpy_vjps.time_tensordot_1_0
-  380.09μs   238.82μs      0.63  bench_numpy_vjps.time_tensordot_1_2
-  379.86μs   236.12μs      0.62  bench_numpy_vjps.time_tensordot_1_1

There should also be a good knock on effect for primitives like inner and matmul which use these adjoints as their derivatives.

By the way, I just noticed that from Numpy 1.4 Einsum is going to use the parallelized BLAS dot routine when possible. Maybe then we should have one efficient Einsum adjoint routine which is used by all of the linear operators, and that ought to have pretty good performance for everything and would be less code. @mattjj what do you reckon?

Edit: there's some more details numpy/numpy#9425

@mattjj
Copy link
Contributor

mattjj commented Sep 6, 2017

Awesome improvements on tensordot! Those performance wins are huge, and (as with dot) I wouldn't have guessed such big improvements were possible. Your point about knock-on effects makes sense too.

Wow, it sounds like einsum is getting a lot of attention. You're right that we might be able to use it everywhere! If it ends up being a performance win and less code then the decision would be easy, but it would be good to do some tests first.

@mattjj mattjj merged commit 0c72e22 into HIPS:dev-1.2 Sep 6, 2017
@j-towns
Copy link
Collaborator Author

j-towns commented Oct 14, 2017

Maybe then we should have one efficient Einsum adjoint routine which is used by all of the linear operators, and that ought to have pretty good performance for everything and would be less code.

I've been thinking a bit more about this today. I've noticed there's a (fairly obvious) relationship between the JOs and the JTOs which can be described something like this...

The JO of a primitive fun, with one argument (denoted x), will always have the form

JO(fun)(x)(u) = bilinear_operator(deriv(x), u)

where deriv is some function of x (and sometimes in practice also of ans = fun(x)) and bilinear_operator is a function of two arguments which is linear in it's first (second) argument holding it's second (first) argument constant. For example, for unary ufuncs bilinear_operator is often just lambda x, v: x * v, but sometimes it's more complicated and is expressed as a dot, tensordot or einsum.

Using the above notation, the JTO of fun is (I think) always of the form

JTO(fun)(x)(v) = transpose(bilinear_operator)(deriv(x), v)

where transpose is a function which (somehow) transposes the bilinear operator w.r.t. it's second argument, for example it would map tensordot to tensordot_adjoint_1 from this pr. Note that deriv(x) is unchanged, so perhaps for the primitive JO/JTO definitions we could (and should) separate the definition of deriv from the specification of bilinear_operator. That way deriv could be shared between the JO and the JTO. Furthermore, we could use einsum notation to specify bilinear_operator, which we know how to transpose, since we already do it in the vjp of einsum.

I think we could save ourselves from separately specifying jvps and vjps in this way. We would need to think about how to do transpose efficiently. This could potentially make implementing #280 a little more straightforward too, by isolating the changes needed in binary_operator and transpose. There might be issues with this approach and/or subtleties that I'm missing. I haven't thought yet about whether it would work ok for complex derivatives.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants