Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API for user code to detect if it's being differentiated #66

Open
ToucheSir opened this issue Nov 24, 2022 · 11 comments
Open

API for user code to detect if it's being differentiated #66

ToucheSir opened this issue Nov 24, 2022 · 11 comments
Labels
feature New feature or request question Inquiries and discussions

Comments

@ToucheSir
Copy link

The most recent iteration of this discussion was FluxML/NNlib.jl#434, but it goes back to JuliaDiff/ChainRulesCore.jl#547 and much further. Given that not all ADs are ChainRules compatible but this package seeks to support (/eventually be supported by) all ADs, it feels like a much better home for such functionality than a domain-specific package like NNlib.

@dfdx
Copy link

dfdx commented Jan 8, 2023

As we briefly discussed in FluxML/NNlib.jl#434, there's another way to think about the problem. If I understand the motivation correctly, there are functions that behave differently during training (i.e. being differentiated) and normal execution, e.g. dropout or batch normalization. Let's use this generic example for simplicity:

function func(x, training=false)
    y = nothing
    if training
        y = do_this(x)
    else
        y = do_that(x)
    end
    return y
end

Understanding branches is a hard problem for AD, so many systems assume code execution always takes the same path. For example, Yota would first trace the code with the provided inputs and assume it never changes. E.g. after tracing it will see equivalent of either this:

# when training == true
function func(x, not_used_training_flag)
    return do_this(x)
end

or this:

# when training == false
function func(x, not_used_training_flag)
    return do_that(x)
end

This leads to a lot of funny things. For example, if we hardcode training=true into rrule() (which sounds reasonable, because rrule is part of differentiation, right?), then the branch taken will depend on whether we first trace and then transform rrule() or first transform and then trace. Ironically, this is exactly the implementation detail I changed in Yota in the most recent major release, and didn't even know it breaks the behavior!

within_gradient() attempts to solve this problem by letting the function author to take the appropriate branch depending on the context. But in a multistage execution (e.g. tracing -> differentiation -> some other transformations) analyzing corner cases becomes quite hard.

Another approach is to provide a simple AD-friendly way to include branches, a kind of cond function. Just to illustrate the idea:

cond(flag, fn1, fn2) = ...

function func_with_cond(x, flag)
    y = cond(flag, () -> do_this()(x), () -> do_that(x))
    return y
end

function rrule(::typeof(cond), flag, fn1, fn2)
    res, pb = flag ? rrule_via_ad(fn1) : rrule_via_ad(fn2)  # a kind of...
    function cond_pullback(dy)
        dxs = pb(dy)
        return NoTangent(), NoTangent(), dxs...
    end
    return res, cond_pullback
end

The code is just an illustration, but I hope you get the idea. An AD system can then overload cond() (e.g. Tracker.jl), rewrite to rrule(::typeof(cond), ...) or do something more sophisticated. But there's no more need to treat differentiation context in a specific way - it's a general-purpose if, just written in an unusual way.

Another advantage of having a cond-like function is that it's compatible with conditionals in model exchanging formats like IF in ONNX, which should help to bring more pre-trained models to Julia.

@ToucheSir
Copy link
Author

ToucheSir commented Jan 8, 2023

I think my worries with a cond-like construct are twofold:

  1. To avoid the within_gradient problem, it seems like you'd need specialized conds for every kind of branch that might change under AD. That or duplicate code for any additional conditionals, i.e. within_gradient(...) && othercondition now needs to become othercondition ? cond(...) : cond(...) or cond(flag, () -> othercondition ? ..., () -> othercondition ? ...).
  2. Chains of ... rrule -> rrule_via_ad -> ... rrule -> rrule_via_ad -> ... tend to trip up the recursion limit heuristic. This leads to additional dynamic dispatch, compilation and other overhead. It's not clear to me whether we can rely on work like RFC: Less aggressive recursion limiting JuliaLang/julia#48059 landing any time soon to alleviate this.

Otherwise I don't have any objections.

@dfdx
Copy link

dfdx commented Jan 9, 2023

  1. Wouldn't it be just cond(somecondition && othercondition, () -> ..., () -> ...)? Also, if you worry about compilation time due to multiple specializations, I believe @nospecialize should help a lot here.
  2. It's a valid point and actually one of the reasons I currently experiment with a very restrictive approach to AD in Remix.jl. Sometimes, It looks like we spend too much time fighting the compiler instead of focusing on more mundane things, so I'm trying to find a less compile-intensive approach. In particular, my target vision for cond, as for any other allowed operation, is to be traced just once, transformed in runtime and finally compiled to the configured backend only at the very end. But it's also a huge experiment, and right now I have neither prototype, nor exact design to share.

@ToucheSir
Copy link
Author

  1. Wouldn't it be just cond(somecondition && othercondition, () -> ..., () -> ...)? Also, if you worry about compilation time due to multiple specializations, I believe @nospecialize should help a lot here.

Well that works for my simple case above, but then you have ||, elseif, etc. @nospecialize might help with cond itself, but my bigger worry would be redundancy across the branch functions. Maybe cond could pass flag to each branch so that one could feasibly use a single callback with if statement for both? Not sure if that's compatible with tracing.

2. Sometimes, It looks like we spend too much time fighting the compiler instead of focusing on more mundane things...

Agreed. Fundamentally it feels like what we want is some kind of staged programming mechanism where one can partially evaluate conditionals like this before the rest of the pipeline runs. Given such a mechanism does not exist at present, this seems like the pragmatic solution.

@dfdx
Copy link

dfdx commented Jan 10, 2023

Well that works for my simple case above, but then you have ||, elseif, etc.

Hm, perhaps we put different meaning into somecondition and othercondition. I assume both are just ordinary booleans, so their combination is an ordinary boolean too. cond() behaves just like ifelse(), but maybe with slightly more complex implementation or special methods. Do you assume some other setup?

elseif is similar to nested ifselse(..., ifelse(...)) and is harder to analyze, but that should happen too often in neural networks I guess. In fact, I know only two major functions with conditionals - dropout and batchnorm - and both have exactly one conditional branching in them.

Fundamentally it feels like what we want is some kind of staged programming mechanism where one can partially evaluate conditionals like this before the rest of the pipeline runs

If you mean an operator that can have 2 branches - one actually taken and another unevaluated - then I have a design for such feature in Umlaut. Something like this:

struct IfElse
    cond::Union{Variable, Any}
    true_branch::Union{Tape, Unevaluated}
    false_branch::Union{Tape, Unevaluated}
end

where Unevaluated is a structure holding the whole execution context - instance of IRCode, execution frame, mappings between variables, etc. This way we analyze one branch during the initial tracing and postpone analyzing the other one until it's actually taken.

But of course it will only work for Umlaut, maybe for Tracker, but unlikely in Zygote or Diffractor since they don't have a user-level execution graph.

@ToucheSir
Copy link
Author

Do you assume some other setup?

My understanding of cond(flag, true_fn, false_fn) is that it obeys the following truth table:

differentiating? flag branch taken
T T true_fn
T F false_fn
F T false_fn
F F false_fn

In other words, there is an implicit "are we differentiating?" variable which is and-ed with flag. This variable needs to be implicit and not passed into cond, because otherwise we run into the same issue within_gradient has with tracing. Assuming true_fn and false_fn are zero-arg functions, notice how this doesn't provide fine-grained control over 1/2 the possible cases. In fact, I don't think taking flag as an argument actually adds much here over just capturing it in one or both branches as desired. Passing it through as an arg to each branch (and not implicitly and-ing it) might be more optimization friendly, but I imagine that breaks the mental symmetry with ifelse.

@dfdx
Copy link

dfdx commented Jan 11, 2023

This variable needs to be implicit and not passed into cond, because otherwise we run into the same issue within_gradient has with tracing.

I think this is where disagree - the point of cond is exactly to avoid the problem with tracing in any context. Let's take it step by step. Imagine a function like this (roughly equivalent to dropout implementation):

function dropout(x, active::Bool)
    if active
        mask = create_mask(x)
        return x .* mask
    else
        return x
    end
end

What we want to do with this code is to:

  1. Trace it into a computational graph.
  2. Transform the graph for differentiating.

The problem arises during the first stage when tracer has to choose only one branch and ignore the other. Depending on the value of active the graph will be either:

%1 = create_mask(x)
%2 = x .* %1
ret %2

or

ret x

The other part of the information is just lost. All further transformations thus have to make assumptions about the tracer behavior or pass some implicit flags.

The idea behind cond is to preserve the information about both branches:

%1 = true_fn
%2 = false_fn
%3 = cond(active, true_fn, false_fn, x)
ret %3

(here I replicate API from JAX'x cond, which seems to be a better option than what I posted previously)

Further transformations now have all the information to behave correctly in all cases. The ChainRules-based AD transformation, for example, can produce graph like this:

%1 = true_fn
%2 = false_fn
%3, %4 = rrule(cond, active, true_fn, false_fn, x)   # %3 == val, %4 == pullback
%5 = 1                 # seed, i.e. initial dy
%6 = %4(%5)            # dxs = pullback(seed)

rrule(cond, ...) has access to the flag active, functions of both branches and their arguments, so something like this should work:

function rrule(::typeof(cond), flag::Bool, true_fn, false_fn, args...)
    return flag ? rrule_via_ad(true_fn, args...) : rrule_via_ad(false_fn, args...)
end

A more sophisticated AD system can also choose to treat cond in a special way, and instead of

rrule(cond, active, true_fn, false_fn, x)

record

cond(active, rrule_via_ad(true_fn, x), rrule_via_ad(false_fn, y))

or even something more complicated and efficient.


So there's no need in a special differentiating flag. There still may be several conditions, but I don't see any issues with them too. E.g. for ||:

%1 = flag1 || flag2
%2 = cond(%1, true_fn(x), false_fn(x))

for elseif:

# main graph
%1 = cond(flag1, true_fn(x), false_fn(x))

# false_fn subgraph
% 1 = cond(flag2, true_fn2(x), false_fn2(x))

elseif after ChainRules transformation:

# main graph
%1, %2 = rrule(cond, flag1, true_fn(x), false_fn(x))

# false_fn subgraph - invoked inside of `rrule_via_ad(false_fn, x)`
%1, %2 = rrule(cond, flag2, true_fn2(x), false_fn2(x))
...

@ToucheSir
Copy link
Author

I'm not sure I understand how this tracing works then. To get a value for active, you ultimately need to call something which returns true when not differentiating and false when differentiating. Doing any branching on the value of active (including && and ||, which lower to control flow/branches) will lead to the within_grad issue. cond works because it's special-cased as an opaque function to the tracer, but being limited to only non-short-circuiting bool operations between getting the value of active and passing it to cond seems quite limiting (though still workable for Flux).

In fact, I don't even know if cond belongs in a diff-related library—it seems general-purpose enough to warrant inclusion in some hypothetical TracingPrimitives.jl or OpaqueConditionals.jl. Alternatively, could there be a way to mark active as an opaque value for the tracer such that an IfElse is always generated for conditionals branching on it?

@dfdx
Copy link

dfdx commented Jan 11, 2023

To get a value for active, you ultimately need to call something which returns true when not differentiating and false when differentiating.

That's the point - active and differentiation are independent. All four combinations are valid:

active differentiating example
F F dropout(x, false)
T F dropout(x, true)
F T rrule(dropout, x, false)
T T rrule(dropout, x, true)

Usually, people set active = true while training and differentiate code while training, but strictly speaking nobody forbids you to set active = true during inference. active is a flag passed from the top-level (e.g. via trainmode!()). Differentiating is a transformation that can work on any valid primitive. cond is one of such primitives. All three can be mixed or used independently.

In fact, I don't even know if cond belongs in a diff-related library—it seems general-purpose enough to warrant inclusion in some hypothetical TracingPrimitives.jl or OpaqueConditionals.jl.

Absolutely! At the moment, cond itself is a hypothetical op :) But once it gets shaped, something like TracingPrimitives sounds reasonable.

@ToucheSir
Copy link
Author

That's the point - active and differentiation are independent.

Usually, people set active = true while training and differentiate code while training, but strictly speaking nobody forbids you to set active = true during inference. active is a flag passed from the top-level (e.g. via trainmode!()).

It doesn't affect the outcome of this issue discussion so I think this can be continued elsewhere, but the impetus for this whole thread was the default case of when people don't set active at all. Then active and differentiation become tightly coupled, and this of course upsets tracing because it expects them not to be. The proximal solution seems to be creating a primitive like cond that tracing can't pierce and being very, very careful about not getting troublesome values like active caught in conditionals, but it would also be nice to step back a bit and brainstorm if there are more generic solutions. For example, TorchDynamo has a concept of "graph breaks", which allows it to avoid the pitfall of only capturing the traced branch.

@dfdx
Copy link

dfdx commented Jan 11, 2023

Yeah, it looks like we went quite off the topic :) The whole discussion also makes me think how hard it is to design common API for very different AD systems. I think in terms of computational graphs and transformations on them, much like JAX, and you provide examples from PyTorch-like systems. Maybe such discussions are easier when we have a working prototype in at least one system.

@gdalle gdalle added feature New feature or request question Inquiries and discussions labels Oct 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature or request question Inquiries and discussions
Projects
None yet
Development

No branches or pull requests

3 participants