From 39a038743a6a5055a6059f389789c71b052b470e Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Sun, 21 Feb 2021 22:02:02 +0100 Subject: [PATCH 1/6] Add cat,vcat and hcat for Fill --- src/FillArrays.jl | 1 + src/fillcat.jl | 16 ++++++++++++++ test/runtests.jl | 56 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+) create mode 100644 src/fillcat.jl diff --git a/src/FillArrays.jl b/src/FillArrays.jl index cbf9405a..ddd10c35 100644 --- a/src/FillArrays.jl +++ b/src/FillArrays.jl @@ -589,6 +589,7 @@ end include("fillalgebra.jl") include("fillbroadcast.jl") include("trues.jl") +include("fillcat.jl") ## # print diff --git a/src/fillcat.jl b/src/fillcat.jl new file mode 100644 index 00000000..69eb32c0 --- /dev/null +++ b/src/fillcat.jl @@ -0,0 +1,16 @@ + +function Base.cat_t(::Type{T}, fs::Fill...; dims) where T + allvals = unique([f.value for f in fs]) + length(allvals) > 1 && return Base._cat_t(dims, T, fs...) + + catdims = Base.dims2cat(dims) + + # Note, when dims is a tuple the output gets zero padded and we can't use a Fill unless it is all zeros too + allvals[] !== zero(T) && sum(catdims) > 1 && return Base._cat_t(dims, T, fs...) + + shape = Base.cat_shape(catdims, map(Base.cat_size, fs)::Tuple{Vararg{Union{Int,Dims}}})::Dims + return Fill(convert(T, fs[1].value), shape) +end + +Base.vcat(vs::Fill...) = cat(vs...;dims=Val(1)) +Base.hcat(vs::Fill...) = cat(vs...;dims=Val(2)) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 37d32c82..a43a8708 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1232,3 +1232,59 @@ end @test convert(Fill, transpose(a)) ≡ Fill(2.0,1,5) end end + +@testset "Concatenation" begin + + @testset "Fill" begin + @testset "cat shape $s" for s in + ( + 0, + 1, + (2, 0), + (2, 3), + (0, 2), + (2,3,4) + ) + + @testset "Dim $dims" for dims in (1,2,3, Val(4)) + res = cat(Fill(1, s), Fill(1, s); dims=dims) + + @test res isa Fill + @test res == cat(fill(1, s), fill(1, s); dims=dims) + + res = cat(Fill(1.0, s), Fill(1, s); dims=dims) + + @test res isa Fill + @test res == cat(fill(1.0, s), fill(1, s); dims=dims) + + @test cat(Fill(1, s), Fill(2, s);dims=dims) == cat(fill(1, s), fill(2, s);dims=dims) + end + @testset "Dim $dims" for dims in ( + (1,2), + (2,3), + (1,3,4) + ) + # This inserts a bunch of zeros so we can no longer assume the answer is a Fill + @test cat(Fill(1, s), Fill(1, s); dims=dims) == cat(fill(1, s), fill(1,s); dims=dims) + @test cat(Fill(0, s), Fill(0, s); dims=dims) isa Fill + @test cat(Fill(0.0, s), Fill(0.0, s); dims=dims) isa Fill + + @test cat(Fill(1, s), Fill(2, s);dims=dims) == cat(fill(1, s), fill(2, s);dims=dims) + end + end + + @testset "vcat" begin + # Vcat just delegates to cat, so we basically just test that here + res = vcat(Fill(1, 3), Fill(1, 4)) + @test res isa Fill + @test res == vcat(fill(1, 3), fill(1,4)) + end + + @testset "hcat" begin + # Vcat just delegates to cat, so we basically just test that here + res = hcat(Fill(1, 2), Fill(1, 2)) + @test res isa Fill + @test res == hcat(fill(1, 2), fill(1,2)) + end + end +end From 88f7c70dcfac072e43f368fa440f8df82b762ad6 Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Sun, 21 Feb 2021 22:34:50 +0100 Subject: [PATCH 2/6] Add cat, vcat, hcat for Zeros --- src/fillcat.jl | 15 +++++++++++++-- test/runtests.jl | 50 +++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/src/fillcat.jl b/src/fillcat.jl index 69eb32c0..dfd0ec44 100644 --- a/src/fillcat.jl +++ b/src/fillcat.jl @@ -5,7 +5,7 @@ function Base.cat_t(::Type{T}, fs::Fill...; dims) where T catdims = Base.dims2cat(dims) - # Note, when dims is a tuple the output gets zero padded and we can't use a Fill unless it is all zeros too + # Note, when dims is a tuple the output gets zero padded and we can't use a Fill unless it is all zeros allvals[] !== zero(T) && sum(catdims) > 1 && return Base._cat_t(dims, T, fs...) shape = Base.cat_shape(catdims, map(Base.cat_size, fs)::Tuple{Vararg{Union{Int,Dims}}})::Dims @@ -13,4 +13,15 @@ function Base.cat_t(::Type{T}, fs::Fill...; dims) where T end Base.vcat(vs::Fill...) = cat(vs...;dims=Val(1)) -Base.hcat(vs::Fill...) = cat(vs...;dims=Val(2)) \ No newline at end of file +Base.hcat(vs::Fill...) = cat(vs...;dims=Val(2)) + + +function Base.cat_t(::Type{T}, fs::Zeros...; dims) where T + catdims = Base.dims2cat(dims) + shape = Base.cat_shape(catdims, map(Base.cat_size, fs)::Tuple{Vararg{Union{Int,Dims}}})::Dims + return Zeros{T}(shape) +end + +Base.vcat(vs::Zeros...) = cat(vs...;dims=Val(1)) +Base.hcat(vs::Zeros...) = cat(vs...;dims=Val(2)) + diff --git a/test/runtests.jl b/test/runtests.jl index a43a8708..1063148c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1241,19 +1241,16 @@ end 0, 1, (2, 0), - (2, 3), (0, 2), (2,3,4) ) @testset "Dim $dims" for dims in (1,2,3, Val(4)) res = cat(Fill(1, s), Fill(1, s); dims=dims) - @test res isa Fill @test res == cat(fill(1, s), fill(1, s); dims=dims) res = cat(Fill(1.0, s), Fill(1, s); dims=dims) - @test res isa Fill @test res == cat(fill(1.0, s), fill(1, s); dims=dims) @@ -1274,17 +1271,60 @@ end end @testset "vcat" begin - # Vcat just delegates to cat, so we basically just test that here + # vcat just delegates to cat, so we basically just test that here res = vcat(Fill(1, 3), Fill(1, 4)) @test res isa Fill @test res == vcat(fill(1, 3), fill(1,4)) end @testset "hcat" begin - # Vcat just delegates to cat, so we basically just test that here + # hcat just delegates to cat, so we basically just test that here res = hcat(Fill(1, 2), Fill(1, 2)) @test res isa Fill @test res == hcat(fill(1, 2), fill(1,2)) end end + + @testset "Zeros" begin + @testset "cat shape $s" for s in + ( + 0, + 1, + (2, 0), + (0, 2), + (2,3,4) + ) + + @testset "Dim $dims" for dims in ( + 1, + 2, + Val(3), + (1,2), + (2,3) + ) + res = cat(Zeros(s), Zeros(s); dims=dims) + @test res isa Zeros + @test res == cat(zeros(s), zeros(s); dims=dims) + + res = cat(Zeros{Float64}(s), Zeros{Int}(s); dims=dims) + @test res isa Zeros + @test res == cat(zeros(Float64, s), zeros(Int, s); dims=dims) + end + end + + @testset "vcat" begin + # vcat just delegates to cat, so we basically just test that here + res = vcat(Zeros(3), Zeros(4)) + @test res isa Zeros + @test res == vcat(zeros(3), zeros(4)) + end + + @testset "hcat" begin + # hcat just delegates to cat, so we basically just test that here + res = vcat(Zeros(2), Zeros(2)) + @test res isa Zeros + @test res == vcat(zeros(2), zeros(2)) + end + end + end From a8082c83c755c24eca4d9b515daa03d9ab80e3fa Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Sun, 21 Feb 2021 22:45:15 +0100 Subject: [PATCH 3/6] Add cat, vcat and hcat for Ones --- src/fillcat.jl | 18 +++++++++++++++++- test/runtests.jl | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/src/fillcat.jl b/src/fillcat.jl index dfd0ec44..40d1eebf 100644 --- a/src/fillcat.jl +++ b/src/fillcat.jl @@ -5,7 +5,8 @@ function Base.cat_t(::Type{T}, fs::Fill...; dims) where T catdims = Base.dims2cat(dims) - # Note, when dims is a tuple the output gets zero padded and we can't use a Fill unless it is all zeros + # When dims is a tuple the output gets zero padded and we can't use a Fill unless it is all zeros + # There might be some cases when it does not get padded which are not considered here allvals[] !== zero(T) && sum(catdims) > 1 && return Base._cat_t(dims, T, fs...) shape = Base.cat_shape(catdims, map(Base.cat_size, fs)::Tuple{Vararg{Union{Int,Dims}}})::Dims @@ -25,3 +26,18 @@ end Base.vcat(vs::Zeros...) = cat(vs...;dims=Val(1)) Base.hcat(vs::Zeros...) = cat(vs...;dims=Val(2)) + +function Base.cat_t(::Type{T}, fs::Ones...; dims) where T + catdims = Base.dims2cat(dims) + + # When dims is a tuple the output gets zero padded so we can't return a Ones + # There might be some cases when it does not get padded which are not considered here + sum(catdims) > 1 && return Base._cat_t(dims, T, fs...) + + shape = Base.cat_shape(catdims, map(Base.cat_size, fs)::Tuple{Vararg{Union{Int,Dims}}})::Dims + return Ones{T}(shape) +end + +Base.vcat(vs::Ones...) = cat(vs...;dims=Val(1)) +Base.hcat(vs::Ones...) = cat(vs...;dims=Val(2)) + diff --git a/test/runtests.jl b/test/runtests.jl index 1063148c..a9810540 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1327,4 +1327,47 @@ end end end + @testset "Ones" begin + @testset "cat shape $s" for s in + ( + 0, + 1, + (2, 0), + (0, 2), + (2,3,4) + ) + + @testset "Dim $dims" for dims in (1,2,3, Val(4)) + res = cat(Ones(s), Ones(s); dims=dims) + @test res isa Ones + @test res == cat(ones(s), ones(s); dims=dims) + + res = cat(Ones{Float64}(s), Ones{Int}(s); dims=dims) + @test res isa Ones + @test res == cat(ones(Float64, s), ones(Int, s); dims=dims) + end + @testset "Dim $dims" for dims in ( + (1,2), + (2,3), + (1,3,4) + ) + # This inserts a bunch of zeros so we can no longer assume the answer is a Fill + @test cat(Ones(s), Ones(s); dims=dims) == cat(ones(s), ones(s); dims=dims) + end + end + + @testset "vcat" begin + # vcat just delegates to cat, so we basically just test that here + res = vcat(Ones(3), Ones(4)) + @test res isa Ones + @test res == vcat(ones(3), fill(1,4)) + end + + @testset "hcat" begin + # hcat just delegates to cat, so we basically just test that here + res = hcat(Ones(2), Ones(2)) + @test res isa Ones + @test res == hcat(ones(2), fill(1,2)) + end + end end From 772d3be2fd7c498073629adf85512143e67ed103 Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Sun, 21 Feb 2021 22:51:06 +0100 Subject: [PATCH 4/6] Add testcase for cat with dims iterator --- test/runtests.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index a9810540..5d1d437b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1259,7 +1259,8 @@ end @testset "Dim $dims" for dims in ( (1,2), (2,3), - (1,3,4) + (1,3,4), + Iterators.take(3:5, 2) ) # This inserts a bunch of zeros so we can no longer assume the answer is a Fill @test cat(Fill(1, s), Fill(1, s); dims=dims) == cat(fill(1, s), fill(1,s); dims=dims) @@ -1300,7 +1301,8 @@ end 2, Val(3), (1,2), - (2,3) + (2,3), + Iterators.take(3:5, 2) ) res = cat(Zeros(s), Zeros(s); dims=dims) @test res isa Zeros @@ -1349,7 +1351,8 @@ end @testset "Dim $dims" for dims in ( (1,2), (2,3), - (1,3,4) + (1,3,4), + Iterators.take(3:5, 2) ) # This inserts a bunch of zeros so we can no longer assume the answer is a Fill @test cat(Ones(s), Ones(s); dims=dims) == cat(ones(s), ones(s); dims=dims) From 7bca7093fe7b81d2f4b776a714a1af6136b8f366 Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Sun, 21 Feb 2021 23:35:59 +0100 Subject: [PATCH 5/6] Fix cat_shape for julia < 1.6 --- src/fillcat.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/fillcat.jl b/src/fillcat.jl index 40d1eebf..7ea486ad 100644 --- a/src/fillcat.jl +++ b/src/fillcat.jl @@ -9,7 +9,7 @@ function Base.cat_t(::Type{T}, fs::Fill...; dims) where T # There might be some cases when it does not get padded which are not considered here allvals[] !== zero(T) && sum(catdims) > 1 && return Base._cat_t(dims, T, fs...) - shape = Base.cat_shape(catdims, map(Base.cat_size, fs)::Tuple{Vararg{Union{Int,Dims}}})::Dims + shape = cat_shape_fill(catdims, fs) return Fill(convert(T, fs[1].value), shape) end @@ -19,7 +19,7 @@ Base.hcat(vs::Fill...) = cat(vs...;dims=Val(2)) function Base.cat_t(::Type{T}, fs::Zeros...; dims) where T catdims = Base.dims2cat(dims) - shape = Base.cat_shape(catdims, map(Base.cat_size, fs)::Tuple{Vararg{Union{Int,Dims}}})::Dims + shape = cat_shape_fill(catdims, fs) return Zeros{T}(shape) end @@ -34,10 +34,16 @@ function Base.cat_t(::Type{T}, fs::Ones...; dims) where T # There might be some cases when it does not get padded which are not considered here sum(catdims) > 1 && return Base._cat_t(dims, T, fs...) - shape = Base.cat_shape(catdims, map(Base.cat_size, fs)::Tuple{Vararg{Union{Int,Dims}}})::Dims + shape = cat_shape_fill(catdims, fs) return Ones{T}(shape) end Base.vcat(vs::Ones...) = cat(vs...;dims=Val(1)) Base.hcat(vs::Ones...) = cat(vs...;dims=Val(2)) + +if VERSION < v"1.6-" + cat_shape_fill(catdims, fs) = Base.cat_shape(catdims, (), map(Base.cat_size, fs)...) +else + cat_shape_fill(catdims, fs) = Base.cat_shape(catdims, map(Base.cat_size, fs)::Tuple{Vararg{Union{Int,Dims}}})::Dims +end From 06724c6ce078c5e402d8beaac7ba5cd7998474ef Mon Sep 17 00:00:00 2001 From: DrChainsaw Date: Sun, 21 Feb 2021 23:58:45 +0100 Subject: [PATCH 6/6] Add handling in cat for non-numeric element types --- src/fillcat.jl | 6 +++++- test/runtests.jl | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/fillcat.jl b/src/fillcat.jl index 7ea486ad..4f30d824 100644 --- a/src/fillcat.jl +++ b/src/fillcat.jl @@ -7,7 +7,11 @@ function Base.cat_t(::Type{T}, fs::Fill...; dims) where T # When dims is a tuple the output gets zero padded and we can't use a Fill unless it is all zeros # There might be some cases when it does not get padded which are not considered here - allvals[] !== zero(T) && sum(catdims) > 1 && return Base._cat_t(dims, T, fs...) + + if sum(catdims) > 1 + allvals[] isa Number || return Base._cat_t(dims, T, fs...) + allvals[] !== zero(T) && return Base._cat_t(dims, T, fs...) + end shape = cat_shape_fill(catdims, fs) return Fill(convert(T, fs[1].value), shape) diff --git a/test/runtests.jl b/test/runtests.jl index 5d1d437b..1d8e7bff 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1254,6 +1254,10 @@ end @test res isa Fill @test res == cat(fill(1.0, s), fill(1, s); dims=dims) + res = cat(Fill(:a, s), Fill(:a, s); dims=dims) + @test res isa Fill + @test res == cat(fill(:a, s), fill(:a, s);dims=dims) + @test cat(Fill(1, s), Fill(2, s);dims=dims) == cat(fill(1, s), fill(2, s);dims=dims) end @testset "Dim $dims" for dims in (