diff --git a/docs/src/index.md b/docs/src/index.md index 5ad110b..3c0a7f3 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -18,6 +18,8 @@ logcosh logabssinh log1psq log1pexp +softplus +invsoftplus log1mexp log2mexp logexpm1 diff --git a/ext/LogExpFunctionsChangesOfVariablesExt.jl b/ext/LogExpFunctionsChangesOfVariablesExt.jl index 52a105b..cf7c960 100644 --- a/ext/LogExpFunctionsChangesOfVariablesExt.jl +++ b/ext/LogExpFunctionsChangesOfVariablesExt.jl @@ -8,11 +8,25 @@ function ChangesOfVariables.with_logabsdet_jacobian(::typeof(log1pexp), x::Real) y = log1pexp(x) return y, x - y end +function ChangesOfVariables.with_logabsdet_jacobian(::typeof(softplus), x::Real) + return ChangesOfVariables.with_logabsdet_jacobian(log1pexp, x) +end +function ChangesOfVariables.with_logabsdet_jacobian(f::Base.Fix2{typeof(softplus),<:Real}, x::Real) + y = f(x) + return y, f.x * (x - y) +end function ChangesOfVariables.with_logabsdet_jacobian(::typeof(logexpm1), x::Real) y = logexpm1(x) return y, x - y end +function ChangesOfVariables.with_logabsdet_jacobian(::typeof(invsoftplus), x::Real) + return ChangesOfVariables.with_logabsdet_jacobian(logexpm1, x) +end +function ChangesOfVariables.with_logabsdet_jacobian(f::Base.Fix2{typeof(invsoftplus),<:Real}, x::Real) + y = f(x) + return y, f.x * (x - y) +end function ChangesOfVariables.with_logabsdet_jacobian(::typeof(log1mexp), x::Real) y = log1mexp(x) diff --git a/ext/LogExpFunctionsInverseFunctionsExt.jl b/ext/LogExpFunctionsInverseFunctionsExt.jl index 1981493..6903303 100644 --- a/ext/LogExpFunctionsInverseFunctionsExt.jl +++ b/ext/LogExpFunctionsInverseFunctionsExt.jl @@ -22,4 +22,14 @@ InverseFunctions.inverse(::typeof(logitexp)) = loglogistic InverseFunctions.inverse(::typeof(log1mlogistic)) = logit1mexp InverseFunctions.inverse(::typeof(logit1mexp)) = log1mlogistic +InverseFunctions.inverse(::typeof(softplus)) = invsoftplus +function InverseFunctions.inverse(f::Base.Fix2{typeof(softplus),<:Real}) + Base.Fix2(invsoftplus, f.x) +end + +InverseFunctions.inverse(::typeof(invsoftplus)) = softplus +function InverseFunctions.inverse(f::Base.Fix2{typeof(invsoftplus),<:Real}) + Base.Fix2(softplus, f.x) +end + end # module diff --git a/src/basicfuns.jl b/src/basicfuns.jl index e013adf..d561fc0 100644 --- a/src/basicfuns.jl +++ b/src/basicfuns.jl @@ -165,6 +165,9 @@ Return `log(1+exp(x))` evaluated carefully for largish `x`. This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) transformation, being a smooth approximation to `max(0,x)`. Its inverse is [`logexpm1`](@ref). +This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) +transformation (in its default parametrization, see [`softplus`](@ref)), being a smooth approximation to `max(0,x)`. + See: * Martin Maechler (2012) [“Accurately Computing log(1 − exp(− |a|))”](http://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf) """ @@ -257,8 +260,27 @@ Return `log(exp(x) - 1)` or the “invsoftplus” function. It is the inverse o logexpm1(x::Real) = x <= 18.0 ? log(_expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x) logexpm1(x::Float32) = x <= 9f0 ? log(expm1(x)) : x <= 16f0 ? x - exp(-x) : oftype(exp(-x), x) -const softplus = log1pexp -const invsoftplus = logexpm1 +""" +$(SIGNATURES) + +The generalized `softplus` function (Wiemann et al., 2024) takes an additional optional parameter `a` that control +the approximation error with respect to the linear spline. It defaults to `a=1.0`, in which case the softplus is +equivalent to [`log1pexp`](@ref). + +See: + * Wiemann, P. F., Kneib, T., & Hambuckers, J. (2024). Using the softplus function to construct alternative link functions in generalized linear models and beyond. Statistical Papers, 65(5), 3155-3180. +""" +softplus(x::Real) = log1pexp(x) +softplus(x::Real, a::Real) = log1pexp(a * x) / a + +""" +$(SIGNATURES) + +The inverse generalized `softplus` function (Wiemann et al., 2024). See [`softplus`](@ref). +""" +invsoftplus(y::Real) = logexpm1(y) +invsoftplus(y::Real, a::Real) = logexpm1(a * y) / a + """ $(SIGNATURES) diff --git a/test/basicfuns.jl b/test/basicfuns.jl index 1e615c0..72d0e44 100644 --- a/test/basicfuns.jl +++ b/test/basicfuns.jl @@ -161,6 +161,16 @@ end end end +@testset "softplus" begin + for T in (Int, Float64, Float32, Float16) + @test @inferred(softplus(T(2))) === log1pexp(T(2)) + @test @inferred(softplus(T(2), 1)) isa float(T) + @test @inferred(softplus(T(2), 1)) ≈ softplus(T(2)) + @test @inferred(softplus(T(2), 5)) ≈ softplus(5 * T(2)) / 5 + @test @inferred(softplus(T(2), 10)) ≈ softplus(10 * T(2)) / 10 + end +end + @testset "log1mexp" begin for T in (Float64, Float32, Float16) @test @inferred(log1mexp(-T(1))) isa T @@ -186,6 +196,16 @@ end end end +@testset "invsoftplus" begin + for T in (Int, Float64, Float32, Float16) + @test @inferred(invsoftplus(T(2))) === logexpm1(T(2)) + @test @inferred(invsoftplus(T(2), 1)) isa float(T) + @test @inferred(invsoftplus(T(2), 1)) ≈ invsoftplus(T(2)) + @test @inferred(invsoftplus(T(2), 5)) ≈ invsoftplus(5 * T(2)) / 5 + @test @inferred(invsoftplus(T(2), 10)) ≈ invsoftplus(10 * T(2)) / 10 + end +end + @testset "log1pmx" begin @test iszero(log1pmx(0.0)) @test log1pmx(1.0) ≈ log(2.0) - 1.0 diff --git a/test/inverse.jl b/test/inverse.jl index e489db5..630e2f7 100644 --- a/test/inverse.jl +++ b/test/inverse.jl @@ -1,6 +1,11 @@ @testset "inverse.jl" begin InverseFunctions.test_inverse(log1pexp, randn()) + InverseFunctions.test_inverse(softplus, randn()) + InverseFunctions.test_inverse(Base.Fix2(softplus, randexp()), randn()) + InverseFunctions.test_inverse(logexpm1, randexp()) + InverseFunctions.test_inverse(invsoftplus, randexp()) + InverseFunctions.test_inverse(Base.Fix2(invsoftplus, randexp()), randexp()) InverseFunctions.test_inverse(log1mexp, -randexp()) diff --git a/test/with_logabsdet_jacobian.jl b/test/with_logabsdet_jacobian.jl index d5f4484..3de5e9f 100644 --- a/test/with_logabsdet_jacobian.jl +++ b/test/with_logabsdet_jacobian.jl @@ -1,12 +1,23 @@ @testset "with_logabsdet_jacobian" begin derivative(f, x) = ChainRulesTestUtils.frule((ChainRulesTestUtils.NoTangent(), 1), f, x)[2] + derivative(::typeof(softplus), x) = derivative(log1pexp, x) + derivative(f::Base.Fix2{typeof(softplus),<:Real}, x) = derivative(log1pexp, f.x * x) + derivative(::typeof(invsoftplus), x) = derivative(logexpm1, x) + derivative(f::Base.Fix2{typeof(invsoftplus),<:Real}, x) = derivative(logexpm1, f.x * x) x = randexp() + y = randexp() ChangesOfVariables.test_with_logabsdet_jacobian(log1pexp, x, derivative) ChangesOfVariables.test_with_logabsdet_jacobian(log1pexp, -x, derivative) + ChangesOfVariables.test_with_logabsdet_jacobian(softplus, x, derivative) + ChangesOfVariables.test_with_logabsdet_jacobian(softplus, -x, derivative) + ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(softplus, y), x, derivative) + ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(softplus, y), -x, derivative) ChangesOfVariables.test_with_logabsdet_jacobian(logexpm1, x, derivative) + ChangesOfVariables.test_with_logabsdet_jacobian(invsoftplus, x, derivative) + ChangesOfVariables.test_with_logabsdet_jacobian(Base.Fix2(invsoftplus, y), x, derivative) ChangesOfVariables.test_with_logabsdet_jacobian(log1mexp, -x, derivative)