diff --git a/Project.toml b/Project.toml index 843f313f..d8eb2e0d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LogExpFunctions" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" authors = ["StatsFun.jl contributors, Tamas K. Papp "] -version = "0.3.14" +version = "0.3.15" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/src/index.md b/docs/src/index.md index eb31ecfe..90522dec 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -26,6 +26,8 @@ logaddexp logsubexp logsumexp logsumexp! +logprod +logabsprod softmax! softmax ``` diff --git a/src/LogExpFunctions.jl b/src/LogExpFunctions.jl index e7c3afab..6fca657e 100644 --- a/src/LogExpFunctions.jl +++ b/src/LogExpFunctions.jl @@ -11,12 +11,13 @@ import LinearAlgebra export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1, softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, logsumexp!, softmax, - softmax!, logcosh + softmax!, logcosh, logprod, logabsprod include("basicfuns.jl") include("logsumexp.jl") include("chainrules.jl") include("inverse.jl") include("with_logabsdet_jacobian.jl") +include("logprod.jl") end # module diff --git a/src/logprod.jl b/src/logprod.jl new file mode 100644 index 00000000..10750293 --- /dev/null +++ b/src/logprod.jl @@ -0,0 +1,83 @@ +""" + logprod(X::AbstractArray{T}; dims) + +Compute `log(prod(x))` efficiently. + +This is faster than computing `sum(log, X)`, especially for large `X`. +It works by representing the `j`th element of `x` as ``x_j = a_j 2^{b_j}`, +allowing us to write +```math +\\log \\prod_k x_j = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j. +``` +""" +function logprod(x) + y, s = logabsprod(x) + y isa Real && s < 0 && throw(DomainError(x, "`prod(x)` must be non-negative")) + return y +end + +export logabsprod + +function logabsprod(x::AbstractArray{T}) where {T} + sig, ex = mapreduce(_logabsprod_op, x; init=frexp(one(T))) do xj + float_xj = float(xj) + frexp(float_xj) + end + return (log(abs(sig)) + IrrationalConstants.logtwo * T(ex), sign(sig)) +end + +@inline function _logabsprod_op((sig1, ex1), (sig2, ex2)) + sig = sig1 * sig2 + ex = ex1 + ex2 + + # The significand from `frexp` has magnitude in the range [0.5, 1), + # so multiplication will eventually underflow + may_underflow(sig::T) where {T} = sig < sqrt(floatmin(T)) + if may_underflow(sig) + (sig, Δex) = frexp(sig) + ex += Δex + end + return sig, ex +end + +""" + logprod(x) + logprod(f, x, ys...) + +For any iterator which produces `AbstractFloat` elements, +this can use `logprod`'s fast reduction strategy. + +Signature with `f` is equivalent to `sum(log, map(f, x, ys...))` +or `mapreduce(log∘f, +, x, ys...)`, without intermediate allocations. + +Does not accept a `dims` keyword. +""" +logprod(f, x) = logprod(Iterators.map(f, x)) +logprod(f, x, ys...) = logprod(f(xy...) for xy in zip(x, ys...)) + +# Iterator version, uses the same `_logabsprod_op`, should be the same speed. +function logabsprod(x) + iter = iterate(x) + if isnothing(iter) + y = prod(x) + return log(abs(y)), sign(y) + end + x1 = float(iter[1]) + if !(x1 isa AbstractFloat) + y = prod(x) + return log(abs(y)), sign(y) + end + sig, ex = significand(x1), _exponent(x1) + nonfloat = zero(x1) + iter = iterate(x, iter[2]) + while iter !== nothing + xj = float(iter[1]) + if xj isa AbstractFloat + sig, ex = _logprod_op((sig, ex), (significand(xj), _exponent(xj))) + else + nonfloat += log(xj) + end + iter = iterate(x, iter[2]) + end + return log(sig) + IrrationalConstants.logtwo * oftype(sig, ex) + nonfloat +end diff --git a/src/sumlog.jl b/src/sumlog.jl new file mode 100644 index 00000000..56d28a42 --- /dev/null +++ b/src/sumlog.jl @@ -0,0 +1,94 @@ +""" + sumlog(X::AbstractArray{T}; dims) + +Compute `sum(log.(X))` with a single `log` evaluation, +provided `float(T) <: AbstractFloat`. + +This is faster than computing `sum(log, X)`, especially for large `X`. +It works by representing the `j`th element of `x` as ``x_j = a_j 2^{b_j}`, +allowing us to write +```math +\\sum_j \\log{x_j} = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j +``` +""" +sumlog(x::AbstractArray{T}; dims=:) where T = _sumlog(float(T), dims, x) + +function _sumlog(::Type{T}, ::Colon, x) where {T<:AbstractFloat} + sig, ex = mapreduce(_sumlog_op, x; init=(one(T), 0)) do xj + xj < 0 && Base.Math.throw_complex_domainerror(:log, xj) + float_xj = float(xj) + significand(float_xj), _exponent(float_xj) + end + return log(sig) + IrrationalConstants.logtwo * T(ex) +end + +function _sumlog(::Type{T}, dims, x) where {T<:AbstractFloat} + sig_ex = mapreduce(_sumlog_op, x; dims=dims, init=(one(T), 0)) do xj + xj < 0 && Base.Math.throw_complex_domainerror(:log, xj) + float_xj = float(xj) + significand(float_xj), _exponent(float_xj) + end + map(sig_ex) do (sig, ex) + log(sig) + IrrationalConstants.logtwo * T(ex) + end +end + +# Fallback: `float(T)` is not always `<: AbstractFloat`, e.g. complex, dual numbers or symbolics +_sumlog(::Type, dims, x) = sum(log, x; dims) + +@inline function _sumlog_op((sig1, ex1), (sig2, ex2)) + sig = sig1 * sig2 + # sig = ifelse(sig2<0, sig2, sig1 * sig2) + ex = ex1 + ex2 + # Significands are in the range [1,2), so multiplication will eventually overflow + if sig > floatmax(typeof(sig)) / 2 + ex += _exponent(sig) + sig = significand(sig) + end + return sig, ex +end + +# The exported `exponent(x)` checks for `NaN` etc, this function doesn't, which is fine as `sig` keeps track. +_exponent(x::Base.IEEEFloat) = Base.Math._exponent_finite_nonzero(x) +Base.@assume_effects :nothrow _exponent(x::AbstractFloat) = Int(exponent(x)) # e.g. for BigFloat + +""" + sumlog(x) + sumlog(f, x, ys...) + +For any iterator which produces `AbstractFloat` elements, +this can use `sumlog`'s fast reduction strategy. + +Signature with `f` is equivalent to `sum(log, map(f, x, ys...))` +or `mapreduce(log∘f, +, x, ys...)`, without intermediate allocations. + +Does not accept a `dims` keyword. +""" +sumlog(f, x) = sumlog(Iterators.map(f, x)) +sumlog(f, x, ys...) = sumlog(f(xy...) for xy in zip(x, ys...)) + +# Iterator version, uses the same `_sumlog_op`, should be the same speed. +function sumlog(x) + iter = iterate(x) + if isnothing(iter) + T = Base._return_type(first, Tuple{typeof(x)}) + return T <: Number ? zero(float(T)) : 0.0 + end + x1 = float(iter[1]) + x1 isa AbstractFloat || return sum(log, x) + x1 < 0 && Base.Math.throw_complex_domainerror(:log, x1) + sig, ex = significand(x1), _exponent(x1) + nonfloat = zero(x1) + iter = iterate(x, iter[2]) + while iter !== nothing + xj = float(iter[1]) + if xj isa AbstractFloat + xj < 0 && Base.Math.throw_complex_domainerror(:log, xj) + sig, ex = _sumlog_op((sig, ex), (significand(xj), _exponent(xj))) + else + nonfloat += log(xj) + end + iter = iterate(x, iter[2]) + end + return log(sig) + IrrationalConstants.logtwo * oftype(sig, ex) + nonfloat +end diff --git a/test/runtests.jl b/test/runtests.jl index f4c9ea3c..91e9fa92 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,3 +14,4 @@ include("basicfuns.jl") include("chainrules.jl") include("inverse.jl") include("with_logabsdet_jacobian.jl") +include("sumlog.jl") \ No newline at end of file diff --git a/test/sumlog.jl b/test/sumlog.jl new file mode 100644 index 00000000..bb321e15 --- /dev/null +++ b/test/sumlog.jl @@ -0,0 +1,63 @@ +@testset "sumlog" begin + @testset for T in [Float16, Float32, Float64, BigFloat] + for x in ( + T[1,2,3], + 10 .* rand(T, 1000), + fill(nextfloat(T(1.0)), 1000), + fill(prevfloat(T(2.0)), 1000), + ) + @test sumlog(x) isa T + + @test (@inferred sumlog(x)) ≈ sum(log, x) + + y = @view x[1:min(end, 100)] + @test (@inferred sumlog(y')) ≈ sum(log, y) + + tup = tuple(y...) + @test (@inferred sumlog(tup)) ≈ sum(log, tup) + # + # gen = (sqrt(a) for a in y) + # # `eltype` of a `Base.Generator` returns `Any` + # @test_broken (@inferred sumlog(gen)) ≈ sum(log, gen) + + # nt = NamedTuple{tuple(Symbol.(1:100)...)}(tup) + # @test (@inferred sumlog(y)) ≈ sum(log, y) + + z = x .+ im .* Random.shuffle(x) + @test (@inferred sumlog(z)) ≈ sum(log, z) + end + + # With dims + m = 1 .+ rand(T, 10, 10) + sumlog(m; dims=1) ≈ sum(log, m; dims=1) + sumlog(m; dims=2) ≈ sum(log, m; dims=2) + + # Iterator + @test sumlog(x^2 for x in m) ≈ sumlog(abs2, m) ≈ sumlog(*, m, m) ≈ sum(log.(m.^2)) + @test sumlog(x for x in Any[1, 2, 3+im, 4]) ≈ sum(log, Any[1, 2, 3+im, 4]) + + # NaN, Inf + if T != BigFloat # exponent fails here + @test isnan(sumlog(T[1, 2, NaN])) + @test isinf(sumlog(T[1, 2, Inf])) + @test sumlog(T[1, 2, 0.0]) == -Inf + @test sumlog(T[1, 2, -0.0]) == -Inf + end + + # Empty + @test sumlog(T[]) isa T + @test eltype(sumlog(T[]; dims=1)) == T + @test sumlog(x for x in T[]) isa T + + # Negative + @test_throws DomainError sumlog(T[1, -2, 3]) # easy + @test_throws DomainError sumlog(T[1, -2, -3]) # harder + + end + @testset "Int" begin + @test sumlog([1,2,3]) isa Float64 + @test sumlog([1,2,3]) ≈ sum(log, [1,2,3]) + @test sumlog([1 2; 3 4]; dims=1) ≈ sum(log, [1 2; 3 4]; dims=1) + @test sumlog(Int(x) for x in Float64[1,2,3]) ≈ sum(log, [1,2,3]) + end +end \ No newline at end of file