-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NaN in gradient when branch not taken has Inf shadow #2285
Comments
Here's an alternative MWE where the branch condition is distinct from the value of julia> function f(cond, a, x)
if cond
x += a * x
end
return x
end;
julia> autodiff(Reverse, f, Active, Const(false), Const(Inf), Active(1.0))
after simplification :
; Function Attrs: mustprogress nofree readonly willreturn
define "enzyme_type"="{[-1]:Float@double}" double @preprocess_julia_f_545(i8 zeroext "enzyme_inactive" "enzyme_type"="{[-1]:Integer}" "enzymejl_parmtype"="4773885328" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4773884592" "enzymejl_parmtype_ref"="0" %1, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4773884592" "enzymejl_parmtype_ref"="0" %2) local_unnamed_addr #3 !dbg !17 {
top:
%3 = call {}*** @julia.get_pgcstack() #4
%ptls_field3 = getelementptr inbounds {}**, {}*** %3, i64 2
%4 = bitcast {}*** %ptls_field3 to i64***
%ptls_load45 = load i64**, i64*** %4, align 8, !tbaa !8
%5 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
%safepoint = load i64*, i64** %5, align 8, !tbaa !12
fence syncscope("singlethread") seq_cst
call void @julia.safepoint(i64* %safepoint) #4, !dbg !18
fence syncscope("singlethread") seq_cst
%6 = and i8 %0, 1, !dbg !19
%.not = icmp eq i8 %6, 0, !dbg !19
%7 = fmul double %1, %2, !dbg !19
%8 = select i1 %.not, double -0.000000e+00, double %7, !dbg !19
%value_phi = fadd double %8, %2, !dbg !19
ret double %value_phi, !dbg !20
}
; Function Attrs: mustprogress nofree
define internal { double } @diffejulia_f_545(i8 zeroext "enzyme_inactive" "enzyme_type"="{[-1]:Integer}" "enzymejl_parmtype"="4773885328" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4773884592" "enzymejl_parmtype_ref"="0" %1, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4773884592" "enzymejl_parmtype_ref"="0" %2, double %differeturn) local_unnamed_addr #4 !dbg !21 {
top:
%"value_phi'de" = alloca double, align 8
%3 = getelementptr double, double* %"value_phi'de", i64 0
store double 0.000000e+00, double* %3, align 8
%"'de" = alloca double, align 8
%4 = getelementptr double, double* %"'de", i64 0
store double 0.000000e+00, double* %4, align 8
%"'de1" = alloca double, align 8
%5 = getelementptr double, double* %"'de1", i64 0
store double 0.000000e+00, double* %5, align 8
%"'de2" = alloca double, align 8
%6 = getelementptr double, double* %"'de2", i64 0
store double 0.000000e+00, double* %6, align 8
%7 = call {}*** @julia.get_pgcstack() #5
%ptls_field3 = getelementptr inbounds {}**, {}*** %7, i64 2
%8 = bitcast {}*** %ptls_field3 to i64***
%ptls_load45 = load i64**, i64*** %8, align 8, !tbaa !8, !alias.scope !22, !noalias !25
%9 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
%safepoint = load i64*, i64** %9, align 8, !tbaa !12, !alias.scope !27, !noalias !30
fence syncscope("singlethread") seq_cst
call void @julia.safepoint(i64* %safepoint) #5, !dbg !32
fence syncscope("singlethread") seq_cst
%10 = and i8 %0, 1, !dbg !33
%.not = icmp eq i8 %10, 0, !dbg !33
br label %inverttop, !dbg !34
inverttop: ; preds = %top
store double %differeturn, double* %"value_phi'de", align 8
%11 = load double, double* %"value_phi'de", align 8, !dbg !33
store double 0.000000e+00, double* %"value_phi'de", align 8, !dbg !33
%12 = load double, double* %"'de", align 8, !dbg !33
%13 = fadd fast double %12, %11, !dbg !33
store double %13, double* %"'de", align 8, !dbg !33
%14 = load double, double* %"'de1", align 8, !dbg !33
%15 = fadd fast double %14, %11, !dbg !33
store double %15, double* %"'de1", align 8, !dbg !33
%16 = load double, double* %"'de", align 8, !dbg !33
%diffe = select fast i1 %.not, double 0.000000e+00, double %16, !dbg !33
store double 0.000000e+00, double* %"'de", align 8, !dbg !33
%17 = load double, double* %"'de2", align 8, !dbg !33
%18 = fadd fast double %17, %16, !dbg !33
%19 = select fast i1 %.not, double %17, double %18, !dbg !33
store double %19, double* %"'de2", align 8, !dbg !33
%20 = load double, double* %"'de2", align 8, !dbg !33
store double 0.000000e+00, double* %"'de2", align 8, !dbg !33
%21 = fmul fast double %20, %1, !dbg !33
%22 = load double, double* %"'de1", align 8, !dbg !33
%23 = fadd fast double %22, %21, !dbg !33
store double %23, double* %"'de1", align 8, !dbg !33
fence syncscope("singlethread") seq_cst
fence syncscope("singlethread") seq_cst
%24 = load double, double* %"'de1", align 8
%25 = insertvalue { double } undef, double %24, 0
ret { double } %25
}
((nothing, nothing, NaN),) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
If a branch is not taken in the forward pass, but the shadow of the code in the branch contains
Inf
, the resulting gradient containsNaN
s. The problem disappears withEnzyme.API.strong_zero!(true)
, but this should not be needed according to @wsmoses on slack.The text was updated successfully, but these errors were encountered: