Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up the code for debug mode #674

Merged
merged 1 commit into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/contrib/contrib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using ArgCheck: @argcheck
using ChainRulesCore: ChainRulesCore
using ConcreteStructs: @concrete
using FastClosures: @closure
using Functors: Functors, KeyPath, fmap, fmap_with_path, fmapstructure, functor
using Functors: Functors, KeyPath, fmap_with_path, fmapstructure, functor
using Markdown: @doc_str
using Random: AbstractRNG, Random
using Setfield: Setfield
Expand Down
152 changes: 68 additions & 84 deletions src/contrib/debug.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
DebugLayer(layer::AbstractExplicitLayer; nan_check::Symbol=:both,
error_check::Bool=true, location::String="")
error_check::Bool=true, location::KeyPath=KeyPath())

!!! danger

Expand Down Expand Up @@ -46,12 +46,20 @@ See [`Lux.Experimental.@debug_mode`](@ref) to construct this layer.
@concrete struct DebugLayer{NaNCheck, ErrorCheck} <:
AbstractExplicitContainerLayer{(:layer,)}
layer
location::String
location::KeyPath
end

function DebugLayer(layer::AbstractExplicitLayer; nan_check::Symbol=:both,
error_check::Bool=true, location::String="")
error_check::Bool=true, location::Union{KeyPath, String}=KeyPath())
@argcheck nan_check in (:both, :forward, :backward, :none)

if location isa String
Base.depwarn(
"Using a String for location in DebugLayer is deprecated. Use \
`Functors.KeyPath` instead.", :DebugLayer)
location = KeyPath(Symbol.(split(location, "."))...)
end

return DebugLayer{nan_check, error_check}(layer, location)
end

Expand All @@ -60,92 +68,80 @@ function error_check(::DebugLayer{NaNCheck, ErrorCheck}) where {NaNCheck, ErrorC
return Val(ErrorCheck)
end

function (d::DebugLayer)(x, ps, st)
return __debug_layer(nan_check(d), error_check(d), d.layer, x, ps, st, d.location)
end

function __any_nan(x)
has_nan = Ref(false)
function nan_check(x)
x isa AbstractArray && (has_nan[] = has_nan[] || any(isnan, x))
applicable(isnan, x) && (has_nan[] = has_nan[] || isnan(x))
return x
end
fmap(nan_check, x)
return has_nan[]
end

CRC.@non_differentiable __any_nan(::Any)

function __debug_layer(
::Val{NC}, ::Val{EC}, layer, x, ps, st, location::String) where {NC, EC}
function (d::DebugLayer{NaNCheck, ErrorCheck})(x, ps, st) where {NaNCheck, ErrorCheck}
CRC.ignore_derivatives() do
@info lazy"Input Type: $(typeof(x)) | Input Structure: $(fmapstructure(Lux.__size, x))"
@info lazy"Running Layer: $(layer) at location $(location)!"
end
if NC ∈ (:both, :forward)
__any_nan(x) && throw(DomainError(
x, lazy"NaNs detected in input to layer $(layer) at location $(location)"))
__any_nan(ps) && throw(DomainError(ps,
lazy"NaNs detected in parameters of layer $(layer) at location $(location)"))
__any_nan(st) && throw(DomainError(st,
lazy"NaNs detected in states of layer $(layer) at location $(location)"))
@info lazy"Running Layer: $(d.layer) at location $(d.location)!"
if NaNCheck ∈ (:both, :forward)
__check_nan_and_throw(x, "input", d.layer, d.location)
__check_nan_and_throw(ps, "parameters", d.layer, d.location)
__check_nan_and_throw(st, "states", d.layer, d.location)
end
end
y, st_ = __debug_layer_internal(layer, x, ps, st, location, EC, NC ∈ (:both, :backward))
y, st_ = __debug_layer_internal(
d.layer, x, ps, st, d.location, ErrorCheck, NaNCheck ∈ (:both, :backward))
CRC.ignore_derivatives() do
if NaNCheck ∈ (:both, :forward)
__check_nan_and_throw(y, "output", d.layer, d.location)
__check_nan_and_throw(st_, "states", d.layer, d.location)
end
@info lazy"Output Type: $(typeof(y)) | Output Structure: $(fmapstructure(Lux.__size, y))"
end
return y, st_
end

function __check_nan_and_throw(x, str::AbstractString, layer, location::KeyPath)
function err(kp, x)
loc_str = kp == KeyPath() ? " " : " (@ $(kp)) "
return DomainError(
x, "NaNs detected in $(str)$(loc_str)of layer $(layer) at location $(location)")
end

function nan_check(kp, x)
x isa AbstractArray && any(isnan, x) && throw(err(kp, x))
applicable(isnan, x) && isnan(x) && throw(err(kp, x))
return x
end

fmap_with_path(nan_check, x)
end

function __debug_layer_internal(layer, x, ps, st, location, EC, NC)
if EC
try
y, st_ = apply(layer, x, ps, st)
return y, st_
catch e
y, st_ = try
apply(layer, x, ps, st)
catch
EC &&
@error lazy"Layer $(layer) failed!! This layer is present at location $(location)"
rethrow()
end
else
return apply(layer, x, ps, st)
rethrow()
end
return y, st_
end

function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode},
::typeof(__debug_layer_internal), layer, x, ps, st, location, EC, NC)
result, ∇__debug_layer_internal = CRC.rrule_via_ad(cfg, apply, layer, x, ps, st)
function ∇__debug_layer_internal_with_checks(Δ)
if NC
__any_nan(Δ) && throw(DomainError(Δ,
lazy"NaNs detected in pullback input for $(layer) at location $(location)!"))
end
if EC
try
gs = ∇__debug_layer_internal(Δ)
catch e
result, ∇debug_layer_internal = CRC.rrule_via_ad(cfg, apply, layer, x, ps, st)
syms = ["LuxCore.apply", "layer", "x", "ps", "st"]
function ∇debug_layer_internal_with_checks(Δ)
NC && __check_nan_and_throw(Δ, "pullback input", layer, location)

gs = try
∇debug_layer_internal(Δ)
catch
EC &&
@error lazy"Backward Pass for Layer $(layer) failed!! This layer is present at location $(location)"
rethrow()
end
if NC
for g in gs
__any_nan(g) && throw(DomainError(g,
lazy"NaNs detected in pullback output for $(layer) at location $(location)!"))
end
end
return (gs..., CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent())
else
gs = ∇__debug_layer_internal(Δ)
if NC
for g in gs
__any_nan(g) && throw(DomainError(g,
lazy"NaNs detected in pullback output for $(layer) at location $(location)!"))
end
rethrow()
end

if NC
for (i, g) in enumerate(gs)
__check_nan_and_throw(
g, lazy"pullback output ($(syms[i]))", layer, location)
end
return (gs..., CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent())
end

return (gs..., CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent())
end
return result, ∇__debug_layer_internal_with_checks
return result, ∇debug_layer_internal_with_checks
end

"""
Expand All @@ -158,19 +154,7 @@ See [`Lux.Experimental.DebugLayer`](@ref) for details about the Keyword Argument
"""
macro debug_mode(layer, kwargs...)
kws = esc.(kwargs)
return :(__debug_mode($(esc(layer)), $(string(layer)); $(kws...)))
end

function __debug_mode(layer, name::String; kwargs...)
l_c, l_re = functor(layer)

length(l_c) == 0 && return DebugLayer(layer; location=name, kwargs...)

l_c_new = []
for k in keys(l_c)
l_c_new_ = __debug_mode(getproperty(l_c, k), join((name, k), "."); kwargs...)
push!(l_c_new, k => l_c_new_)
end

return l_re((; l_c_new...))
return :($(fmap_with_path)(
(kp, l) -> DebugLayer(l; location=$(KeyPath)($(Meta.quot(layer)), kp), $(kws...)),
$(esc(layer))))
end
2 changes: 1 addition & 1 deletion test/contrib/debug_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@test_throws DimensionMismatch model_debug(x, ps, st)
@test_logs (:info,) (:error,
"Layer Dense(1 => 1) failed!! This layer is present at location model.layers.layer_2.layers.layer_2") match_mode=:any try
"Layer Dense(1 => 1) failed!! This layer is present at location KeyPath(:model, :layers, :layer_2, :layers, :layer_2)") match_mode=:any try
model_debug(x, ps, st)
catch
end
Expand Down
Loading