From 581b9df9d073814516c5a5043f3d4d5caa4c6e10 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 2 May 2022 10:04:34 -0700 Subject: [PATCH 01/38] sumlog --- src/LogExpFunctions.jl | 3 ++- src/sumlog.jl | 36 ++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + test/sumlog.jl | 5 +++++ 4 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 src/sumlog.jl create mode 100644 test/sumlog.jl diff --git a/src/LogExpFunctions.jl b/src/LogExpFunctions.jl index e7c3afab..8512e0ab 100644 --- a/src/LogExpFunctions.jl +++ b/src/LogExpFunctions.jl @@ -11,12 +11,13 @@ import LinearAlgebra export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1, softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, logsumexp!, softmax, - softmax!, logcosh + softmax!, logcosh, sumlog include("basicfuns.jl") include("logsumexp.jl") include("chainrules.jl") include("inverse.jl") include("with_logabsdet_jacobian.jl") +include("sumlog.jl") end # module diff --git a/src/sumlog.jl b/src/sumlog.jl new file mode 100644 index 00000000..137360bc --- /dev/null +++ b/src/sumlog.jl @@ -0,0 +1,36 @@ +using IrrationalConstants: logtwo + +""" +$(SIGNATURES) + +Compute `sum(log.(X))`. + +`sum(log.(X))` can be evaluated much more quickly as `sum(log, X)`. However, +this still requires computing `log` for each element of `X`. + +`sumlog(X)` can be faster still, especially an the length of `X` increases. + +This works by representing the `j`th element of `X` as `xⱼ = aⱼ * 2 ^ bⱼ`, +allowing us to write + + ∑ⱼ log(xⱼ) = log(∏ⱼ aⱼ) + log(2) * ∑ⱼ bⱼ + +Since `log(2)` is constant, `sumlog` only requires a single `log` evaluation. +""" +function sumlog(x::AbstractArray{T}) where {T} + sig = one(T) + ex = zero(exponent(one(T))) + bound = floatmax(T) / 2 + for xj in x + sig *= significand(xj) + ex += exponent(xj) + + # Significands are in the rang [1,2), so multiplication will eventually overflow + if sig > bound + (a, b) = (significand(sig), exponent(sig)) + sig = a + ex += b + end + end + log(sig) + logtwo * ex +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index f4c9ea3c..91e9fa92 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,3 +14,4 @@ include("basicfuns.jl") include("chainrules.jl") include("inverse.jl") include("with_logabsdet_jacobian.jl") +include("sumlog.jl") \ No newline at end of file diff --git a/test/sumlog.jl b/test/sumlog.jl new file mode 100644 index 00000000..724d4ed0 --- /dev/null +++ b/test/sumlog.jl @@ -0,0 +1,5 @@ +@testset "sumlog" begin + for x in [10 .* rand(1000), repeat([nextfloat(1.0)], 1000), repeat([prevfloat(2.0)], 1000)] + @test (@inferred sumlog(x)) ≈ sum(log, x) + end +end \ No newline at end of file From 5725aa91932cbfa422f403c86cf4c56183713219 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 2 May 2022 10:40:45 -0700 Subject: [PATCH 02/38] Update src/sumlog.jl Co-authored-by: David Widmann --- src/sumlog.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index 137360bc..b653219e 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -3,12 +3,9 @@ using IrrationalConstants: logtwo """ $(SIGNATURES) -Compute `sum(log.(X))`. +Compute `sum(log.(X))` with a single `log` evaluation. -`sum(log.(X))` can be evaluated much more quickly as `sum(log, X)`. However, -this still requires computing `log` for each element of `X`. - -`sumlog(X)` can be faster still, especially an the length of `X` increases. +This is faster than computing `sum(log.(X))` or even `sum(log, X)`, in particular as `X` increases. This works by representing the `j`th element of `X` as `xⱼ = aⱼ * 2 ^ bⱼ`, allowing us to write From 77aa3d9fab0ff71a091a4faedd772a219f2f14c7 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 2 May 2022 10:44:36 -0700 Subject: [PATCH 03/38] Update src/sumlog.jl Co-authored-by: David Widmann --- src/sumlog.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index b653219e..89effbc6 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -22,7 +22,7 @@ function sumlog(x::AbstractArray{T}) where {T} sig *= significand(xj) ex += exponent(xj) - # Significands are in the rang [1,2), so multiplication will eventually overflow + # Significands are in the range [1,2), so multiplication will eventually overflow if sig > bound (a, b) = (significand(sig), exponent(sig)) sig = a From 88d6fb1c066b578aa12f2edb8bd90115cf9d4837 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 2 May 2022 11:22:59 -0700 Subject: [PATCH 04/38] Update src/sumlog.jl Co-authored-by: David Widmann --- src/sumlog.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index 89effbc6..ea6c8867 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -9,9 +9,9 @@ This is faster than computing `sum(log.(X))` or even `sum(log, X)`, in particula This works by representing the `j`th element of `X` as `xⱼ = aⱼ * 2 ^ bⱼ`, allowing us to write - - ∑ⱼ log(xⱼ) = log(∏ⱼ aⱼ) + log(2) * ∑ⱼ bⱼ - +```math +\\sum_j \\log{x_j} = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j +``` Since `log(2)` is constant, `sumlog` only requires a single `log` evaluation. """ function sumlog(x::AbstractArray{T}) where {T} From 9ecf5893d85748198a2008cee38c2722f86df14f Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 2 May 2022 11:23:09 -0700 Subject: [PATCH 05/38] Update src/sumlog.jl Co-authored-by: David Widmann --- src/sumlog.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index ea6c8867..3c4302d6 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -12,7 +12,7 @@ allowing us to write ```math \\sum_j \\log{x_j} = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j ``` -Since `log(2)` is constant, `sumlog` only requires a single `log` evaluation. +Since ``\\log{2}`` is constant, `sumlog` only requires a single `log` evaluation. """ function sumlog(x::AbstractArray{T}) where {T} sig = one(T) From 9db732f0488fe44abffaf7038bf216246a23570d Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 2 May 2022 11:23:13 -0700 Subject: [PATCH 06/38] Update src/sumlog.jl Co-authored-by: David Widmann --- src/sumlog.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index 3c4302d6..21857a0f 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -7,7 +7,7 @@ Compute `sum(log.(X))` with a single `log` evaluation. This is faster than computing `sum(log.(X))` or even `sum(log, X)`, in particular as `X` increases. -This works by representing the `j`th element of `X` as `xⱼ = aⱼ * 2 ^ bⱼ`, +This works by representing the `j`th element of `X` as ``x_j = a_j 2^b_j``, allowing us to write ```math \\sum_j \\log{x_j} = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j From 76533e12b126158840167ebcc81bd3ef718fb6ae Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 2 May 2022 11:27:57 -0700 Subject: [PATCH 07/38] fall-back method --- src/sumlog.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index 21857a0f..5a512a0f 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -14,7 +14,7 @@ allowing us to write ``` Since ``\\log{2}`` is constant, `sumlog` only requires a single `log` evaluation. """ -function sumlog(x::AbstractArray{T}) where {T} +function sumlog(x::AbstractArray{T}) where {T<:AbstractFloat} sig = one(T) ex = zero(exponent(one(T))) bound = floatmax(T) / 2 @@ -30,4 +30,6 @@ function sumlog(x::AbstractArray{T}) where {T} end end log(sig) + logtwo * ex -end \ No newline at end of file +end + +sumlog(x) = sum(log, x) \ No newline at end of file From 5747205dc38224bc16f971f43d7ad39a305392a3 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 2 May 2022 11:28:03 -0700 Subject: [PATCH 08/38] more tests --- test/sumlog.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/sumlog.jl b/test/sumlog.jl index 724d4ed0..ae637387 100644 --- a/test/sumlog.jl +++ b/test/sumlog.jl @@ -1,5 +1,7 @@ @testset "sumlog" begin - for x in [10 .* rand(1000), repeat([nextfloat(1.0)], 1000), repeat([prevfloat(2.0)], 1000)] - @test (@inferred sumlog(x)) ≈ sum(log, x) + for T in [Int, Float16, Float32, Float64, BigFloat] + for x in [10 .* rand(1000), repeat([nextfloat(1.0)], 1000), repeat([prevfloat(2.0)], 1000)] + @test (@inferred sumlog(x)) ≈ sum(log, x) + end end end \ No newline at end of file From 0f5a9275f5b3639064b0a520b1d29779a10be035 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 2 May 2022 11:30:43 -0700 Subject: [PATCH 09/38] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 843f313f..d8eb2e0d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LogExpFunctions" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" authors = ["StatsFun.jl contributors, Tamas K. Papp "] -version = "0.3.14" +version = "0.3.15" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 977723d4c0aa66971e47bc14d8ce7d12fbf6c7e1 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 2 May 2022 12:03:55 -0700 Subject: [PATCH 10/38] cast to floating point when possible --- src/sumlog.jl | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index 5a512a0f..e33dcef0 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -14,13 +14,18 @@ allowing us to write ``` Since ``\\log{2}`` is constant, `sumlog` only requires a single `log` evaluation. """ -function sumlog(x::AbstractArray{T}) where {T<:AbstractFloat} +function sumlog(x::AbstractArray{<:Real}) + T = float(eltype(x)) + + # `T` might be a `Symbolics.Num`, which is not an `AbstractFloat` + T isa AbstractFloat || return sum(log, x) sig = one(T) - ex = zero(exponent(one(T))) + ex = zero(exponent(sig)) bound = floatmax(T) / 2 - for xj in x - sig *= significand(xj) - ex += exponent(xj) + for xj in x + float_xj = float(xj) + sig *= significand(float_xj) + ex += exponent(float_xj) # Significands are in the range [1,2), so multiplication will eventually overflow if sig > bound From 4d488cd454a85b2d96a99acbdfad1453f157ee33 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 2 May 2022 12:09:30 -0700 Subject: [PATCH 11/38] docstring fixes --- src/sumlog.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index e33dcef0..b991df2f 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -5,14 +5,16 @@ $(SIGNATURES) Compute `sum(log.(X))` with a single `log` evaluation. -This is faster than computing `sum(log.(X))` or even `sum(log, X)`, in particular as `X` increases. +This is faster than computing `sum(log.(X))` or even `sum(log, X)`, in +particular as the size of `X` increases. -This works by representing the `j`th element of `X` as ``x_j = a_j 2^b_j``, +This works by representing the `j`th element of `X` as ``x_j = a_j 2^{b_j}``, allowing us to write ```math \\sum_j \\log{x_j} = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j ``` -Since ``\\log{2}`` is constant, `sumlog` only requires a single `log` evaluation. +Since ``\\log{2}`` is constant, `sumlog` only requires a single `log` +evaluation. """ function sumlog(x::AbstractArray{<:Real}) T = float(eltype(x)) From afa5d945e11bf85f83c60e4877b3722db96dfade Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 2 May 2022 12:20:39 -0700 Subject: [PATCH 12/38] performance fix --- src/sumlog.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index b991df2f..d0d71665 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -18,9 +18,10 @@ evaluation. """ function sumlog(x::AbstractArray{<:Real}) T = float(eltype(x)) + _sumlog(T, x) +end - # `T` might be a `Symbolics.Num`, which is not an `AbstractFloat` - T isa AbstractFloat || return sum(log, x) +function _sumlog(::Type{T}, x::AbstractArray{<:Real}) where {T<:AbstractFloat} sig = one(T) ex = zero(exponent(sig)) bound = floatmax(T) / 2 @@ -39,4 +40,7 @@ function sumlog(x::AbstractArray{<:Real}) log(sig) + logtwo * ex end +# `T` might be a `Symbolics.Num`, which is not an `AbstractFloat` +_sumlog(::Type{T}, x::AbstractArray{<:Real}) where {T} = sum(log, x) + sumlog(x) = sum(log, x) \ No newline at end of file From e400483380fad10f4d25fe60e659bacc23e56429 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Mon, 2 May 2022 12:22:35 -0700 Subject: [PATCH 13/38] inline _sumlog --- src/sumlog.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index d0d71665..2f6fdcd4 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -21,7 +21,7 @@ function sumlog(x::AbstractArray{<:Real}) _sumlog(T, x) end -function _sumlog(::Type{T}, x::AbstractArray{<:Real}) where {T<:AbstractFloat} +@inline function _sumlog(::Type{T}, x::AbstractArray{<:Real}) where {T<:AbstractFloat} sig = one(T) ex = zero(exponent(sig)) bound = floatmax(T) / 2 @@ -41,6 +41,6 @@ function _sumlog(::Type{T}, x::AbstractArray{<:Real}) where {T<:AbstractFloat} end # `T` might be a `Symbolics.Num`, which is not an `AbstractFloat` -_sumlog(::Type{T}, x::AbstractArray{<:Real}) where {T} = sum(log, x) +@inline _sumlog(::Type{T}, x::AbstractArray{<:Real}) where {T} = sum(log, x) sumlog(x) = sum(log, x) \ No newline at end of file From cc1aaacac3247be720412e9b9d9995d4dd918a89 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 3 May 2022 07:57:30 -0700 Subject: [PATCH 14/38] qualify IrrationalConstants.logtwo --- src/sumlog.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index 2f6fdcd4..c9aaa570 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -1,5 +1,3 @@ -using IrrationalConstants: logtwo - """ $(SIGNATURES) @@ -37,7 +35,7 @@ end ex += b end end - log(sig) + logtwo * ex + log(sig) + IrrationalConstants.logtwo * ex end # `T` might be a `Symbolics.Num`, which is not an `AbstractFloat` From 07809b76b8bfe5df0b557003505c0959c2d4640b Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 3 May 2022 07:59:23 -0700 Subject: [PATCH 15/38] update comment --- src/sumlog.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index c9aaa570..3039ae1c 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -38,7 +38,7 @@ end log(sig) + IrrationalConstants.logtwo * ex end -# `T` might be a `Symbolics.Num`, which is not an `AbstractFloat` +# `float(T)` is not always `isa AbstractFloat`, e.g. dual numbers or symbolics @inline _sumlog(::Type{T}, x::AbstractArray{<:Real}) where {T} = sum(log, x) sumlog(x) = sum(log, x) \ No newline at end of file From 16ee153fa1dff99bd073f577dad78a91c40cd336 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 3 May 2022 08:06:05 -0700 Subject: [PATCH 16/38] Update src/sumlog.jl Co-authored-by: David Widmann --- src/sumlog.jl | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index 3039ae1c..4ec6d211 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -20,21 +20,23 @@ function sumlog(x::AbstractArray{<:Real}) end @inline function _sumlog(::Type{T}, x::AbstractArray{<:Real}) where {T<:AbstractFloat} - sig = one(T) - ex = zero(exponent(sig)) - bound = floatmax(T) / 2 - for xj in x + sig, ex = mapreduce(_sumlog_op; init=(one(T), zero(exponent(one(T))))) do xj float_xj = float(xj) - sig *= significand(float_xj) - ex += exponent(float_xj) + return significand(float_xj), exponent(float_xj) + end + return log(sig) + IrrationalConstants.logtwo * ex +end - # Significands are in the range [1,2), so multiplication will eventually overflow - if sig > bound - (a, b) = (significand(sig), exponent(sig)) - sig = a - ex += b - end +function _sumlog_op((sig1, ex1), (sig2, ex2)) + sig = sig1 * sig2 + ex = ex1 + ex2 + # Significands are in the range [1,2), so multiplication will eventually overflow + if sig > floatmax(typeof(sig)) / 2 + ex += exponent(sig) + sig = significand(sig) end + return sig, ex +end log(sig) + IrrationalConstants.logtwo * ex end From 1af518b2f13768dd172057ccae220e8a85bd2603 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 3 May 2022 08:14:36 -0700 Subject: [PATCH 17/38] bugfix --- src/sumlog.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index 4ec6d211..3131b420 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -20,14 +20,14 @@ function sumlog(x::AbstractArray{<:Real}) end @inline function _sumlog(::Type{T}, x::AbstractArray{<:Real}) where {T<:AbstractFloat} - sig, ex = mapreduce(_sumlog_op; init=(one(T), zero(exponent(one(T))))) do xj + sig, ex = mapreduce(_sumlog_op, x; init=(one(T), zero(exponent(one(T))))) do xj float_xj = float(xj) - return significand(float_xj), exponent(float_xj) + significand(float_xj), exponent(float_xj) end return log(sig) + IrrationalConstants.logtwo * ex end -function _sumlog_op((sig1, ex1), (sig2, ex2)) +@inline function _sumlog_op((sig1, ex1), (sig2, ex2)) sig = sig1 * sig2 ex = ex1 + ex2 # Significands are in the range [1,2), so multiplication will eventually overflow @@ -37,10 +37,8 @@ function _sumlog_op((sig1, ex1), (sig2, ex2)) end return sig, ex end - log(sig) + IrrationalConstants.logtwo * ex -end # `float(T)` is not always `isa AbstractFloat`, e.g. dual numbers or symbolics -@inline _sumlog(::Type{T}, x::AbstractArray{<:Real}) where {T} = sum(log, x) +@inline _sumlog(::Type{T}, x) where {T} = sum(log, x) sumlog(x) = sum(log, x) \ No newline at end of file From 0eaf8d2a2d7005bb2480ded39a0bd882438b0b9a Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 3 May 2022 08:29:02 -0700 Subject: [PATCH 18/38] Make it work (and be fast) for Tuples and NamedTuples --- src/sumlog.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index 3131b420..60d00dd3 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -14,12 +14,12 @@ allowing us to write Since ``\\log{2}`` is constant, `sumlog` only requires a single `log` evaluation. """ -function sumlog(x::AbstractArray{<:Real}) +function sumlog(x) T = float(eltype(x)) - _sumlog(T, x) + _sumlog(T, values(x)) end -@inline function _sumlog(::Type{T}, x::AbstractArray{<:Real}) where {T<:AbstractFloat} +@inline function _sumlog(::Type{T}, x) where {T<:AbstractFloat} sig, ex = mapreduce(_sumlog_op, x; init=(one(T), zero(exponent(one(T))))) do xj float_xj = float(xj) significand(float_xj), exponent(float_xj) @@ -40,5 +40,3 @@ end # `float(T)` is not always `isa AbstractFloat`, e.g. dual numbers or symbolics @inline _sumlog(::Type{T}, x) where {T} = sum(log, x) - -sumlog(x) = sum(log, x) \ No newline at end of file From 1f478d0898e858e29969d158aff5574cb0bd8c32 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 3 May 2022 15:16:54 -0700 Subject: [PATCH 19/38] add sumlog to docs --- docs/src/index.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/index.md b/docs/src/index.md index eb31ecfe..d3e1583a 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -26,6 +26,7 @@ logaddexp logsubexp logsumexp logsumexp! +sumlog softmax! softmax ``` From 0807f7a46455fa15daf32ce80c2573a982138ee3 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 3 May 2022 16:19:52 -0700 Subject: [PATCH 20/38] tests --- test/sumlog.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/sumlog.jl b/test/sumlog.jl index ae637387..bdd58525 100644 --- a/test/sumlog.jl +++ b/test/sumlog.jl @@ -2,6 +2,23 @@ for T in [Int, Float16, Float32, Float64, BigFloat] for x in [10 .* rand(1000), repeat([nextfloat(1.0)], 1000), repeat([prevfloat(2.0)], 1000)] @test (@inferred sumlog(x)) ≈ sum(log, x) + + y = view(x, 1:100) + @test (@inferred sumlog(y)) ≈ sum(log, y) + + tup = tuple(y...) + @test (@inferred sumlog(tup)) ≈ sum(log, tup) + + gen = (sqrt(a) for a in y) + @test_broken (@inferred sumlog(gen)) ≈ sum(log, gen) + + nt = NamedTuple{tuple(Symbol.(1:100)...)}(tup) + @test (@inferred sumlog(y)) ≈ sum(log, y) + + i = Random.shuffle(x) + z = x .+ i * im + @test (@inferred sumlog(z)) ≈ sum(log, z) end + end end \ No newline at end of file From 2a0004d918c338ec7b64e98efc6356314cbcd26f Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 3 May 2022 16:21:38 -0700 Subject: [PATCH 21/38] comment that `eltype` of a `Base.Generator` returns `Any` --- test/sumlog.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/sumlog.jl b/test/sumlog.jl index bdd58525..8f0be09f 100644 --- a/test/sumlog.jl +++ b/test/sumlog.jl @@ -10,6 +10,7 @@ @test (@inferred sumlog(tup)) ≈ sum(log, tup) gen = (sqrt(a) for a in y) + # `eltype` of a `Base.Generator` returns `Any` @test_broken (@inferred sumlog(gen)) ≈ sum(log, gen) nt = NamedTuple{tuple(Symbol.(1:100)...)}(tup) From e5809d1b1b29045bb26c36c291acad133477e1a8 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 7 May 2022 14:37:28 -0400 Subject: [PATCH 22/38] saturday --- src/sumlog.jl | 88 +++++++++++++++++++++++++++++++++++++++----------- test/sumlog.jl | 62 ++++++++++++++++++++++++++++------- 2 files changed, 120 insertions(+), 30 deletions(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index 60d00dd3..7dc634ca 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -1,42 +1,94 @@ """ -$(SIGNATURES) + sumlog(X::AbstractArray{T}; dims) -Compute `sum(log.(X))` with a single `log` evaluation. +Compute `sum(log.(X))` with a single `log` evaluation, +provided `float(T) <: AbstractFloat`. -This is faster than computing `sum(log.(X))` or even `sum(log, X)`, in -particular as the size of `X` increases. - -This works by representing the `j`th element of `X` as ``x_j = a_j 2^{b_j}``, +This is faster than computing `sum(log, X)`, especially for large `X`. +It works by representing the `j`th element of `X` as ``x_j = a_j 2^{b_j}``, allowing us to write ```math \\sum_j \\log{x_j} = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j ``` -Since ``\\log{2}`` is constant, `sumlog` only requires a single `log` -evaluation. """ -function sumlog(x) - T = float(eltype(x)) - _sumlog(T, values(x)) +sumlog(x::AbstractArray{T}; dims=:) where T = _sumlog(float(T), dims, x) + +function _sumlog(::Type{T}, ::Colon, x) where {T<:AbstractFloat} + sig, ex = mapreduce(_sumlog_op, x; init=(one(T), 0)) do xj + xj < 0 && Base.Math.throw_complex_domainerror(:log, xj) + float_xj = float(xj) + significand(float_xj), _exponent(float_xj) + end + return log(sig) + IrrationalConstants.logtwo * T(ex) end -@inline function _sumlog(::Type{T}, x) where {T<:AbstractFloat} - sig, ex = mapreduce(_sumlog_op, x; init=(one(T), zero(exponent(one(T))))) do xj +function _sumlog(::Type{T}, dims, x) where {T<:AbstractFloat} + sig_ex = mapreduce(_sumlog_op, x; dims=dims, init=(one(T), 0)) do xj + xj < 0 && Base.Math.throw_complex_domainerror(:log, xj) float_xj = float(xj) - significand(float_xj), exponent(float_xj) + significand(float_xj), _exponent(float_xj) + end + map(sig_ex) do (sig, ex) + log(sig) + IrrationalConstants.logtwo * T(ex) end - return log(sig) + IrrationalConstants.logtwo * ex end +# Fallback: `float(T)` is not always `<: AbstractFloat`, e.g. complex, dual numbers or symbolics +_sumlog(::Type, dims, x) = sum(log, x; dims) + @inline function _sumlog_op((sig1, ex1), (sig2, ex2)) sig = sig1 * sig2 + # sig = ifelse(sig2<0, sig2, sig1 * sig2) ex = ex1 + ex2 # Significands are in the range [1,2), so multiplication will eventually overflow if sig > floatmax(typeof(sig)) / 2 - ex += exponent(sig) + ex += _exponent(sig) sig = significand(sig) end return sig, ex end -# `float(T)` is not always `isa AbstractFloat`, e.g. dual numbers or symbolics -@inline _sumlog(::Type{T}, x) where {T} = sum(log, x) +# The exported `exponent(x)` checks for `NaN` etc, this function doesn't, which is fine as `sig` keeps track. +_exponent(x::Base.IEEEFloat) = Base.Math._exponent_finite_nonzero(x) +Base.@assume_effects :nothrow _exponent(x::AbstractFloat) = Int(exponent(x)) # e.g. for BigFloat + +""" + sumlog(x) + sumlog(f, x, ys...) + +For any iterator which produces `AbstractFloat` elements, +this can use `sumlog`'s fast reduction strategy. + +Signature with `f` is equivalent to `sum(log, map(f, x, ys...))` +or `mapreduce(log∘f, +, x, ys...)`, without intermediate allocations. + +Does not accept a `dims` keyword. +""" +sumlog(f, x) = sumlog(Iterators.map(f, x)) +sumlog(f, x, ys...) = sumlog(f(xy...) for xy in zip(x, ys...)) + +# Iterator version, uses the same `_sumlog_op`, should be the same speed. +function sumlog(x) + iter = iterate(x) + if isnothing(iter) + T = Base._return_type(first, Tuple{typeof(x)}) + return T <: Number ? zero(float(T)) : 0.0 + end + x1 = float(iter[1]) + x1 isa AbstractFloat || return sum(log, x) + x1 < 0 && Base.Math.throw_complex_domainerror(:log, x1) + sig, ex = significand(x1), _exponent(x1) + nonfloat = zero(x1) + iter = iterate(x, iter[2]) + while iter !== nothing + xj = float(iter[1]) + if xj isa AbstractFloat + xj < 0 && Base.Math.throw_complex_domainerror(:log, xj) + sig, ex = _sumlog_op((sig, ex), (significand(xj), _exponent(xj))) + else + nonfloat += log(xj) + end + iter = iterate(x, iter[2]) + end + return log(sig) + IrrationalConstants.logtwo * oftype(sig, ex) + nonfloat +end diff --git a/test/sumlog.jl b/test/sumlog.jl index 8f0be09f..bb321e15 100644 --- a/test/sumlog.jl +++ b/test/sumlog.jl @@ -1,25 +1,63 @@ @testset "sumlog" begin - for T in [Int, Float16, Float32, Float64, BigFloat] - for x in [10 .* rand(1000), repeat([nextfloat(1.0)], 1000), repeat([prevfloat(2.0)], 1000)] + @testset for T in [Float16, Float32, Float64, BigFloat] + for x in ( + T[1,2,3], + 10 .* rand(T, 1000), + fill(nextfloat(T(1.0)), 1000), + fill(prevfloat(T(2.0)), 1000), + ) + @test sumlog(x) isa T + @test (@inferred sumlog(x)) ≈ sum(log, x) - y = view(x, 1:100) - @test (@inferred sumlog(y)) ≈ sum(log, y) + y = @view x[1:min(end, 100)] + @test (@inferred sumlog(y')) ≈ sum(log, y) tup = tuple(y...) @test (@inferred sumlog(tup)) ≈ sum(log, tup) + # + # gen = (sqrt(a) for a in y) + # # `eltype` of a `Base.Generator` returns `Any` + # @test_broken (@inferred sumlog(gen)) ≈ sum(log, gen) - gen = (sqrt(a) for a in y) - # `eltype` of a `Base.Generator` returns `Any` - @test_broken (@inferred sumlog(gen)) ≈ sum(log, gen) + # nt = NamedTuple{tuple(Symbol.(1:100)...)}(tup) + # @test (@inferred sumlog(y)) ≈ sum(log, y) - nt = NamedTuple{tuple(Symbol.(1:100)...)}(tup) - @test (@inferred sumlog(y)) ≈ sum(log, y) - - i = Random.shuffle(x) - z = x .+ i * im + z = x .+ im .* Random.shuffle(x) @test (@inferred sumlog(z)) ≈ sum(log, z) end + # With dims + m = 1 .+ rand(T, 10, 10) + sumlog(m; dims=1) ≈ sum(log, m; dims=1) + sumlog(m; dims=2) ≈ sum(log, m; dims=2) + + # Iterator + @test sumlog(x^2 for x in m) ≈ sumlog(abs2, m) ≈ sumlog(*, m, m) ≈ sum(log.(m.^2)) + @test sumlog(x for x in Any[1, 2, 3+im, 4]) ≈ sum(log, Any[1, 2, 3+im, 4]) + + # NaN, Inf + if T != BigFloat # exponent fails here + @test isnan(sumlog(T[1, 2, NaN])) + @test isinf(sumlog(T[1, 2, Inf])) + @test sumlog(T[1, 2, 0.0]) == -Inf + @test sumlog(T[1, 2, -0.0]) == -Inf + end + + # Empty + @test sumlog(T[]) isa T + @test eltype(sumlog(T[]; dims=1)) == T + @test sumlog(x for x in T[]) isa T + + # Negative + @test_throws DomainError sumlog(T[1, -2, 3]) # easy + @test_throws DomainError sumlog(T[1, -2, -3]) # harder + + end + @testset "Int" begin + @test sumlog([1,2,3]) isa Float64 + @test sumlog([1,2,3]) ≈ sum(log, [1,2,3]) + @test sumlog([1 2; 3 4]; dims=1) ≈ sum(log, [1 2; 3 4]; dims=1) + @test sumlog(Int(x) for x in Float64[1,2,3]) ≈ sum(log, [1,2,3]) end end \ No newline at end of file From 6fe8bb158d5fc2f006b350eb7c85687c6808c2a7 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Sat, 7 May 2022 13:18:46 -0700 Subject: [PATCH 23/38] Update src/sumlog.jl Co-authored-by: David Widmann --- src/sumlog.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sumlog.jl b/src/sumlog.jl index 7dc634ca..56d28a42 100644 --- a/src/sumlog.jl +++ b/src/sumlog.jl @@ -5,7 +5,7 @@ Compute `sum(log.(X))` with a single `log` evaluation, provided `float(T) <: AbstractFloat`. This is faster than computing `sum(log, X)`, especially for large `X`. -It works by representing the `j`th element of `X` as ``x_j = a_j 2^{b_j}``, +It works by representing the `j`th element of `x` as ``x_j = a_j 2^{b_j}`, allowing us to write ```math \\sum_j \\log{x_j} = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j From 0def97dcc2fcdfc38d1d9022271d58d20cee51ab Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 10 May 2022 05:23:10 -0700 Subject: [PATCH 24/38] change to `logprod` --- src/LogExpFunctions.jl | 3 +- src/logprod.jl | 78 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 src/logprod.jl diff --git a/src/LogExpFunctions.jl b/src/LogExpFunctions.jl index 8512e0ab..57559541 100644 --- a/src/LogExpFunctions.jl +++ b/src/LogExpFunctions.jl @@ -18,6 +18,7 @@ include("logsumexp.jl") include("chainrules.jl") include("inverse.jl") include("with_logabsdet_jacobian.jl") -include("sumlog.jl") +# include("sumlog.jl") +include("logprod.jl") end # module diff --git a/src/logprod.jl b/src/logprod.jl new file mode 100644 index 00000000..b3128fcd --- /dev/null +++ b/src/logprod.jl @@ -0,0 +1,78 @@ +""" + logprod(X::AbstractArray{T}; dims) + +Compute `sum(log.(X))` with a single `log` evaluation, +provided `float(T) <: AbstractFloat`. + +This is faster than computing `sum(log, X)`, especially for large `X`. +It works by representing the `j`th element of `x` as ``x_j = a_j 2^{b_j}`, +allowing us to write +```math +\\sum_j \\log{x_j} = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j +``` +""" +logprod(x) = first(logabsprod(x)) + +export logabsprod + +function logabsprod(x::AbstractArray{T}) where {T} + sig, ex = mapreduce(_logabsprod_op, x; init=frexp(one(T))) do xj + float_xj = float(xj) + frexp(float_xj) + end + sgn = signbit(sig) ? one(T) : -one(T) + return (log(abs(sig)) + IrrationalConstants.logtwo * T(ex), sgn) +end + +@inline function _logabsprod_op((sig1, ex1), (sig2, ex2)) + sig = sig1 * sig2 + # sig = ifelse(sig2<0, sig2, sig1 * sig2) + ex = ex1 + ex2 + # Significands are in the range [1,2), so multiplication will eventually overflow + if sig > floatmax(typeof(sig)) / 2 + (new_sig, Δex) = frexp(sig) + ex += Δex + sig = new_sig + end + return sig, ex +end + +""" + logprod(x) + logprod(f, x, ys...) + +For any iterator which produces `AbstractFloat` elements, +this can use `logprod`'s fast reduction strategy. + +Signature with `f` is equivalent to `sum(log, map(f, x, ys...))` +or `mapreduce(log∘f, +, x, ys...)`, without intermediate allocations. + +Does not accept a `dims` keyword. +""" +logprod(f, x) = logprod(Iterators.map(f, x)) +logprod(f, x, ys...) = logprod(f(xy...) for xy in zip(x, ys...)) + +# Iterator version, uses the same `_logprod_op`, should be the same speed. +function logprod(x) + iter = iterate(x) + if isnothing(iter) + T = Base._return_type(first, Tuple{typeof(x)}) + return T <: Number ? zero(float(T)) : 0.0 + end + x1 = float(iter[1]) + x1 isa AbstractFloat || return sum(log, x) + x1 < 0 && Base.Math.throw_complex_domainerror(:log, x1) + sig, ex = significand(x1), _exponent(x1) + nonfloat = zero(x1) + iter = iterate(x, iter[2]) + while iter !== nothing + xj = float(iter[1]) + if xj isa AbstractFloat + sig, ex = _logprod_op((sig, ex), (significand(xj), _exponent(xj))) + else + nonfloat += log(xj) + end + iter = iterate(x, iter[2]) + end + return log(sig) + IrrationalConstants.logtwo * oftype(sig, ex) + nonfloat +end From 3ad95b2b130c5b63463116b5965b77713cb1119e Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 10 May 2022 05:50:07 -0700 Subject: [PATCH 25/38] fix sign bit --- src/logprod.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/logprod.jl b/src/logprod.jl index b3128fcd..df821b9f 100644 --- a/src/logprod.jl +++ b/src/logprod.jl @@ -20,7 +20,7 @@ function logabsprod(x::AbstractArray{T}) where {T} float_xj = float(xj) frexp(float_xj) end - sgn = signbit(sig) ? one(T) : -one(T) + sgn = signbit(sig) ? -one(T) : one(T) return (log(abs(sig)) + IrrationalConstants.logtwo * T(ex), sgn) end From a0a9348d3a5005158966b832963f5a6bdcd1401c Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 10 May 2022 07:08:08 -0700 Subject: [PATCH 26/38] Update docs/src/index.md Co-authored-by: David Widmann --- docs/src/index.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/src/index.md b/docs/src/index.md index d3e1583a..90522dec 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -26,7 +26,8 @@ logaddexp logsubexp logsumexp logsumexp! -sumlog +logprod +logabsprod softmax! softmax ``` From fa667ecc8bc455c21bb5013c0d52354eb6350d90 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 10 May 2022 07:08:26 -0700 Subject: [PATCH 27/38] Update src/LogExpFunctions.jl Co-authored-by: David Widmann --- src/LogExpFunctions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/LogExpFunctions.jl b/src/LogExpFunctions.jl index 57559541..2d7a7385 100644 --- a/src/LogExpFunctions.jl +++ b/src/LogExpFunctions.jl @@ -11,7 +11,7 @@ import LinearAlgebra export xlogx, xlogy, xlog1py, xexpx, xexpy, logistic, logit, log1psq, log1pexp, log1mexp, log2mexp, logexpm1, softplus, invsoftplus, log1pmx, logmxp1, logaddexp, logsubexp, logsumexp, logsumexp!, softmax, - softmax!, logcosh, sumlog + softmax!, logcosh, logprod, logabsprod include("basicfuns.jl") include("logsumexp.jl") From 989a11184e73b2701db3c24d4cdafe0e1fcba53c Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 10 May 2022 07:09:15 -0700 Subject: [PATCH 28/38] Update src/logprod.jl Co-authored-by: David Widmann --- src/logprod.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/logprod.jl b/src/logprod.jl index df821b9f..e20eeb24 100644 --- a/src/logprod.jl +++ b/src/logprod.jl @@ -1,8 +1,7 @@ """ logprod(X::AbstractArray{T}; dims) -Compute `sum(log.(X))` with a single `log` evaluation, -provided `float(T) <: AbstractFloat`. +Compute `log(prod(x))` efficiently. This is faster than computing `sum(log, X)`, especially for large `X`. It works by representing the `j`th element of `x` as ``x_j = a_j 2^{b_j}`, From a54a024b8385631595d26dd0b881c3f41fa40051 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 10 May 2022 07:09:28 -0700 Subject: [PATCH 29/38] Update src/LogExpFunctions.jl Co-authored-by: David Widmann --- src/LogExpFunctions.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/LogExpFunctions.jl b/src/LogExpFunctions.jl index 2d7a7385..6fca657e 100644 --- a/src/LogExpFunctions.jl +++ b/src/LogExpFunctions.jl @@ -18,7 +18,6 @@ include("logsumexp.jl") include("chainrules.jl") include("inverse.jl") include("with_logabsdet_jacobian.jl") -# include("sumlog.jl") include("logprod.jl") end # module From dc484334a75b8b34156c4826ef04cdc70524b809 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 10 May 2022 07:25:57 -0700 Subject: [PATCH 30/38] Update src/logprod.jl Co-authored-by: David Widmann --- src/logprod.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/logprod.jl b/src/logprod.jl index e20eeb24..7b970657 100644 --- a/src/logprod.jl +++ b/src/logprod.jl @@ -29,9 +29,8 @@ end ex = ex1 + ex2 # Significands are in the range [1,2), so multiplication will eventually overflow if sig > floatmax(typeof(sig)) / 2 - (new_sig, Δex) = frexp(sig) + sig, Δex = frexp(sig) ex += Δex - sig = new_sig end return sig, ex end From 39ca989d0e19ed23f1ac803b1c16ba3189376d2d Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 10 May 2022 10:26:51 -0700 Subject: [PATCH 31/38] cleaning up --- src/logprod.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/logprod.jl b/src/logprod.jl index df821b9f..3691e724 100644 --- a/src/logprod.jl +++ b/src/logprod.jl @@ -26,13 +26,14 @@ end @inline function _logabsprod_op((sig1, ex1), (sig2, ex2)) sig = sig1 * sig2 - # sig = ifelse(sig2<0, sig2, sig1 * sig2) ex = ex1 + ex2 - # Significands are in the range [1,2), so multiplication will eventually overflow - if sig > floatmax(typeof(sig)) / 2 - (new_sig, Δex) = frexp(sig) + + # The significand from `frexp` has magnitude in the range [0.5, 1), + # so multiplication will eventually underflow + may_underflow(sig::T) where {T} = sig < sqrt(floatmin(T)) + if may_underflow(sig) + (sig, Δex) = frexp(sig) ex += Δex - sig = new_sig end return sig, ex end From 55d125ef26b6b6a7a65bb31f0c918d3444a0d8c6 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 10 May 2022 10:32:20 -0700 Subject: [PATCH 32/38] Update src/logprod.jl Co-authored-by: David Widmann --- src/logprod.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/logprod.jl b/src/logprod.jl index d1cc6671..f2c6b827 100644 --- a/src/logprod.jl +++ b/src/logprod.jl @@ -7,7 +7,7 @@ This is faster than computing `sum(log, X)`, especially for large `X`. It works by representing the `j`th element of `x` as ``x_j = a_j 2^{b_j}`, allowing us to write ```math -\\sum_j \\log{x_j} = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j +\\log \\prod_k x_j = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j. ``` """ logprod(x) = first(logabsprod(x)) From bef4728a6ebe86bdaf07bba56adf5aac7b1ac019 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 10 May 2022 10:32:45 -0700 Subject: [PATCH 33/38] Update src/logprod.jl Co-authored-by: David Widmann --- src/logprod.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/logprod.jl b/src/logprod.jl index f2c6b827..b29185a3 100644 --- a/src/logprod.jl +++ b/src/logprod.jl @@ -10,7 +10,11 @@ allowing us to write \\log \\prod_k x_j = \\log(\\prod_j a_j) + \\log{2} \\sum_j b_j. ``` """ -logprod(x) = first(logabsprod(x)) +function logprod(x) + y, s = logabsprod(x) + y isa Real && s < 0 && throw(DomainError(x, "`prod(x)` must be non-negative")) + return y +end export logabsprod From 9572e4851b0af8cceaa6a023c17c26320a708c79 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 10 May 2022 10:34:51 -0700 Subject: [PATCH 34/38] Update src/logprod.jl Co-authored-by: David Widmann --- src/logprod.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/logprod.jl b/src/logprod.jl index b29185a3..5ed1b3f8 100644 --- a/src/logprod.jl +++ b/src/logprod.jl @@ -23,8 +23,7 @@ function logabsprod(x::AbstractArray{T}) where {T} float_xj = float(xj) frexp(float_xj) end - sgn = signbit(sig) ? -one(T) : one(T) - return (log(abs(sig)) + IrrationalConstants.logtwo * T(ex), sgn) + return (log(abs(sig)) + IrrationalConstants.logtwo * T(ex), sign(sig)) end @inline function _logabsprod_op((sig1, ex1), (sig2, ex2)) From 38488485656160ecf22c5b523ce496d197d87ffa Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 10 May 2022 10:35:26 -0700 Subject: [PATCH 35/38] Update src/logprod.jl Co-authored-by: David Widmann --- src/logprod.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/logprod.jl b/src/logprod.jl index 5ed1b3f8..c4a6662d 100644 --- a/src/logprod.jl +++ b/src/logprod.jl @@ -55,8 +55,8 @@ Does not accept a `dims` keyword. logprod(f, x) = logprod(Iterators.map(f, x)) logprod(f, x, ys...) = logprod(f(xy...) for xy in zip(x, ys...)) -# Iterator version, uses the same `_logprod_op`, should be the same speed. -function logprod(x) +# Iterator version, uses the same `_logabsprod_op`, should be the same speed. +function logabsprod(x) iter = iterate(x) if isnothing(iter) T = Base._return_type(first, Tuple{typeof(x)}) From c4c3e8937fbcbb6207d54fd73fcc1905ab175239 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 10 May 2022 10:37:24 -0700 Subject: [PATCH 36/38] Update src/logprod.jl Co-authored-by: David Widmann --- src/logprod.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/logprod.jl b/src/logprod.jl index c4a6662d..8ed56014 100644 --- a/src/logprod.jl +++ b/src/logprod.jl @@ -59,8 +59,8 @@ logprod(f, x, ys...) = logprod(f(xy...) for xy in zip(x, ys...)) function logabsprod(x) iter = iterate(x) if isnothing(iter) - T = Base._return_type(first, Tuple{typeof(x)}) - return T <: Number ? zero(float(T)) : 0.0 + y = prod(x) + return log(abs(y)), sign(y) end x1 = float(iter[1]) x1 isa AbstractFloat || return sum(log, x) From e0f410e814d9ee248afe813604ffcbeef760280b Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Tue, 10 May 2022 10:37:55 -0700 Subject: [PATCH 37/38] Update src/logprod.jl Co-authored-by: David Widmann --- src/logprod.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/logprod.jl b/src/logprod.jl index 8ed56014..800ad66c 100644 --- a/src/logprod.jl +++ b/src/logprod.jl @@ -64,7 +64,6 @@ function logabsprod(x) end x1 = float(iter[1]) x1 isa AbstractFloat || return sum(log, x) - x1 < 0 && Base.Math.throw_complex_domainerror(:log, x1) sig, ex = significand(x1), _exponent(x1) nonfloat = zero(x1) iter = iterate(x, iter[2]) From 23b5bf10344382888a7d0bb57c8045f8f546c1a2 Mon Sep 17 00:00:00 2001 From: Chad Scherrer Date: Wed, 11 May 2022 07:16:50 -0700 Subject: [PATCH 38/38] Update src/logprod.jl Co-authored-by: David Widmann --- src/logprod.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/logprod.jl b/src/logprod.jl index 800ad66c..10750293 100644 --- a/src/logprod.jl +++ b/src/logprod.jl @@ -63,7 +63,10 @@ function logabsprod(x) return log(abs(y)), sign(y) end x1 = float(iter[1]) - x1 isa AbstractFloat || return sum(log, x) + if !(x1 isa AbstractFloat) + y = prod(x) + return log(abs(y)), sign(y) + end sig, ex = significand(x1), _exponent(x1) nonfloat = zero(x1) iter = iterate(x, iter[2])