From 8d577ab3f24212a9762ad6673e8a8fb0ecd5a140 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 14 Apr 2024 20:21:22 +0530 Subject: [PATCH] `zeros`/`ones`/`fill` may accept arbitrary axes that are supported by `similar` (#53965) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The idea is that functions like `zeros` are essentially constructing a container and filling it with a value. `similar` seems perfectly placed to construct such a container, so we may accept arbitrary axes in `zeros` as long as there's a corresponding `similar` method that is defined for the axes. Packages therefore would only need to define `similar`, and would get `zeros`/`ones` and `fill` for free. For example, the following will work after this: ```julia julia> using StaticArrays julia> zeros(SOneTo(2), 2) 2×2 Matrix{Float64}: 0.0 0.0 0.0 0.0 julia> zeros(SOneTo(2), Base.OneTo(2)) 2×2 Matrix{Float64}: 0.0 0.0 0.0 0.0 ``` Neither of these work on the current master, as `StaticArrays` doesn't define `zeros` for these combinations, even though it does define `similar`. One may argue for these methods to be added to `StaticArrays`, but this seems to be adding redundancy. The flip side is that `OffsetArrays` defines exactly these methods, so adding them to `Base` would break precompilation for the package. However, `OffsetArrays` really shouldn't be defining these methods, as this is type-piracy. The methods may be version-limited in `OffsetArrays` if this PR is merged. On the face of it, `trues` and `falses` should also work similarly, but currently these seem to be bypassing `similar` and constructing a `BitArray` explicitly. I have not added the corresponding methods for these functions, but they may be added as well. --- base/array.jl | 6 ++++++ base/bitarray.jl | 2 ++ test/abstractarray.jl | 22 ++++++++++++++++++++++ test/bitarray.jl | 22 ++++++++++++++++++++++ test/testhelpers/SizedArrays.jl | 15 +++++++++++++++ 5 files changed, 67 insertions(+) diff --git a/base/array.jl b/base/array.jl index 7676b380923ee..fb702725b389a 100644 --- a/base/array.jl +++ b/base/array.jl @@ -529,6 +529,7 @@ function fill end fill(v, dims::DimOrInd...) = fill(v, dims) fill(v, dims::NTuple{N, Union{Integer, OneTo}}) where {N} = fill(v, map(to_dim, dims)) fill(v, dims::NTuple{N, Integer}) where {N} = (a=Array{typeof(v),N}(undef, dims); fill!(a, v); a) +fill(v, dims::NTuple{N, DimOrInd}) where {N} = (a=similar(Array{typeof(v),N}, dims); fill!(a, v); a) fill(v, dims::Tuple{}) = (a=Array{typeof(v),0}(undef, dims); fill!(a, v); a) """ @@ -589,6 +590,11 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one)) fill!(a, $felt(T)) return a end + function $fname(::Type{T}, dims::NTuple{N, DimOrInd}) where {T,N} + a = similar(Array{T,N}, dims) + fill!(a, $felt(T)) + return a + end end end diff --git a/base/bitarray.jl b/base/bitarray.jl index 079dbefe03a94..f7eeafbb62231 100644 --- a/base/bitarray.jl +++ b/base/bitarray.jl @@ -404,6 +404,7 @@ falses(dims::DimOrInd...) = falses(dims) falses(dims::NTuple{N, Union{Integer, OneTo}}) where {N} = falses(map(to_dim, dims)) falses(dims::NTuple{N, Integer}) where {N} = fill!(BitArray(undef, dims), false) falses(dims::Tuple{}) = fill!(BitArray(undef, dims), false) +falses(dims::NTuple{N, DimOrInd}) where {N} = fill!(similar(BitArray, dims), false) """ trues(dims) @@ -422,6 +423,7 @@ trues(dims::DimOrInd...) = trues(dims) trues(dims::NTuple{N, Union{Integer, OneTo}}) where {N} = trues(map(to_dim, dims)) trues(dims::NTuple{N, Integer}) where {N} = fill!(BitArray(undef, dims), true) trues(dims::Tuple{}) = fill!(BitArray(undef, dims), true) +trues(dims::NTuple{N, DimOrInd}) where {N} = fill!(similar(BitArray, dims), true) function one(x::BitMatrix) m, n = size(x) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index a0a6ba6b2229a..feb6adaf39fdd 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -11,6 +11,9 @@ using .Main.StructArrays isdefined(Main, :FillArrays) || @eval Main include("testhelpers/FillArrays.jl") using .Main.FillArrays +isdefined(Main, :SizedArrays) || @eval Main include("testhelpers/SizedArrays.jl") +using .Main.SizedArrays + A = rand(5,4,3) @testset "Bounds checking" begin @test checkbounds(Bool, A, 1, 1, 1) == true @@ -2097,3 +2100,22 @@ end @test r2[i] == z[j] end end + +@testset "zero for arbitrary axes" begin + r = SizedArrays.SOneTo(2) + s = Base.OneTo(2) + _to_oneto(x::Integer) = Base.OneTo(2) + _to_oneto(x::Union{Base.OneTo, SizedArrays.SOneTo}) = x + for (f, v) in ((zeros, 0), (ones, 1), ((x...)->fill(3,x...),3)) + for ax in ((r,r), (s, r), (2, r)) + A = f(ax...) + @test axes(A) == map(_to_oneto, ax) + if all(x -> x isa SizedArrays.SOneTo, ax) + @test A isa SizedArrays.SizedArray && parent(A) isa Array + else + @test A isa Array + end + @test all(==(v), A) + end + end +end diff --git a/test/bitarray.jl b/test/bitarray.jl index 056a201bd4f6f..2cf285370441e 100644 --- a/test/bitarray.jl +++ b/test/bitarray.jl @@ -3,6 +3,9 @@ using Base: findprevnot, findnextnot using Random, LinearAlgebra, Test +isdefined(Main, :SizedArrays) || @eval Main include("testhelpers/SizedArrays.jl") +using .Main.SizedArrays + tc(r1::NTuple{N,Any}, r2::NTuple{N,Any}) where {N} = all(x->tc(x...), [zip(r1,r2)...]) tc(r1::BitArray{N}, r2::Union{BitArray{N},Array{Bool,N}}) where {N} = true tc(r1::SubArray{Bool,N1,BitArray{N2}}, r2::SubArray{Bool,N1,<:Union{BitArray{N2},Array{Bool,N2}}}) where {N1,N2} = true @@ -82,6 +85,25 @@ allsizes = [((), BitArray{0}), ((v1,), BitVector), @test !isassigned(b, length(b) + 1) end +@testset "trues and falses with custom axes" begin + for ax in ((SizedArrays.SOneTo(2),), (SizedArrays.SOneTo(2), Base.OneTo(2))) + t = trues(ax) + if all(x -> x isa SizedArrays.SOneTo, ax) + @test t isa SizedArrays.SizedArray && parent(t) isa BitArray + else + @test t isa BitArray + end + @test all(t) + + f = falses(ax) + if all(x -> x isa SizedArrays.SOneTo, ax) + @test t isa SizedArrays.SizedArray && parent(t) isa BitArray + else + @test t isa BitArray + end + @test !any(f) + end +end @testset "Conversions for size $sz" for (sz, T) in allsizes b1 = rand!(falses(sz...)) diff --git a/test/testhelpers/SizedArrays.jl b/test/testhelpers/SizedArrays.jl index 43bc27e630479..2d37cead61a08 100644 --- a/test/testhelpers/SizedArrays.jl +++ b/test/testhelpers/SizedArrays.jl @@ -43,10 +43,25 @@ Base.size(a::SizedArray) = size(typeof(a)) Base.size(::Type{<:SizedArray{SZ}}) where {SZ} = SZ Base.axes(a::SizedArray) = map(SOneTo, size(a)) Base.getindex(A::SizedArray, i...) = getindex(A.data, i...) +Base.setindex!(A::SizedArray, v, i...) = setindex!(A.data, v, i...) Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T))) +Base.parent(S::SizedArray) = S.data +(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data) ==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data +homogenize_shape(t::Tuple) = (_homogenize_shape(first(t)), homogenize_shape(Base.tail(t))...) +homogenize_shape(::Tuple{}) = () +_homogenize_shape(x::Integer) = x +_homogenize_shape(x::AbstractUnitRange) = length(x) +const Dims = Union{Integer, Base.OneTo, SOneTo} +function Base.similar(::Type{A}, shape::Tuple{Dims, Vararg{Dims}}) where {A<:AbstractArray} + similar(A, homogenize_shape(shape)) +end +function Base.similar(::Type{A}, shape::Tuple{SOneTo, Vararg{SOneTo}}) where {A<:AbstractArray} + R = similar(A, length.(shape)) + SizedArray{length.(shape)}(R) +end + const SizedMatrixLike = Union{SizedMatrix, Transpose{<:Any, <:SizedMatrix}, Adjoint{<:Any, <:SizedMatrix}} _data(S::SizedArray) = S.data