From 0c99a38a3e742263b216b51ded19b17afeadc116 Mon Sep 17 00:00:00 2001 From: Daniel Ingraham Date: Wed, 2 Oct 2024 14:37:54 -0400 Subject: [PATCH] Add tests for stacking issue #254 --- src/array_interface.jl | 23 +++++++++++++++ test/runtests.jl | 66 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/src/array_interface.jl b/src/array_interface.jl index c9060eb..80a83b8 100644 --- a/src/array_interface.jl +++ b/src/array_interface.jl @@ -79,6 +79,29 @@ function Base.permutedims(x::ComponentArray, dims) return ComponentArray(permutedims(getdata(x), dims), map(i->axs[i], dims)...) end +function stack_ca(iter; dims=:) + # Check if all the component axes are the same in iter. + if all(arr -> getaxes(arr) == getaxes(first(iter)), iter) + # Create a new array with `FlatAxis()` in the `dims` dimension. + return _stack_ca(dims, iter) + else + # Fall back to plain arrays. + return stack(getdata.(iter); dims=dims) + end +end + +function _stack_ca(dims::Colon, iter) + oldax = getaxes(first(iter)) + return ComponentArray(stack(getdata.(iter)), oldax..., FlatAxis()) +end + +function _stack_ca(dims::Integer, iter) + oldax = getaxes(first(iter)) + oldndims = length(oldax) + outax = ntuple(d -> d==dims ? FlatAxis() : oldax[d - (d>dims)], oldndims+1) + return ComponentArray(stack(getdata.(iter); dims=dims), outax...) +end + ## Indexing Base.IndexStyle(::Type{<:ComponentArray{T,N,<:A,<:Axes}}) where {T,N,A,Axes} = IndexStyle(A) diff --git a/test/runtests.jl b/test/runtests.jl index c659656..98cf785 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -745,6 +745,72 @@ end XY_d1 = stack([X, Y]; dims=1) @test all(XY_d1[1, :a, :b] .== XY[:a, :b, 1]) @test all(XY_d1[2, :a, :b] .== XY[:a, :b, 2]) + + # Issue #254, tuple of arrays: + x = ComponentVector(a=[1, 2]) + y = ComponentVector(b=[3, 4]) + Xstack1 = stack((x, y, x); dims=1) + Xstack1_noca = stack((getdata(x), getdata(y), getdata(x)); dims=1) + @test all(Xstack1 .== Xstack1_noca) + @test all(Xstack1[1, :a] .== Xstack1_noca[1, :]) + @test all(Xstack1[2, :a] .== Xstack1_noca[2, :]) + + # Issue #254, Array of tuples. + Xstack2 = stack(ComponentArray(a=(1,2,3), b=(4,5,6))) + Xstack2_noca = stack([(1,2,3), (4,5,6)]) + @test all(Xstack2 .== Xstack2_noca) + @test all(Xstack2[:, :a] .== Xstack2_noca[:, 1]) + @test all(Xstack2[:, :b] .== Xstack2_noca[:, 2]) + + Xstack2_d1 = stack(ComponentArray(a=(1,2,3), b=(4,5,6)); dims=1) + Xstack2_noca_d1 = stack([(1,2,3), (4,5,6)]; dims=1) + @test all(Xstack2_d1 .== Xstack2_noca_d1) + @test all(Xstack2_d1[:a, :] .== Xstack2_noca_d1[1, :]) + @test all(Xstack2_d1[:b, :] .== Xstack2_noca_d1[2, :]) + + # Issue #254, generator of arrays. + Xstack3 = stack(ComponentArray(z=[x,x]) for x in 1:4) + Xstack3_noca = stack([x, x] for x in 1:4) + # That should give me + # [1 2 3 4; + # 1 2 3 4] + @test all(Xstack3 .== Xstack3_noca) + @test all(Xstack3[:z, 1] .== Xstack3_noca[:, 1]) + @test all(Xstack3[:z, 2] .== Xstack3_noca[:, 2]) + @test all(Xstack3[:z, 3] .== Xstack3_noca[:, 3]) + @test all(Xstack3[:z, 4] .== Xstack3_noca[:, 4]) + + Xstack3_d1 = stack(ComponentArray(z=[x,x]) for x in 1:4; dims=1) + Xstack3_noca_d1 = stack([x, x] for x in 1:4; dims=1) + # That should give me + # [1 1; + # 2 2; + # 3 3; + # 4 4;] + @test all(Xstack3_d1 .== Xstack3_noca_d1) + @test all(Xstack3_d1[1, :z] .== Xstack3_noca_d1[1, :]) + @test all(Xstack3_d1[2, :z] .== Xstack3_noca_d1[2, :]) + @test all(Xstack3_d1[3, :z] .== Xstack3_noca_d1[3, :]) + @test all(Xstack3_d1[4, :z] .== Xstack3_noca_d1[4, :]) + + # Issue #254, map then stack. + Xstack4_d1 = stack(x -> ComponentArray(a=x, b=[x+1,x+2]), [5 6; 7 8]; dims=1) # map then stack + Xstack4_noca_d1 = stack(x -> [x, x+1, x+2], [5 6; 7 8]; dims=1) # map then stack + @test all(Xstack4_d1 .== Xstack4_noca_d1) + @test all(Xstack4_d1[:, :a] .== Xstack4_noca_d1[:, 1]) + @test all(Xstack4_d1[:, :b] .== Xstack4_noca_d1[:, 2:3]) + + Xstack4_d2 = stack(x -> ComponentArray(a=x, b=[x+1,x+2]), [5 6; 7 8]; dims=2) # map then stack + Xstack4_noca_d2 = stack(x -> [x, x+1, x+2], [5 6; 7 8]; dims=2) # map then stack + @test all(Xstack4_d2 .== Xstack4_noca_d2) + @test all(Xstack4_d2[:a, :] .== Xstack4_noca_d2[1, :]) + @test all(Xstack4_d2[:b, :] .== Xstack4_noca_d2[2:3, :]) + + Xstack4_dcolon = stack(x -> ComponentArray(a=x, b=[x+1,x+2]), [5 6; 7 8]; dims=:) # map then stack + Xstack4_noca_dcolon = stack(x -> [x, x+1, x+2], [5 6; 7 8]; dims=:) # map then stack + @test all(Xstack4_dcolon .== Xstack4_noca_dcolon) + @test all(Xstack4_dcolon[:a, :, :] .== Xstack4_noca_dcolon[1, :, :]) + @test all(Xstack4_dcolon[:b, :, :] .== Xstack4_noca_dcolon[2:3, :, :]) end @testset "axpy! / axpby!" begin