Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN in gradient when branch not taken has Inf shadow #2285

Open
danielwe opened this issue Jan 28, 2025 · 1 comment
Open

NaN in gradient when branch not taken has Inf shadow #2285

danielwe opened this issue Jan 28, 2025 · 1 comment

Comments

@danielwe
Copy link
Contributor

danielwe commented Jan 28, 2025

If a branch is not taken in the forward pass, but the shadow of the code in the branch contains Inf, the resulting gradient contains NaNs. The problem disappears with Enzyme.API.strong_zero!(true), but this should not be needed according to @wsmoses on slack.

julia> using Enzyme

julia> Enzyme.API.printall!(true)

julia> function f(a, x)
           if !isinf(a)
               x += a * x
           end
           return x
       end;

julia> autodiff(Reverse, f, Active, Const(Inf), Active(1.0))
after simplification :
; Function Attrs: mustprogress nofree readonly willreturn
define "enzyme_type"="{[-1]:Float@double}" double @preprocess_julia_f_462(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4773884592" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4773884592" "enzymejl_parmtype_ref"="0" %1) local_unnamed_addr #4 !dbg !22 {
top:
  %2 = call {}*** @julia.get_pgcstack() #5
  %ptls_field3 = getelementptr inbounds {}**, {}*** %2, i64 2
  %3 = bitcast {}*** %ptls_field3 to i64***
  %ptls_load45 = load i64**, i64*** %3, align 8, !tbaa !8
  %4 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
  %safepoint = load i64*, i64** %4, align 8, !tbaa !12
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #5, !dbg !23
  fence syncscope("singlethread") seq_cst
  %5 = call double @llvm.fabs.f64(double %0) #5, !dbg !24
  %6 = bitcast double %5 to i64, !dbg !25
  %7 = icmp eq i64 %6, 9218868437227405312, !dbg !25
  %8 = fmul double %0, %1, !dbg !26
  %9 = select i1 %7, double -0.000000e+00, double %8, !dbg !26
  %value_phi = fadd double %9, %1, !dbg !26
  ret double %value_phi, !dbg !27
}

; Function Attrs: mustprogress nofree
define internal { double } @diffejulia_f_462(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4773884592" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4773884592" "enzymejl_parmtype_ref"="0" %1, double %differeturn) local_unnamed_addr #5 !dbg !28 {
top:
  %"value_phi'de" = alloca double, align 8
  %2 = getelementptr double, double* %"value_phi'de", i64 0
  store double 0.000000e+00, double* %2, align 8
  %"'de" = alloca double, align 8
  %3 = getelementptr double, double* %"'de", i64 0
  store double 0.000000e+00, double* %3, align 8
  %"'de1" = alloca double, align 8
  %4 = getelementptr double, double* %"'de1", i64 0
  store double 0.000000e+00, double* %4, align 8
  %"'de2" = alloca double, align 8
  %5 = getelementptr double, double* %"'de2", i64 0
  store double 0.000000e+00, double* %5, align 8
  %6 = call {}*** @julia.get_pgcstack() #6
  %ptls_field3 = getelementptr inbounds {}**, {}*** %6, i64 2
  %7 = bitcast {}*** %ptls_field3 to i64***
  %ptls_load45 = load i64**, i64*** %7, align 8, !tbaa !8, !alias.scope !29, !noalias !32
  %8 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
  %safepoint = load i64*, i64** %8, align 8, !tbaa !12, !alias.scope !34, !noalias !37
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #6, !dbg !39
  fence syncscope("singlethread") seq_cst
  %9 = call double @llvm.fabs.f64(double %0) #6, !dbg !40
  %10 = bitcast double %9 to i64, !dbg !41
  %11 = icmp eq i64 %10, 9218868437227405312, !dbg !41
  br label %inverttop, !dbg !43

inverttop:                                        ; preds = %top
  store double %differeturn, double* %"value_phi'de", align 8
  %12 = load double, double* %"value_phi'de", align 8, !dbg !42
  store double 0.000000e+00, double* %"value_phi'de", align 8, !dbg !42
  %13 = load double, double* %"'de", align 8, !dbg !42
  %14 = fadd fast double %13, %12, !dbg !42
  store double %14, double* %"'de", align 8, !dbg !42
  %15 = load double, double* %"'de1", align 8, !dbg !42
  %16 = fadd fast double %15, %12, !dbg !42
  store double %16, double* %"'de1", align 8, !dbg !42
  %17 = load double, double* %"'de", align 8, !dbg !42
  %diffe = select fast i1 %11, double 0.000000e+00, double %17, !dbg !42
  store double 0.000000e+00, double* %"'de", align 8, !dbg !42
  %18 = load double, double* %"'de2", align 8, !dbg !42
  %19 = fadd fast double %18, %17, !dbg !42
  %20 = select fast i1 %11, double %18, double %19, !dbg !42
  store double %20, double* %"'de2", align 8, !dbg !42
  %21 = load double, double* %"'de2", align 8, !dbg !42
  store double 0.000000e+00, double* %"'de2", align 8, !dbg !42
  %22 = fmul fast double %21, %0, !dbg !42
  %23 = load double, double* %"'de1", align 8, !dbg !42
  %24 = fadd fast double %23, %22, !dbg !42
  store double %24, double* %"'de1", align 8, !dbg !42
  fence syncscope("singlethread") seq_cst
  fence syncscope("singlethread") seq_cst
  %25 = load double, double* %"'de1", align 8
  %26 = insertvalue { double } undef, double %25, 0
  ret { double } %26
}

((nothing, NaN),)
@danielwe
Copy link
Contributor Author

Here's an alternative MWE where the branch condition is distinct from the value of a, so this is not related to the branch condition being a measure-zero float subset:

julia> function f(cond, a, x)
           if cond
               x += a * x
           end
           return x
       end;

julia> autodiff(Reverse, f, Active, Const(false), Const(Inf), Active(1.0))
after simplification :
; Function Attrs: mustprogress nofree readonly willreturn
define "enzyme_type"="{[-1]:Float@double}" double @preprocess_julia_f_545(i8 zeroext "enzyme_inactive" "enzyme_type"="{[-1]:Integer}" "enzymejl_parmtype"="4773885328" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4773884592" "enzymejl_parmtype_ref"="0" %1, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4773884592" "enzymejl_parmtype_ref"="0" %2) local_unnamed_addr #3 !dbg !17 {
top:
  %3 = call {}*** @julia.get_pgcstack() #4
  %ptls_field3 = getelementptr inbounds {}**, {}*** %3, i64 2
  %4 = bitcast {}*** %ptls_field3 to i64***
  %ptls_load45 = load i64**, i64*** %4, align 8, !tbaa !8
  %5 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
  %safepoint = load i64*, i64** %5, align 8, !tbaa !12
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #4, !dbg !18
  fence syncscope("singlethread") seq_cst
  %6 = and i8 %0, 1, !dbg !19
  %.not = icmp eq i8 %6, 0, !dbg !19
  %7 = fmul double %1, %2, !dbg !19
  %8 = select i1 %.not, double -0.000000e+00, double %7, !dbg !19
  %value_phi = fadd double %8, %2, !dbg !19
  ret double %value_phi, !dbg !20
}

; Function Attrs: mustprogress nofree
define internal { double } @diffejulia_f_545(i8 zeroext "enzyme_inactive" "enzyme_type"="{[-1]:Integer}" "enzymejl_parmtype"="4773885328" "enzymejl_parmtype_ref"="0" %0, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4773884592" "enzymejl_parmtype_ref"="0" %1, double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="4773884592" "enzymejl_parmtype_ref"="0" %2, double %differeturn) local_unnamed_addr #4 !dbg !21 {
top:
  %"value_phi'de" = alloca double, align 8
  %3 = getelementptr double, double* %"value_phi'de", i64 0
  store double 0.000000e+00, double* %3, align 8
  %"'de" = alloca double, align 8
  %4 = getelementptr double, double* %"'de", i64 0
  store double 0.000000e+00, double* %4, align 8
  %"'de1" = alloca double, align 8
  %5 = getelementptr double, double* %"'de1", i64 0
  store double 0.000000e+00, double* %5, align 8
  %"'de2" = alloca double, align 8
  %6 = getelementptr double, double* %"'de2", i64 0
  store double 0.000000e+00, double* %6, align 8
  %7 = call {}*** @julia.get_pgcstack() #5
  %ptls_field3 = getelementptr inbounds {}**, {}*** %7, i64 2
  %8 = bitcast {}*** %ptls_field3 to i64***
  %ptls_load45 = load i64**, i64*** %8, align 8, !tbaa !8, !alias.scope !22, !noalias !25
  %9 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
  %safepoint = load i64*, i64** %9, align 8, !tbaa !12, !alias.scope !27, !noalias !30
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #5, !dbg !32
  fence syncscope("singlethread") seq_cst
  %10 = and i8 %0, 1, !dbg !33
  %.not = icmp eq i8 %10, 0, !dbg !33
  br label %inverttop, !dbg !34

inverttop:                                        ; preds = %top
  store double %differeturn, double* %"value_phi'de", align 8
  %11 = load double, double* %"value_phi'de", align 8, !dbg !33
  store double 0.000000e+00, double* %"value_phi'de", align 8, !dbg !33
  %12 = load double, double* %"'de", align 8, !dbg !33
  %13 = fadd fast double %12, %11, !dbg !33
  store double %13, double* %"'de", align 8, !dbg !33
  %14 = load double, double* %"'de1", align 8, !dbg !33
  %15 = fadd fast double %14, %11, !dbg !33
  store double %15, double* %"'de1", align 8, !dbg !33
  %16 = load double, double* %"'de", align 8, !dbg !33
  %diffe = select fast i1 %.not, double 0.000000e+00, double %16, !dbg !33
  store double 0.000000e+00, double* %"'de", align 8, !dbg !33
  %17 = load double, double* %"'de2", align 8, !dbg !33
  %18 = fadd fast double %17, %16, !dbg !33
  %19 = select fast i1 %.not, double %17, double %18, !dbg !33
  store double %19, double* %"'de2", align 8, !dbg !33
  %20 = load double, double* %"'de2", align 8, !dbg !33
  store double 0.000000e+00, double* %"'de2", align 8, !dbg !33
  %21 = fmul fast double %20, %1, !dbg !33
  %22 = load double, double* %"'de1", align 8, !dbg !33
  %23 = fadd fast double %22, %21, !dbg !33
  store double %23, double* %"'de1", align 8, !dbg !33
  fence syncscope("singlethread") seq_cst
  fence syncscope("singlethread") seq_cst
  %24 = load double, double* %"'de1", align 8
  %25 = insertvalue { double } undef, double %24, 0
  ret { double } %25
}

((nothing, nothing, NaN),)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant