Skip to content

Commit

Permalink
test: batchnorm layers
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 17, 2025
1 parent 5014ee4 commit f074b76
Showing 1 changed file with 63 additions and 1 deletion.
64 changes: 63 additions & 1 deletion test/reactant/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,69 @@ end
end
end

@testitem "BatchNorm Layer" tags=[:reactant] setup=[SharedTestSetup] skip=:(Sys.iswindows()) begin
@testitem "BatchNorm Layer" tags=[:reactant] setup=[
SharedTestSetup, SharedReactantLayersTestSetup] skip=:(Sys.iswindows()) begin
using Reactant, Lux, Random

@testset "$(mode)" for (mode, atype, dev, ongpu) in MODES
if mode == "amdgpu"
@warn "Skipping AMDGPU tests for Reactant"
continue
end

dev = reactant_device(; force=true)

if ongpu
Reactant.set_default_backend("gpu")
else
Reactant.set_default_backend("cpu")
end

@testset for track_stats in (true, false), affine in (true, false),
act in (identity, tanh)

model = Chain(
Dense(2 => 3, tanh),
BatchNorm(3, act; track_stats, affine, init_bias=rand32, init_scale=rand32),
Dense(3 => 2)
)

x = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), model)

x_ra = x |> dev
ps_ra = ps |> dev
st_ra = st |> dev

y, st2 = model(x, ps, st)
y_ra, st2_ra = @jit model(x_ra, ps_ra, st_ra)

@test yy_ra rtol=1e-3 atol=1e-3
if track_stats
@test st2.layer_2.running_meanst2_ra.layer_2.running_mean rtol=1e-3 atol=1e-3
@test st2.layer_2.running_varst2_ra.layer_2.running_var rtol=1e-3 atol=1e-3
end

# TODO: Check for stablehlo.batch_norm_training once we emit it in LuxLib

@testset "gradient" begin
∂x, ∂ps = ∇sumabs2_zygote(model, x, ps, st)
∂x_ra, ∂ps_ra = @jit ∇sumabs2_enzyme(model, x_ra, ps_ra, st_ra)
@test ∂x_ra∂x atol=1e-2 rtol=1e-2
@test check_approx(∂ps_ra, ∂ps; atol=1e-2, rtol=1e-2)
end

y2, st3 = model(x, ps, Lux.testmode(st2))
y2_ra, st3_ra = @jit model(x_ra, ps_ra, Lux.testmode(st2_ra))

@test y2y2_ra rtol=1e-3 atol=1e-3
if track_stats
@test st3.layer_2.running_meanst3_ra.layer_2.running_mean rtol=1e-3 atol=1e-3
@test st3.layer_2.running_varst3_ra.layer_2.running_var rtol=1e-3 atol=1e-3
end

hlo = @code_hlo model(x_ra, ps_ra, Lux.testmode(st_ra))
@test contains(repr(hlo), "stablehlo.batch_norm_inference")
end
end
end

0 comments on commit f074b76

Please sign in to comment.