Skip to content

Commit

Permalink
Call user function only once in mean
Browse files Browse the repository at this point in the history
Override the standard `mapreduce` machinery to promote accumulator type.
This avoid calling the function twice, which can be confusing.
  • Loading branch information
nalimilan committed Sep 9, 2023
1 parent 81a90af commit 6877ebf
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
33 changes: 30 additions & 3 deletions src/Statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ if !isdefined(Base, :mean)
"""
mean(itr) = mean(identity, itr)

_mean_promote(x::T, y::S) where {T,S} = convert(promote_type(T, S), y)

"""
mean(f, itr)
Expand Down Expand Up @@ -178,7 +180,33 @@ if !isdefined(Base, :mean)
"""
mean(A::AbstractArray; dims=:) = _mean(identity, A, dims)

_mean_promote(x::T, y::S) where {T,S} = convert(promote_type(T, S), y)
struct _InitType end

Base.add_sum(x::_InitType, y::Any) = y/1

Base._mapreduce_dim(f, op, ::_InitType, A::Base.AbstractArrayOrBroadcasted, dims) =
Base.mapreducedim!(f, op, Base.reducedim_init(f, op, A, dims), A)
Base._mapreduce_dim(f, op, ::_InitType, A::Base.AbstractArrayOrBroadcasted, ::Colon) =
Base.mapfoldl_impl(f, op, _InitType(), A)
promote_add(x::T, y::S) where {T,S} =
Base.add_sum(convert(promote_type(T, S), x),
convert(promote_type(T, S), y))

function Base.reducedim_init(f, op::typeof(promote_add), A::AbstractArray, region)
Base._reducedim_init(f, op, zero, mean, A, region)
end
function Base._reducedim_init(f, op::typeof(promote_add), fv, fop, A, region)
T = Base._realtype(f, Base.promote_union(eltype(A)))
if T !== Any && applicable(zero, T)
x = f(zero(T)/1)
z = op(fv(x), fv(x))
Tr = z isa T ? T : typeof(z)
else
z = fv(fop(f, A))
Tr = typeof(z)
end
return Base.reducedim_initarray(A, region, z, Tr)
end

# ::Dims is there to force specializing on Colon (as it is a Function)
function _mean(f, A::AbstractArray, dims::Dims=:) where Dims
Expand All @@ -188,8 +216,7 @@ if !isdefined(Base, :mean)
else
n = mapreduce(i -> size(A, i), *, unique(dims); init=1)
end
x1 = f(first(A)) / 1
result = sum(x -> _mean_promote(x1, f(x)), A, dims=dims)
result = mapreduce(f, promote_add, A, dims=dims, init=_InitType())
if dims === (:)
return result / n
else
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ end
float(typemax(Int)))
end
let x = rand(10000) # mean should use sum's accurate pairwise algorithm
@test mean(x) == sum(x) / length(x)
@test mean(x) == sum(x; init=0.0) / length(x)
end
@test mean(Number[1, 1.5, 2+3im]) === 1.5+1im # mixed-type array
@test mean(v for v in Number[1, 1.5, 2+3im]) === 1.5+1im
Expand Down

0 comments on commit 6877ebf

Please sign in to comment.