Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "a" parameter to softplus() #83 #85

Merged
merged 13 commits into from
Dec 11, 2024
2 changes: 2 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ logcosh
logabssinh
log1psq
log1pexp
softplus
invsoftplus
log1mexp
log2mexp
logexpm1
Expand Down
3 changes: 3 additions & 0 deletions ext/LogExpFunctionsInverseFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,7 @@ InverseFunctions.inverse(::typeof(logitexp)) = loglogistic
InverseFunctions.inverse(::typeof(log1mlogistic)) = logit1mexp
InverseFunctions.inverse(::typeof(logit1mexp)) = log1mlogistic

InverseFunctions.inverse(::typeof(softplus)) = invsoftplus
InverseFunctions.inverse(::typeof(invsoftplus)) = softplus

end # module
28 changes: 26 additions & 2 deletions src/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ Return `log(1+exp(x))` evaluated carefully for largish `x`.
This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
transformation, being a smooth approximation to `max(0,x)`. Its inverse is [`logexpm1`](@ref).

This is also called the ["softplus"](https://en.wikipedia.org/wiki/Rectifier_(neural_networks))
transformation (in its default parametrization, see [`softplus`](@ref)), being a smooth approximation to `max(0,x)`.

See:
* Martin Maechler (2012) [“Accurately Computing log(1 − exp(− |a|))”](http://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf)
"""
Expand Down Expand Up @@ -257,8 +260,29 @@ Return `log(exp(x) - 1)` or the “invsoftplus” function. It is the inverse o
logexpm1(x::Real) = x <= 18.0 ? log(_expm1(x)) : x <= 33.3 ? x - exp(-x) : oftype(exp(-x), x)
logexpm1(x::Float32) = x <= 9f0 ? log(expm1(x)) : x <= 16f0 ? x - exp(-x) : oftype(exp(-x), x)

const softplus = log1pexp
const invsoftplus = logexpm1
"""
$(SIGNATURES)

The generalized `softplus` function (Wiemann et al., 2024) takes an additional optional parameter `a` that control
the approximation error with respect to the linear spline. It defaults to `a=1.0`, in which case the softplus is
equivalent to [`log1pexp`](@ref).

See:
* Wiemann, P. F., Kneib, T., & Hambuckers, J. (2024). Using the softplus function to construct alternative link functions in generalized linear models and beyond. Statistical Papers, 65(5), 3155-3180.
"""
softplus(x::Real) = log1pexp(x)
softplus(x::Real, a::Real) = log1pexp(a * x) / a
softplus(x::Real; a::Real=1) = softplus(x, a)
devmotion marked this conversation as resolved.
Show resolved Hide resolved

"""
$(SIGNATURES)

The inverse generalized `softplus` function (Wiemann et al., 2024). See [`softplus`](@ref).
"""
invsoftplus(y::Real) = logexpm1(y)
DominiqueMakowski marked this conversation as resolved.
Show resolved Hide resolved
invsoftplus(y::Real, a::Real) = logexpm1(a * y) / a
invsoftplus(y::Real; a::Real=1) = invsoftplus(y, a)
devmotion marked this conversation as resolved.
Show resolved Hide resolved


"""
$(SIGNATURES)
Expand Down
10 changes: 10 additions & 0 deletions test/basicfuns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ end
end
end

@testset "softplus" begin
@test softplus(2) ≈ log1pexp(2)
@test softplus(2, 1) ≈ log1pexp(2)
@test softplus(2, a=1) ≈ log1pexp(2)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
@test softplus(2, 10) < log1pexp(2)
@test invsoftplus(softplus(2), 1) ≈ 2
@test invsoftplus(softplus(2, 10), a=10) ≈ 2
devmotion marked this conversation as resolved.
Show resolved Hide resolved
end


@testset "log1mexp" begin
for T in (Float64, Float32, Float16)
@test @inferred(log1mexp(-T(1))) isa T
Expand Down
3 changes: 3 additions & 0 deletions test/inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,7 @@

InverseFunctions.test_inverse(log1mlogistic, randexp())
InverseFunctions.test_inverse(logit1mexp, -randexp())

InverseFunctions.test_inverse(softplus, randn())
InverseFunctions.test_inverse(invsoftplus, randexp())
end
Loading