Skip to content

Commit

Permalink
some updates for mapreducedim refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
adienes committed Dec 21, 2024
1 parent 793733e commit 465e3a0
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
2 changes: 1 addition & 1 deletion ext/SparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end
# This is the function that does the reduction underlying var/std
function centralize_sumabs2!(R::AbstractArray{S}, A::SparseMatrixCSC{Tv,Ti}, means::AbstractArray) where {S,Tv,Ti}
require_one_based_indexing(R, A, means)
lsiz = Base.check_reducedims(R,A)
lsiz = Base._linear_reduction_length(A, axes(R))
for i in 1:max(ndims(R), ndims(means))
if axes(means, i) != axes(R, i)
throw(DimensionMismatch("dimension $i of `mean` should have indices $(axes(R, i)), but got $(axes(means, i))"))
Expand Down
14 changes: 10 additions & 4 deletions src/Statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ centralize_sumabs2(A::AbstractArray, m, ifirst::Int, ilast::Int) =

function centralize_sumabs2!(R::AbstractArray{S}, A::AbstractArray, means::AbstractArray) where S
# following the implementation of _mapreducedim! at base/reducedim.jl
lsiz = Base.check_reducedims(R,A)
lsiz = Base._linear_reduction_length(A, axes(R))
for i in 1:max(ndims(R), ndims(means))
if axes(means, i) != axes(R, i)
throw(DimensionMismatch("dimension $i of `mean` should have indices $(axes(R, i)), but got $(axes(means, i))"))
Expand Down Expand Up @@ -339,9 +339,15 @@ over dimensions. In that case, `mean` must be an array with the same shape as
"""
varm(A::AbstractArray, m::AbstractArray; corrected::Bool=true, dims=:) = _varm(A, m, corrected, dims)

_varm(A::AbstractArray{T}, m, corrected::Bool, region) where {T} =
varm!(Base.reducedim_init(t -> abs2(t)/2, +, A, region), A, m; corrected=corrected)

function _varm(A::AbstractArray{T}, m, corrected::Bool, region) where {T}
outaxes = Base.reduced_indices(A, region)
nmT = Base.nonmissingtype(T)
base = typeof(abs2(zero(nmT)) / 2)
elT = (T == nmT) ? base : Union{Missing, base}
R = Base.mapreduce_similar(A, elT, outaxes)
fill!(R, zero(base))
return varm!(R, A, m; corrected=corrected)
end
varm(A::AbstractArray, m; corrected::Bool=true) = _varm(A, m, corrected, :)

function _varm(A::AbstractArray{T}, m, corrected::Bool, ::Colon) where T
Expand Down
15 changes: 9 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,9 @@ end
end
end

is_missing_or_approx_equals(X, Y) =
all(splat((x,y) -> isequal(x, y) || isapprox(x, y)), zip(X, Y))

function safe_cov(x, y, zm::Bool, cr::Bool)
n = length(x)
if !zm
Expand Down Expand Up @@ -522,9 +525,9 @@ Y = [6.0 2.0;
@testset "cov with missing" begin
@test cov([missing]) === cov([1, missing]) === missing
@test cov([1, missing], [2, 3]) === cov([1, 3], [2, missing]) === missing
@test_throws Exception cov([1 missing; 2 3])
@test_throws Exception cov([1 missing; 2 3], [1, 2])
@test_throws Exception cov([1, 2], [1 missing; 2 3])
@test isequal(cov([1 missing; 2 3]), [0.5 missing; missing missing])
@test isequal(cov([1 missing; 2 3], [1, 2]), [0.5; missing;;])
@test isequal(cov([1, 2], [1 missing; 2 3]), [0.5 missing])
@test isequal(cov([1 2; 2 3], [1, missing]), [missing missing]')
@test isequal(cov([1, missing], [1 2; 2 3]), [missing missing])
end
Expand Down Expand Up @@ -633,9 +636,9 @@ end
@test cor([missing]) === missing
@test cor([1, missing]) == 1
@test cor([1, missing], [2, 3]) === cor([1, 3], [2, missing]) === missing
@test_throws Exception cor([1 missing; 2 3])
@test_throws Exception cor([1 missing; 2 3], [1, 2])
@test_throws Exception cor([1, 2], [1 missing; 2 3])
@test is_missing_or_approx_equals(cor([1 missing; 2 3]), [1 missing; missing 1])
@test is_missing_or_approx_equals(cor([1 missing; 2 3], [1, 2]), [1; missing;;])
@test is_missing_or_approx_equals(cor([1, 2], [1 missing; 2 3]), transpose([1, missing]))
@test isequal(cor([1 2; 2 3], [1, missing]), [missing missing]')
@test isequal(cor([1, missing], [1 2; 2 3]), [missing missing])
end
Expand Down

0 comments on commit 465e3a0

Please sign in to comment.