-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
API for user code to detect if it's being differentiated #66
Comments
As we briefly discussed in FluxML/NNlib.jl#434, there's another way to think about the problem. If I understand the motivation correctly, there are functions that behave differently during training (i.e. being differentiated) and normal execution, e.g. dropout or batch normalization. Let's use this generic example for simplicity: function func(x, training=false)
y = nothing
if training
y = do_this(x)
else
y = do_that(x)
end
return y
end Understanding branches is a hard problem for AD, so many systems assume code execution always takes the same path. For example, Yota would first trace the code with the provided inputs and assume it never changes. E.g. after tracing it will see equivalent of either this: # when training == true
function func(x, not_used_training_flag)
return do_this(x)
end or this: # when training == false
function func(x, not_used_training_flag)
return do_that(x)
end This leads to a lot of funny things. For example, if we hardcode
Another approach is to provide a simple AD-friendly way to include branches, a kind of cond(flag, fn1, fn2) = ...
function func_with_cond(x, flag)
y = cond(flag, () -> do_this()(x), () -> do_that(x))
return y
end
function rrule(::typeof(cond), flag, fn1, fn2)
res, pb = flag ? rrule_via_ad(fn1) : rrule_via_ad(fn2) # a kind of...
function cond_pullback(dy)
dxs = pb(dy)
return NoTangent(), NoTangent(), dxs...
end
return res, cond_pullback
end The code is just an illustration, but I hope you get the idea. An AD system can then overload Another advantage of having a |
I think my worries with a
Otherwise I don't have any objections. |
|
Well that works for my simple case above, but then you have
Agreed. Fundamentally it feels like what we want is some kind of staged programming mechanism where one can partially evaluate conditionals like this before the rest of the pipeline runs. Given such a mechanism does not exist at present, this seems like the pragmatic solution. |
Hm, perhaps we put different meaning into
If you mean an operator that can have 2 branches - one actually taken and another unevaluated - then I have a design for such feature in Umlaut. Something like this: struct IfElse
cond::Union{Variable, Any}
true_branch::Union{Tape, Unevaluated}
false_branch::Union{Tape, Unevaluated}
end where But of course it will only work for Umlaut, maybe for Tracker, but unlikely in Zygote or Diffractor since they don't have a user-level execution graph. |
My understanding of
In other words, there is an implicit "are we differentiating?" variable which is and-ed with |
I think this is where disagree - the point of function dropout(x, active::Bool)
if active
mask = create_mask(x)
return x .* mask
else
return x
end
end What we want to do with this code is to:
The problem arises during the first stage when tracer has to choose only one branch and ignore the other. Depending on the value of
or
The other part of the information is just lost. All further transformations thus have to make assumptions about the tracer behavior or pass some implicit flags. The idea behind
(here I replicate API from JAX'x cond, which seems to be a better option than what I posted previously) Further transformations now have all the information to behave correctly in all cases. The ChainRules-based AD transformation, for example, can produce graph like this:
function rrule(::typeof(cond), flag::Bool, true_fn, false_fn, args...)
return flag ? rrule_via_ad(true_fn, args...) : rrule_via_ad(false_fn, args...)
end A more sophisticated AD system can also choose to treat rrule(cond, active, true_fn, false_fn, x) record cond(active, rrule_via_ad(true_fn, x), rrule_via_ad(false_fn, y)) or even something more complicated and efficient. So there's no need in a special %1 = flag1 || flag2
%2 = cond(%1, true_fn(x), false_fn(x)) for # main graph
%1 = cond(flag1, true_fn(x), false_fn(x))
# false_fn subgraph
% 1 = cond(flag2, true_fn2(x), false_fn2(x))
# main graph
%1, %2 = rrule(cond, flag1, true_fn(x), false_fn(x))
# false_fn subgraph - invoked inside of `rrule_via_ad(false_fn, x)`
%1, %2 = rrule(cond, flag2, true_fn2(x), false_fn2(x))
... |
I'm not sure I understand how this tracing works then. To get a value for In fact, I don't even know if |
That's the point -
Usually, people set
Absolutely! At the moment, |
It doesn't affect the outcome of this issue discussion so I think this can be continued elsewhere, but the impetus for this whole thread was the default case of when people don't set |
Yeah, it looks like we went quite off the topic :) The whole discussion also makes me think how hard it is to design common API for very different AD systems. I think in terms of computational graphs and transformations on them, much like JAX, and you provide examples from PyTorch-like systems. Maybe such discussions are easier when we have a working prototype in at least one system. |
The most recent iteration of this discussion was FluxML/NNlib.jl#434, but it goes back to JuliaDiff/ChainRulesCore.jl#547 and much further. Given that not all ADs are ChainRules compatible but this package seeks to support (/eventually be supported by) all ADs, it feels like a much better home for such functionality than a domain-specific package like NNlib.
The text was updated successfully, but these errors were encountered: