diff --git a/src/Statistics.jl b/src/Statistics.jl index 560b227d..cdab5990 100644 --- a/src/Statistics.jl +++ b/src/Statistics.jl @@ -44,6 +44,8 @@ if !isdefined(Base, :mean) """ mean(itr) = mean(identity, itr) + _mean_promote(x::T, y::S) where {T,S} = convert(promote_type(T, S), y) + """ mean(f, itr) @@ -178,7 +180,28 @@ if !isdefined(Base, :mean) """ mean(A::AbstractArray; dims=:) = _mean(identity, A, dims) - _mean_promote(x::T, y::S) where {T,S} = convert(promote_type(T, S), y) + promote_add_type(x::S, y::T) where {S, T} = + promote_type(typeof(zero(S)/1), typeof(zero(T)/1)) + function promote_add(x::Any, y::Any) + T = promote_add_type(x, y) + return Base.add_sum(convert(T, x), convert(T, y)) + end + + function Base.reducedim_init(f, op::typeof(promote_add), A::AbstractArray, region) + Base._reducedim_init(f, op, zero, mean, A, region) + end + function Base._reducedim_init(f, op::typeof(promote_add), fv, fop, A, region) + T = Base._realtype(f, Base.promote_union(eltype(A))) + if T !== Any && applicable(zero, T) + x = f(zero(T))/1 # /1 added for mean + z = op(fv(x), fv(x)) + Tr = z isa T ? T : typeof(z) + else + z = fv(fop(f, A)) + Tr = typeof(z) + end + return Base.reducedim_initarray(A, region, z, Tr) + end # ::Dims is there to force specializing on Colon (as it is a Function) function _mean(f, A::AbstractArray, dims::Dims=:) where Dims @@ -188,8 +211,7 @@ if !isdefined(Base, :mean) else n = mapreduce(i -> size(A, i), *, unique(dims); init=1) end - x1 = f(first(A)) / 1 - result = sum(x -> _mean_promote(x1, f(x)), A, dims=dims) + result = mapreduce(f, promote_add, A, dims=dims) if dims === (:) return result / n else