Skip to content

Commit

Permalink
Test + Fix Debug Mode (#356)
Browse files Browse the repository at this point in the history
* Improve comment

* Improve error handling in LazyDerivedRule

* Test debug mode properly

* Fix debug mode

* Bump patch version

* Tidy up slightly

* Tidy up a bit more

* Tweak performance criteria

* Fix test warning
  • Loading branch information
willtebbutt authored Nov 8, 2024
1 parent c8cd702 commit e9d5c81
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.38"
version = "0.4.39"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
4 changes: 2 additions & 2 deletions bench/run_benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ function benchmark_derived_rrules!!(rng_ctor)
tags = fill(nothing, length(test_cases))
return map(x -> x[4:end], test_cases), memory, ranges, tags
end
return benchmark_rules!!(test_case_data, (lb=1e-3, ub=150), false)
return benchmark_rules!!(test_case_data, (lb=1e-3, ub=200), false)
end

function benchmark_inter_framework_rules()
Expand All @@ -299,7 +299,7 @@ function benchmark_inter_framework_rules()
test_cases = map(last, test_case_data)
memory = []
ranges = fill(nothing, length(test_cases))
return benchmark_rules!!([(test_cases, memory, ranges, tags)], (lb=0.1, ub=150), true)
return benchmark_rules!!([(test_cases, memory, ranges, tags)], (lb=0.1, ub=200), true)
end

function flag_concerning_performance(ratios)
Expand Down
2 changes: 1 addition & 1 deletion src/debug_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ verify_args(_, x) = nothing
try
# Check that the input types are correct. If this check is not present, the passing
# in arguments of the wrong type can result in a segfault.
verify_args(rule, map(primal, x))
verify_args(rule, x)

# Use for-loop to keep the stack trace as simple as possible.
for _x in x
Expand Down
28 changes: 16 additions & 12 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ __get_primal(x::CoDual) = primal(x)
__get_primal(x) = x

# Used in `make_ad_stmts!` method for `Expr(:call, ...)` and `Expr(:invoke, ...)`.
@inline function __run_rvs_pass!(::Type{P}, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...) where {P, sig}
@inline function __run_rvs_pass!(P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...) where {sig}
tuple_map(increment_if_ref!, arg_rev_data_refs, pb!!(ret_rev_data_ref[]))
set_ret_ref_to_zero!!(P, ret_rev_data_ref)
return nothing
Expand Down Expand Up @@ -734,10 +734,10 @@ function DerivedRule(Tprimal, fwds_oc::T, pb::U, isva::V, nargs::W) where {T, U,
end

# Extends functionality defined for debug_mode.
function verify_args(::DerivedRule{sig}, ::Tx) where {sig, Tx}
sig === Tx && return nothing
msg = "Arguments with sig $Tx do not match signature expected by rule, $sig"
throw(ArgumentError(msg))
function verify_args(r::DerivedRule{sig}, x) where {sig}
Tx = Tuple{map(_typeof primal, __unflatten_codual_varargs(r.isva, x, r.nargs))...}
Tx <: sig && return nothing
throw(ArgumentError("Arguments with sig $Tx do not subtype rule signature, $sig"))
end

_copy(::Nothing) = nothing
Expand Down Expand Up @@ -1530,8 +1530,7 @@ end
_copy(x::P) where {P<:LazyDerivedRule} = P(x.mi, x.debug_mode)

@inline function (rule::LazyDerivedRule)(args::Vararg{Any, N}) where {N}
isdefined(rule, :rule) || _build_rule!(rule, args)
return rule.rule(args...)
return isdefined(rule, :rule) ? rule.rule(args...) : _build_rule!(rule, args)
end

struct BadRuleTypeException <: Exception
Expand Down Expand Up @@ -1561,20 +1560,25 @@ function Base.showerror(io::IO, err::BadRuleTypeException)
println(io, msg)
end

_rtype(::Type{<:DebugRRule}) = Tuple{CoDual, DebugPullback}
_rtype(T::Type{<:MistyClosure}) = _rtype(fieldtype(T, :oc))
_rtype(::Type{<:OpaqueClosure{<:Any, <:R}}) where {R} = (@isdefined R) ? R : CoDual
_rtype(T::Type{<:DerivedRule}) = Tuple{_rtype(fieldtype(T, :fwds_oc)), fieldtype(T, :pb)}

@noinline function _build_rule!(rule::LazyDerivedRule{sig, Trule}, args) where {sig, Trule}
derived_rule = build_rrule(get_interpreter(), rule.mi; debug_mode=rule.debug_mode)
if derived_rule isa Trule
rule.rule = derived_rule
result = derived_rule(args...)
else
@warn "Unable to put rule in rule field. A `BadRuleTypeException` should be thrown."
@warn "Unable to put rule in rule field. A `BadRuleTypeException` might be thrown."
err = BadRuleTypeException(rule.mi, sig, typeof(derived_rule), Trule)
try
result = try
derived_rule(args...)
catch
throw(err)
end
@warn "`BadRuleTypException was _not_ thrown. Throwing now."
throw(err)
@warn "`BadRuleTypException was _not_ thrown. Expect an error at some point."
end
return nothing
return result::_rtype(Trule)
end
2 changes: 1 addition & 1 deletion src/test_resources.jl
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ end

# Copied over from https://github.com/TuringLang/Turing.jl/issues/1140
function _sum(x)
z = 0
z = 0 # this intentionally causes a type instability -- do not make this type stable.
for i in eachindex(x)
z += x[i]
end
Expand Down
10 changes: 9 additions & 1 deletion test/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,11 @@ end
@testset "LazyDerivedRule" begin
fargs = (S2SGlobals.baz, 5.0)
rule = build_rrule(fargs...)
@test_throws Mooncake.BadRuleTypeException rule(map(zero_fcodual, fargs)...)
msg = "Unable to put rule in rule field. A `BadRuleTypeException` might be thrown."
@test_logs(
(:warn, msg),
(@test_throws Mooncake.BadRuleTypeException rule(map(zero_fcodual, fargs)...)),
)
end
@testset "MooncakeRuleCompilationError" begin
@test_throws(Mooncake.MooncakeRuleCompilationError, Mooncake.build_rrule(sin))
Expand All @@ -270,6 +274,10 @@ end
TestUtils.test_rule(
Xoshiro(123456), f, x...; perf_flag, interface_only, is_primitive=false
)
TestUtils.test_rule(
Xoshiro(123456), f, x...;
perf_flag=:none, interface_only, is_primitive=false, debug_mode=true,
)

# interp = Mooncake.get_interpreter()
# codual_args = map(zero_codual, (f, x...))
Expand Down

2 comments on commit e9d5c81

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/118983

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.39 -m "<description of version>" e9d5c81809e790797a9b0701eb0af7e96d29c4b7
git push origin v0.4.39

Please sign in to comment.