Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use multiple arguments instead of a tuple for pushforward and pullback function? #53

Open
devmotion opened this issue Feb 9, 2022 · 6 comments
Labels
design Package structure and correctness help wanted Extra attention is needed question Inquiries and discussions

Comments

@devmotion
Copy link
Member

It seems annoying that the pushforward and pullback function accept tuples of co-tangents instead of multiple arguments. Is there a compelling reason for doing so or was this a design decision that could be changed? In my opinion the main annoyance is that one has to handle the case of tuples of length 1 in a special way (as e.g. in #51) (it also makes it impossible to work with actual single-argument functions that take a tuple as only argument but maybe this is not needed anyway). Arguably it is also cleaner to provide multiple arguments as, well, multiple arguments instead of a tuple.

@mohamed82008
Copy link
Member

Yes I think this can be considered along with #35.

@sethaxen
Copy link
Member

sethaxen commented Feb 9, 2022

It seems annoying that the pushforward and pullback function accept tuples of co-tangents instead of multiple arguments. Is there a compelling reason for doing so or was this a design decision that could be changed?

While Julia functions may take multiple inputs, no Julia function returns multiple outputs. Instead, they might return a tuple of outputs. The (co)tangent of a tuple is like a tuple itself. FWIW, this is consistent with how ChainRules behaves, hence why in ChainRules, it would be represented as a Tangent, and here it would be represented as a tuple. One could make the case that since AD.jl supports only functions whose inputs and outputs are arrays, then if such a function returns a tuple it can only be interpreted as multiple outputs, but that would be inconsistent at least with ChainRules and Zygote.

it also makes it impossible to work with actual single-argument functions that take a tuple as only argument but maybe this is not needed anyway

I don't think function with tuple inputs would be supported anyways.

@devmotion
Copy link
Member Author

While Julia functions may take multiple inputs, no Julia function returns multiple outputs. Instead, they might return a tuple of outputs. The (co)tangent of a tuple is like a tuple itself. FWIW, this is consistent with how ChainRules behaves, hence why in ChainRules, it would be represented as a Tangent, and here it would be represented as a tuple. One could make the case that since AD.jl supports only functions whose inputs and outputs are arrays, then if such a function returns a tuple it can only be interpreted as multiple outputs, but that would be inconsistent at least with ChainRules and Zygote.

Sure, multiple outputs are in fact just a tuple of outputs. But it does not necessarily mean that we have to use a tuple as input to the pullback and pushforward function.

The current design is also not completely consistent with ChainRules: In ChainRules one does not have to consider tuples of co-tangents of length 1 - the pullback function of a function with a single output just takes a single co-tangent without wrapping it as a tuple. Neglecting/not supporting tuples of length 1 would already solve the special case in #51, even if we stick with tuples in case of multiple outputs.

@sethaxen
Copy link
Member

I think in general AD.jl has a funny relationship with inputs and outputs. Like gradient for a single input returns a tuple, and hessian only supports single inputs and yet still returns a tuple. IMO this should be changed.

The pushforward of a function (talking about the actual pushforward, not the fusion of the pushforward and primal that frule encodes) should be structured the same as the primal in terms of inputs and outputs. The pullback is the adjoint of the pushforward and vice versa, so a useful check of consistency is whether the rules we choose are symmetric.

i.e., these rules would maintain this symmetry, and perhaps they make sense:

  • The adjoint of a single-argument function returns a single output (not a tuple)
  • The adjoint of a multi-argument function returns a tuple of outputs
  • The adjoint of a single-output (non-tuple) function takes a single input
  • The adjoint of a multi-output (tuple) function takes multiple inputs

This is almost consistent with ChainRules, the key differences being that 1) in ChainRules, the function is treated as an argument, so there are no single-argument functions (or at least, I don't know of any examples where a rule is defined for a 0-argument function), hence all pullbacks return tuples and 2) a function might actually return a tuple directly, so it's not safe to interpret a tuple return value as being multiple outputs.

@gdalle
Copy link
Member

gdalle commented Aug 9, 2023

I could try to give this a shot once #93 is merged

This was referenced Sep 19, 2023
@gdalle gdalle added design Package structure and correctness question Inquiries and discussions help wanted Extra attention is needed labels Oct 5, 2023
@gdalle
Copy link
Member

gdalle commented Dec 21, 2023

Starting to work on this and I'm wondering what to do with the lazy derivatives? Only allow them for a single input / output? It's a bit counterintuitive to apply matrix multiplication on tuple anyway

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Package structure and correctness help wanted Extra attention is needed question Inquiries and discussions
Projects
None yet
Development

No branches or pull requests

4 participants