Skip to content

Commit

Permalink
Add tests for stacking issue SciML#254
Browse files Browse the repository at this point in the history
  • Loading branch information
dingraha committed Oct 2, 2024
1 parent f60730a commit 0c99a38
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/array_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,29 @@ function Base.permutedims(x::ComponentArray, dims)
return ComponentArray(permutedims(getdata(x), dims), map(i->axs[i], dims)...)
end

function stack_ca(iter; dims=:)
# Check if all the component axes are the same in iter.
if all(arr -> getaxes(arr) == getaxes(first(iter)), iter)
# Create a new array with `FlatAxis()` in the `dims` dimension.
return _stack_ca(dims, iter)
else
# Fall back to plain arrays.
return stack(getdata.(iter); dims=dims)
end
end

function _stack_ca(dims::Colon, iter)
oldax = getaxes(first(iter))
return ComponentArray(stack(getdata.(iter)), oldax..., FlatAxis())
end

function _stack_ca(dims::Integer, iter)
oldax = getaxes(first(iter))
oldndims = length(oldax)
outax = ntuple(d -> d==dims ? FlatAxis() : oldax[d - (d>dims)], oldndims+1)
return ComponentArray(stack(getdata.(iter); dims=dims), outax...)
end

## Indexing
Base.IndexStyle(::Type{<:ComponentArray{T,N,<:A,<:Axes}}) where {T,N,A,Axes} = IndexStyle(A)

Expand Down
66 changes: 66 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,72 @@ end
XY_d1 = stack([X, Y]; dims=1)
@test all(XY_d1[1, :a, :b] .== XY[:a, :b, 1])
@test all(XY_d1[2, :a, :b] .== XY[:a, :b, 2])

# Issue #254, tuple of arrays:
x = ComponentVector(a=[1, 2])
y = ComponentVector(b=[3, 4])
Xstack1 = stack((x, y, x); dims=1)
Xstack1_noca = stack((getdata(x), getdata(y), getdata(x)); dims=1)
@test all(Xstack1 .== Xstack1_noca)
@test all(Xstack1[1, :a] .== Xstack1_noca[1, :])
@test all(Xstack1[2, :a] .== Xstack1_noca[2, :])

# Issue #254, Array of tuples.
Xstack2 = stack(ComponentArray(a=(1,2,3), b=(4,5,6)))
Xstack2_noca = stack([(1,2,3), (4,5,6)])
@test all(Xstack2 .== Xstack2_noca)
@test all(Xstack2[:, :a] .== Xstack2_noca[:, 1])
@test all(Xstack2[:, :b] .== Xstack2_noca[:, 2])

Xstack2_d1 = stack(ComponentArray(a=(1,2,3), b=(4,5,6)); dims=1)
Xstack2_noca_d1 = stack([(1,2,3), (4,5,6)]; dims=1)
@test all(Xstack2_d1 .== Xstack2_noca_d1)
@test all(Xstack2_d1[:a, :] .== Xstack2_noca_d1[1, :])
@test all(Xstack2_d1[:b, :] .== Xstack2_noca_d1[2, :])

# Issue #254, generator of arrays.
Xstack3 = stack(ComponentArray(z=[x,x]) for x in 1:4)
Xstack3_noca = stack([x, x] for x in 1:4)
# That should give me
# [1 2 3 4;
# 1 2 3 4]
@test all(Xstack3 .== Xstack3_noca)
@test all(Xstack3[:z, 1] .== Xstack3_noca[:, 1])
@test all(Xstack3[:z, 2] .== Xstack3_noca[:, 2])
@test all(Xstack3[:z, 3] .== Xstack3_noca[:, 3])
@test all(Xstack3[:z, 4] .== Xstack3_noca[:, 4])

Xstack3_d1 = stack(ComponentArray(z=[x,x]) for x in 1:4; dims=1)
Xstack3_noca_d1 = stack([x, x] for x in 1:4; dims=1)
# That should give me
# [1 1;
# 2 2;
# 3 3;
# 4 4;]
@test all(Xstack3_d1 .== Xstack3_noca_d1)
@test all(Xstack3_d1[1, :z] .== Xstack3_noca_d1[1, :])
@test all(Xstack3_d1[2, :z] .== Xstack3_noca_d1[2, :])
@test all(Xstack3_d1[3, :z] .== Xstack3_noca_d1[3, :])
@test all(Xstack3_d1[4, :z] .== Xstack3_noca_d1[4, :])

# Issue #254, map then stack.
Xstack4_d1 = stack(x -> ComponentArray(a=x, b=[x+1,x+2]), [5 6; 7 8]; dims=1) # map then stack
Xstack4_noca_d1 = stack(x -> [x, x+1, x+2], [5 6; 7 8]; dims=1) # map then stack
@test all(Xstack4_d1 .== Xstack4_noca_d1)
@test all(Xstack4_d1[:, :a] .== Xstack4_noca_d1[:, 1])
@test all(Xstack4_d1[:, :b] .== Xstack4_noca_d1[:, 2:3])

Xstack4_d2 = stack(x -> ComponentArray(a=x, b=[x+1,x+2]), [5 6; 7 8]; dims=2) # map then stack
Xstack4_noca_d2 = stack(x -> [x, x+1, x+2], [5 6; 7 8]; dims=2) # map then stack
@test all(Xstack4_d2 .== Xstack4_noca_d2)
@test all(Xstack4_d2[:a, :] .== Xstack4_noca_d2[1, :])
@test all(Xstack4_d2[:b, :] .== Xstack4_noca_d2[2:3, :])

Xstack4_dcolon = stack(x -> ComponentArray(a=x, b=[x+1,x+2]), [5 6; 7 8]; dims=:) # map then stack
Xstack4_noca_dcolon = stack(x -> [x, x+1, x+2], [5 6; 7 8]; dims=:) # map then stack
@test all(Xstack4_dcolon .== Xstack4_noca_dcolon)
@test all(Xstack4_dcolon[:a, :, :] .== Xstack4_noca_dcolon[1, :, :])
@test all(Xstack4_dcolon[:b, :, :] .== Xstack4_noca_dcolon[2:3, :, :])
end

@testset "axpy! / axpby!" begin
Expand Down

0 comments on commit 0c99a38

Please sign in to comment.