From 2d4d5ed5dc557bda5c0476329abf008f7621b09f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 28 Aug 2024 23:00:53 -0400 Subject: [PATCH] feat: default call for wrapper layers --- src/LuxCore.jl | 7 +++++++ test/runtests.jl | 23 +++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/LuxCore.jl b/src/LuxCore.jl index 5f0a3f2..4e90827 100644 --- a/src/LuxCore.jl +++ b/src/LuxCore.jl @@ -239,6 +239,9 @@ layer to be wrapped in a container. Additionally, on calling [`initialparameters`](@ref) and [`initialstates`](@ref), the parameters and states are **not** wrapped in a `NamedTuple` with the same name as the field. + +As a convenience, we define the fallback call `(::AbstractLuxWrapperLayer)(x, ps, st)`, +which calls `getfield(x, layer)(x, ps, st)`. """ abstract type AbstractLuxWrapperLayer{layer} <: AbstractLuxLayer end @@ -259,6 +262,10 @@ function statelength(l::AbstractLuxWrapperLayer{layer}) where {layer} return statelength(getfield(l, layer)) end +function (l::AbstractLuxWrapperLayer{layer})(x, ps, st) where {layer} + return apply(getfield(l, layer), x, ps, st) +end + # Test Mode """ testmode(st::NamedTuple) diff --git a/test/runtests.jl b/test/runtests.jl index 82c3439..f55dba7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,6 +31,17 @@ end (::Dense)(x, ps, st) = x, st # Dummy Forward Pass +struct DenseWrapper{L} <: AbstractLuxWrapperLayer{:layer} + layer::L +end + +# For checking ambiguities in the dispatch +struct DenseWrapper2{L} <: AbstractLuxWrapperLayer{:layer} + layer::L +end + +(d::DenseWrapper2)(x::AbstractArray, ps, st) = d.layer(x, ps, st) + struct Chain{L} <: AbstractLuxContainerLayer{(:layers,)} layers::L end @@ -78,6 +89,18 @@ end first(LuxCore.apply(model, x, ps, NamedTuple())) @test_nowarn println(model) + + @testset for wrapper in (DenseWrapper, DenseWrapper2) + model2 = DenseWrapper(model) + ps, st = LuxCore.setup(rng, model2) + + @test LuxCore.parameterlength(ps) == LuxCore.parameterlength(model2) + @test LuxCore.statelength(st) == LuxCore.statelength(model2) + + @test model2(x, ps, st)[1] == model(x, ps, st)[1] + + @test_nowarn println(model2) + end end @testset "Default Fallbacks" begin