From e9d5c81809e790797a9b0701eb0af7e96d29c4b7 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 8 Nov 2024 14:46:46 +0000 Subject: [PATCH] Test + Fix Debug Mode (#356) * Improve comment * Improve error handling in LazyDerivedRule * Test debug mode properly * Fix debug mode * Bump patch version * Tidy up slightly * Tidy up a bit more * Tweak performance criteria * Fix test warning --- Project.toml | 2 +- bench/run_benchmarks.jl | 4 ++-- src/debug_mode.jl | 2 +- src/interpreter/s2s_reverse_mode_ad.jl | 28 ++++++++++++++----------- src/test_resources.jl | 2 +- test/interpreter/s2s_reverse_mode_ad.jl | 10 ++++++++- 6 files changed, 30 insertions(+), 18 deletions(-) diff --git a/Project.toml b/Project.toml index dd88e30df..12e2f750b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.38" +version = "0.4.39" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/bench/run_benchmarks.jl b/bench/run_benchmarks.jl index a3cc8f56c..095ff3dd9 100644 --- a/bench/run_benchmarks.jl +++ b/bench/run_benchmarks.jl @@ -290,7 +290,7 @@ function benchmark_derived_rrules!!(rng_ctor) tags = fill(nothing, length(test_cases)) return map(x -> x[4:end], test_cases), memory, ranges, tags end - return benchmark_rules!!(test_case_data, (lb=1e-3, ub=150), false) + return benchmark_rules!!(test_case_data, (lb=1e-3, ub=200), false) end function benchmark_inter_framework_rules() @@ -299,7 +299,7 @@ function benchmark_inter_framework_rules() test_cases = map(last, test_case_data) memory = [] ranges = fill(nothing, length(test_cases)) - return benchmark_rules!!([(test_cases, memory, ranges, tags)], (lb=0.1, ub=150), true) + return benchmark_rules!!([(test_cases, memory, ranges, tags)], (lb=0.1, ub=200), true) end function flag_concerning_performance(ratios) diff --git a/src/debug_mode.jl b/src/debug_mode.jl index f96e23553..25342afac 100644 --- a/src/debug_mode.jl +++ b/src/debug_mode.jl @@ -98,7 +98,7 @@ verify_args(_, x) = nothing try # Check that the input types are correct. If this check is not present, the passing # in arguments of the wrong type can result in a segfault. - verify_args(rule, map(primal, x)) + verify_args(rule, x) # Use for-loop to keep the stack trace as simple as possible. for _x in x diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index d9f38ec9a..a1fec0c5e 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -682,7 +682,7 @@ __get_primal(x::CoDual) = primal(x) __get_primal(x) = x # Used in `make_ad_stmts!` method for `Expr(:call, ...)` and `Expr(:invoke, ...)`. -@inline function __run_rvs_pass!(::Type{P}, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...) where {P, sig} +@inline function __run_rvs_pass!(P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...) where {sig} tuple_map(increment_if_ref!, arg_rev_data_refs, pb!!(ret_rev_data_ref[])) set_ret_ref_to_zero!!(P, ret_rev_data_ref) return nothing @@ -734,10 +734,10 @@ function DerivedRule(Tprimal, fwds_oc::T, pb::U, isva::V, nargs::W) where {T, U, end # Extends functionality defined for debug_mode. -function verify_args(::DerivedRule{sig}, ::Tx) where {sig, Tx} - sig === Tx && return nothing - msg = "Arguments with sig $Tx do not match signature expected by rule, $sig" - throw(ArgumentError(msg)) +function verify_args(r::DerivedRule{sig}, x) where {sig} + Tx = Tuple{map(_typeof ∘ primal, __unflatten_codual_varargs(r.isva, x, r.nargs))...} + Tx <: sig && return nothing + throw(ArgumentError("Arguments with sig $Tx do not subtype rule signature, $sig")) end _copy(::Nothing) = nothing @@ -1530,8 +1530,7 @@ end _copy(x::P) where {P<:LazyDerivedRule} = P(x.mi, x.debug_mode) @inline function (rule::LazyDerivedRule)(args::Vararg{Any, N}) where {N} - isdefined(rule, :rule) || _build_rule!(rule, args) - return rule.rule(args...) + return isdefined(rule, :rule) ? rule.rule(args...) : _build_rule!(rule, args) end struct BadRuleTypeException <: Exception @@ -1561,20 +1560,25 @@ function Base.showerror(io::IO, err::BadRuleTypeException) println(io, msg) end +_rtype(::Type{<:DebugRRule}) = Tuple{CoDual, DebugPullback} +_rtype(T::Type{<:MistyClosure}) = _rtype(fieldtype(T, :oc)) +_rtype(::Type{<:OpaqueClosure{<:Any, <:R}}) where {R} = (@isdefined R) ? R : CoDual +_rtype(T::Type{<:DerivedRule}) = Tuple{_rtype(fieldtype(T, :fwds_oc)), fieldtype(T, :pb)} + @noinline function _build_rule!(rule::LazyDerivedRule{sig, Trule}, args) where {sig, Trule} derived_rule = build_rrule(get_interpreter(), rule.mi; debug_mode=rule.debug_mode) if derived_rule isa Trule rule.rule = derived_rule + result = derived_rule(args...) else - @warn "Unable to put rule in rule field. A `BadRuleTypeException` should be thrown." + @warn "Unable to put rule in rule field. A `BadRuleTypeException` might be thrown." err = BadRuleTypeException(rule.mi, sig, typeof(derived_rule), Trule) - try + result = try derived_rule(args...) catch throw(err) end - @warn "`BadRuleTypException was _not_ thrown. Throwing now." - throw(err) + @warn "`BadRuleTypException was _not_ thrown. Expect an error at some point." end - return nothing + return result::_rtype(Trule) end diff --git a/src/test_resources.jl b/src/test_resources.jl index 5075646fb..5b32a0a3e 100644 --- a/src/test_resources.jl +++ b/src/test_resources.jl @@ -525,7 +525,7 @@ end # Copied over from https://github.com/TuringLang/Turing.jl/issues/1140 function _sum(x) - z = 0 + z = 0 # this intentionally causes a type instability -- do not make this type stable. for i in eachindex(x) z += x[i] end diff --git a/test/interpreter/s2s_reverse_mode_ad.jl b/test/interpreter/s2s_reverse_mode_ad.jl index 097119a40..bdb8bc7b3 100644 --- a/test/interpreter/s2s_reverse_mode_ad.jl +++ b/test/interpreter/s2s_reverse_mode_ad.jl @@ -257,7 +257,11 @@ end @testset "LazyDerivedRule" begin fargs = (S2SGlobals.baz, 5.0) rule = build_rrule(fargs...) - @test_throws Mooncake.BadRuleTypeException rule(map(zero_fcodual, fargs)...) + msg = "Unable to put rule in rule field. A `BadRuleTypeException` might be thrown." + @test_logs( + (:warn, msg), + (@test_throws Mooncake.BadRuleTypeException rule(map(zero_fcodual, fargs)...)), + ) end @testset "MooncakeRuleCompilationError" begin @test_throws(Mooncake.MooncakeRuleCompilationError, Mooncake.build_rrule(sin)) @@ -270,6 +274,10 @@ end TestUtils.test_rule( Xoshiro(123456), f, x...; perf_flag, interface_only, is_primitive=false ) + TestUtils.test_rule( + Xoshiro(123456), f, x...; + perf_flag=:none, interface_only, is_primitive=false, debug_mode=true, + ) # interp = Mooncake.get_interpreter() # codual_args = map(zero_codual, (f, x...))