Skip to content

Commit

Permalink
overhaul constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
colinxs authored and oschulz committed Apr 11, 2020
1 parent f8de994 commit 119b7d5
Showing 1 changed file with 37 additions and 39 deletions.
76 changes: 37 additions & 39 deletions src/elasticarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
using Base: @propagate_inbounds
using Base.MultiplicativeInverses: SignedMultiplicativeInverse

export ElasticArray


"""
ElasticArray{T,N,M} <: DenseArray{T,N}
Expand All @@ -16,55 +18,45 @@ Constructors:
convert(ElasticArray, A::AbstractArray)
"""
struct ElasticArray{T,N,M,V<:DenseVector{T}} <: DenseArray{T,N}
kernel_size::NTuple{M,Int}
kernel_size::Dims{M}
kernel_length::SignedMultiplicativeInverse{Int}
data::V
function ElasticArray{T,N,M,V}(kernel_size, kernel_length, data) where {T,N,M,V}
if M::Int != N::Int - 1
throw(ArgumentError("ElasticArray parameter M does not satisfy requirement M == N - 1"))
end
if rem(length(eachindex(data)), kernel_length) != 0
throw(ArgumentError("length(data) must be integer multiple of prod(kernel_size)"))
end
new(kernel_size, kernel_length, data)
end
end

export ElasticArray


function ElasticArray{T}(::UndefInitializer, dims::Integer...) where {T}
kernel_size, size_lastdim = _split_dims(dims)
kernel_length = prod(kernel_size)
data = Vector{T}(undef, kernel_length * size_lastdim)
ElasticArray{T,length(dims),length(kernel_size),Vector{T}}(
kernel_size,
SignedMultiplicativeInverse{Int}(kernel_length),
data
)
end

ElasticArray{T,N}(A::AbstractArray{U,N}) where {T,N,U} = copyto!(ElasticArray{T}(undef, size(A)...), A)
ElasticArray{T}(A::AbstractArray{U,N}) where {T,N,U} = ElasticArray{T,N}(A)
ElasticArray(A::AbstractArray{T,N}) where {T,N} = ElasticArray{T,N}(A)

function ElasticArray{T,N,M}(A::AbstractArray) where {T,N,M}
M == N - 1 || throw(ArgumentError("ElasticArray{T,N=$N,M=$M} does not satisfy requirement M == N-1"))
ElasticArray{T,N}(A)
function ElasticArray{T,N,M,V}(::UndefInitializer, dims::NTuple{N,Integer}) where {T,N,M,V}
kernel_size, size_lastdim = _split_dims(dims)
kernel_length = prod(kernel_size)
data = similar(V, kernel_length * size_lastdim)
ElasticArray{T,N,M,V}(kernel_size, SignedMultiplicativeInverse{Int}(kernel_length), data)
end
ElasticArray{T,N,M}(::UndefInitializer, dims::NTuple{N,Integer}) where {T,N,M} = ElasticArray{T,N,M,Vector{T}}(undef, dims)
ElasticArray{T,N}(::UndefInitializer, dims::NTuple{N,Integer}) where {T,N} = ElasticArray{T,N,N-1}(undef, dims)
ElasticArray{T}(::UndefInitializer, dims::NTuple{N,Integer}) where {T,N} = ElasticArray{T,N}(undef, dims)

ElasticArray(kernel_size::NTuple{M,Int}, kernel_length::SignedMultiplicativeInverse{Int}, data::V) where {T,M,V<:DenseVector{T}} = ElasticArray{T,M+1,M,V}(kernel_size, kernel_length, data)

ElasticArray{T,N,M,V}(::UndefInitializer, dims::Vararg{Integer,N}) where {T,N,M,V} = ElasticArray{T,N,M,V}(undef, dims)
ElasticArray{T,N,M}(::UndefInitializer, dims::Vararg{Integer,N}) where {T,N,M,V} = ElasticArray{T,N,M}(undef, dims)
ElasticArray{T,N}(::UndefInitializer, dims::Vararg{Integer,N}) where {T,N,M,V} = ElasticArray{T,N}(undef, dims)
ElasticArray{T}(::UndefInitializer, dims::Vararg{Integer,N}) where {T,N,M,V} = ElasticArray{T,N}(undef, dims)

Base.convert(::Type{ElasticArray{T,N,M}}, A::ElasticArray{T,N,M}) where {T,N,M} = A
Base.convert(::Type{ElasticArray{T,N,M}}, A::AbstractArray) where {T,N,M} = ElasticArray{T,N,M}(A)
ElasticArray{T,N,M,V}(A::AbstractArray{<:Any,N}) where {T,N,M,V} = copyto!(ElasticArray{T,N,M,V}(undef, size(A)), A)
ElasticArray{T,N,M}(A::AbstractArray{<:Any,N}) where {T,N,M} = copyto!(ElasticArray{T,N,M}(undef, size(A)), A)
ElasticArray{T,N}(A::AbstractArray{<:Any,N}) where {T,N} = copyto!(ElasticArray{T,N}(undef, size(A)), A)
ElasticArray{T}(A::AbstractArray{<:Any,N}) where {T,N} = copyto!(ElasticArray{T,N}(undef, size(A)), A)
ElasticArray(A::AbstractArray{T,N}) where {T,N} = copyto!(ElasticArray{T,N}(undef, size(A)), A)

Base.convert(::Type{ElasticArray{T,N}}, A::ElasticArray{T,N}) where {T,N} = A
Base.convert(::Type{ElasticArray{T,N}}, A::AbstractArray) where {T,N} = ElasticArray{T,N}(A)
Base.convert(::Type{T}, A::AbstractArray) where {T<:ElasticArray} = A isa T ? A : T(A)

Base.convert(::Type{ElasticArray{T}}, A::ElasticArray{T}) where {T} = A
Base.convert(::Type{ElasticArray{T}}, A::AbstractArray) where {T} = ElasticArray{T}(A)

Base.convert(::Type{ElasticArray}, A::ElasticArray) = A
Base.convert(::Type{ElasticArray}, A::AbstractArray) = ElasticArray(A)


function _split_resize_dims(A::ElasticArray, dims::NTuple{N,Integer}) where {N}
kernel_size, size_lastdim = _split_dims(dims)
kernel_size != A.kernel_size && throw(ArgumentError("Can only resize last dimension of an ElasticArray"))
kernel_size, size_lastdim
end


import Base.==
Expand All @@ -83,19 +75,25 @@ Base.length(A::ElasticArray) = length(A.data)

Base.dataids(A::ElasticArray) = Base.dataids(A.data)


@inline function Base.resize!(A::ElasticArray{T,N}, dims::Vararg{Integer,N}) where {T,N}
kernel_size, size_lastdim = _split_resize_dims(A, dims)
resize!(A.data, A.kernel_length.divisor * size_lastdim)
A
end


@inline function Base.sizehint!(A::ElasticArray{T,N}, dims::Vararg{Integer,N}) where {T,N}
kernel_size, size_lastdim = _split_resize_dims(A, dims)
sizehint!(A.data, A.kernel_length.divisor * size_lastdim)
A
end

function _split_resize_dims(A::ElasticArray, dims::NTuple{N,Integer}) where {N}
kernel_size, size_lastdim = _split_dims(dims)
kernel_size != A.kernel_size && throw(ArgumentError("Can only resize last dimension of an ElasticArray"))
kernel_size, size_lastdim
end


function Base.append!(dest::ElasticArray, src::AbstractArray)
rem(length(eachindex(src)), dest.kernel_length) != 0 && throw(DimensionMismatch("Can't append, length of source array is incompatible"))
Expand Down

0 comments on commit 119b7d5

Please sign in to comment.