Skip to content

Commit

Permalink
Add Adapt.jl extension
Browse files Browse the repository at this point in the history
  • Loading branch information
Tokazama committed Jan 19, 2024
1 parent 5fba850 commit 4bd282c
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 84 deletions.
9 changes: 8 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,21 @@ DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

[extensions]
MetadataArraysAdapt = "Adapt"

[compat]
ArrayInterface = "7"
Adapt = "4"
Aqua = "0.8"
DataAPI = "1.14"
LinearAlgebra = "1"
Statistics = "1"
Test = "1"
julia = "1.0"
julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Expand Down
10 changes: 10 additions & 0 deletions ext/MetadataArraysAdapt/MetadataArraysAdapt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module MetadataArraysAdapt

using MetadataArrays
using Adapt

function Adapt.adapt_structure(to, mda::MetadataArray)
MetadataArrays._MetadataArray(adapt(to, getfield(mda, :parent)), getfield(mda, :metadata))
end

end
16 changes: 4 additions & 12 deletions src/MetadataArrays.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module MetadataArrays

using ArrayInterface
import ArrayInterface: parent_type, is_forwarding_wrapper, can_setindex,
can_change_size
import ArrayInterface: parent_type, is_forwarding_wrapper, can_setindex, can_change_size
using Base: BroadcastStyle
using DataAPI
import DataAPI: metadata, metadata!, metadatakeys, metadatasupport, deletemetadata!,
Expand All @@ -22,19 +21,11 @@ export

const MDType = Union{NamedTuple, AbstractDict{Symbol}, AbstractDict{String}}

struct MetadataStyle{S}
style::S

MetadataStyle(style::S) where {S}= new{S}(style)
MetadataStyle() = MetadataStyle(nothing)
global const DEFAULT_META_STYLE = MetadataStyle()
end


include("MetadataDict.jl")
include("MetadataDicts.jl")
include("types.jl")
include("array.jl")
include("metadata.jl")
include("reduce.jl")

Base.write(io::IO, mda::MetadataArray) = write(io, getfield(mda, :parent))
function Base.read!(io::IO, mda::MetadataArray)
Expand All @@ -50,3 +41,4 @@ include("resizing.jl")
include("broadcasting.jl")

end # module

104 changes: 67 additions & 37 deletions src/MetadataDict.jl → src/MetadataDicts.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,53 @@
module MetadataDicts

struct MetadataDict{K, V, P <: Union{AbstractDict{K, V}, NamedTuple{<:Any, <:Tuple{Vararg{V}}}}, M <: MDType, S<:MetadataStyle} <: AbstractDict{K, V}
using ArrayInterface
import ArrayInterface: parent_type, can_setindex, can_change_size
using DataAPI
import DataAPI: metadata, metadata!, metadatakeys, metadatasupport, deletemetadata!,
emptymetadata!

export
MetadataDict,
MetadataStyle

#region styles
struct MetadataBranch{S}
style::S
end
struct MetadataLeaf{S}
style::S
end
const MetadataStyle{S} = Union{MetadataLeaf{S}, MetadataBranch{S}}
#endregion styles

const MDType = Union{NamedTuple, AbstractDict{Symbol}, AbstractDict{String}}

struct MetadataDict{K, V, P <: Union{AbstractDict{K, V}, NamedTuple{<:Any, <:Tuple{Vararg{V}}}}, M <: MDType} <: AbstractDict{K, V}
parent::P
metadata::M
style::S

function MetadataDict{K, V, P, M, S}(p::Union{AbstractDict, NamedTuple}=D(), m=M(), s=S()) where {K, V, P, M, S}
new{K, V, P, M, S}(p, m, s)
function MetadataDict{K, V, P, M}(p::Union{AbstractDict, NamedTuple}=D(), m=M()) where {K, V, P, M}
new{K, V, P, M}(p, m)
end
function MetadataDict{K, V, P, M}(p::Union{AbstractDict, NamedTuple}=P(), m=NamedTuple(), s=MetadataStyle()) where {K, V, P, M}
MetadataDict{K, V, P, M, typeof(s)}(p, m, s)
function MetadataDict{K, V, P}(p::Union{AbstractDict, NamedTuple}=P(), m=NamedTuple()) where {K, V, P}
MetadataDict{K, V, P, typeof(m)}(p, m)
end
function MetadataDict{K, V, P}(p::Union{AbstractDict, NamedTuple}=P(), m=NamedTuple(), s=MetadataStyle()) where {K, V, P}
MetadataDict{K, V, P, typeof(m)}(p, m, s)
function MetadataDict{K, V}(d::Union{AbstractDict, NamedTuple}, m=NamedTuple()) where {K, V}
MetadataDict{K, V, typeof(d)}(d, m)
end
function MetadataDict{K, V}(d::Union{AbstractDict, NamedTuple}, m=NamedTuple(), s=MetadataStyle()) where {K, V}
MetadataDict{K, V, typeof(d)}(d, m, s)
function MetadataDict{K}(d::AbstractDict, m=NamedTuple()) where {K}
MetadataDict{K, valtype(d)}(d, m)
end
function MetadataDict{K}(d::AbstractDict, m=NamedTuple(), s=MetadataStyle()) where {K}
MetadataDict{K, valtype(d)}(d, m, s)
end
function MetadataDict(d::AbstractDict, m::MDType=NamedTuple(), s=MetadataStyle())
MetadataDict{keytype(d)}(d, m, s)
function MetadataDict(d::AbstractDict, m::MDType=NamedTuple())
MetadataDict{keytype(d)}(d, m)
end

# NamedTuple support
function MetadataDict{Symbol}(p::NamedTuple, m::MDType=NamedTuple(), s=MetadataStyle())
MetadataDict{Symbol, eltype(p)}(p, m, s)
function MetadataDict{Symbol}(p::NamedTuple, m::MDType=NamedTuple())
MetadataDict{Symbol, eltype(p)}(p, m)
end
function MetadataDict(p::NamedTuple, m::MDType=NamedTuple(), s=MetadataStyle())
MetadataDict{Symbol}(p, m, s)
function MetadataDict(p::NamedTuple, m::MDType=NamedTuple())
MetadataDict{Symbol}(p, m)
end

function Base.copy(mdd::MetadataDict{K, V, P, M}) where {K, V, P, M}
Expand All @@ -38,32 +57,39 @@ struct MetadataDict{K, V, P <: Union{AbstractDict{K, V}, NamedTuple{<:Any, <:Tup
else
m = copy(getfield(mdd, :metadata))
end
s = getfield(mdd, :style)
new{K, V, P, M, typeof(s)}(p, m, s)
new{K, V, P, M}(p, m)
end
end

const NamedMetadataDict{K, V, P, MDNS, MDTYS} = MetadataDict{K, V, P, NamedTuple{MDNS, MDTYS}}

Base.parent(mdd::MetadataDict) = getfield(mdd, :parent)
ArrayInterface.parent_type(@nospecialize(T::Type{<:MetadataDict})) = fieldtype(T, :parent)
function ArrayInterface.parent_type(@nospecialize(T::Type{<:MetadataDict{<:Any, <:Any, <:Any, <:Any}}))
fieldtype(T, :parent)
end

Base.propertynames(mda::MetadataDict) = propertynames(getfield(mda, :parent))
Base.hasproperty(mda::MetadataDict, s::Symbol) = hasproperty(getfield(mda, :parent), s)

Base.getproperty(mda::MetadataDict, s::Symbol) = getproperty(getfield(mda, :parent), s)
function Base.setproperty!(mda::MetadataDict, s::Symbol, v)
setproperty!(getfield(mda, :parent), s, v)
function Base.getproperty(mda::MetadataDict, s::Symbol, order::Symbol)
getproperty(getfield(mda, :parent), s, order)
end

Base.setproperty!(mda::MetadataDict, s::Symbol, v) = setproperty!(getfield(mda, :parent), s, v)
function Base.setproperty!(mda::MetadataDict, s::Symbol, v, order::Symbol)
setproperty!(getfield(mda, :parent), s, v, order)
end

function ArrayInterface.can_setindex(@nospecialize(T::Type{<:MetadataDict}))
function ArrayInterface.can_setindex(@nospecialize(T::Type{<:MetadataDict{<:Any, <:Any, <:Any, <:Any}}))
can_setindex(fieldtype(T, :parent))
end

function ArrayInterface.can_change_size(@nospecialize(T::Type{<:MetadataDict}))
function ArrayInterface.can_change_size(@nospecialize(T::Type{<:MetadataDict{<:Any, <:Any, <:Any, <:Any}}))
can_change_size(fieldtype(T, :parent))
end

ArrayInterface.is_forwarding_wrapper(@nospecialize(T::Type{<:MetadataDict})) = true
ArrayInterface.is_forwarding_wrapper(@nospecialize(T::Type{<:MetadataDict{<:Any, <:Any, <:Any, <:Any}})) = true

function Base.sizehint!(mdd::MetadataDict, n::Integer)
sizehint!(getfield(mdd, :parent), n)
Expand Down Expand Up @@ -131,7 +157,7 @@ function _promote_valtypes(V, d, ds...) # give up if promoted to any
end

Base.merge(pd::MetadataDict) = copy(pd)
Base.merge(pd::NamedMetadataDict, pds::NamedMetadataDict...) = _mergeprops(_getarg2, pd, pds...)
Base.merge(pd::NamedMetadataDict, pds::NamedMetadataDict...) = _mergewith(_getarg2, pd, pds...)
_getarg2(@nospecialize(arg1), @nospecialize(arg2)) = arg2
function Base.merge(pd::MetadataDict, pds::MetadataDict...)
K = _promote_keytypes((pd, pds...))
Expand All @@ -155,17 +181,21 @@ function Base.mergewith(combine, pd::MetadataDict, pds::MetadataDict...)
mergewith!(combine, out, pds...)
end
@inline function Base.mergewith(combine, pd::NamedMetadataDict, pds::NamedMetadataDict...)
_mergeprops(combine, pd, pds...)
_mergewith(combine, pd, pds...)
end
_mergeprops(combine, @nospecialize(x::NamedMetadataDict)) = x
@inline function _mergeprops(combine, x::NamedMetadataDict, y::NamedMetadataDict)
_mergewith(combine, @nospecialize(x::NamedMetadataDict)) = x
@inline function _mergewith(combine, x::NamedMetadataDict, y::NamedMetadataDict)
MetadataDict(mergewith(combine, getfield(x, :data), getfield(y, :data)))
end
@inline function _mergeprops(combine, x::NamedMetadataDict, y::NamedMetadataDict, zs::NamedMetadataDict...)
_mergeprops(combine, _mergeprops(combine, x, y), zs...)
@inline function _mergewith(combine, x::NamedMetadataDict, y::NamedMetadataDict, zs::NamedMetadataDict...)
_mergewith(combine, _mergewith(combine, x, y), zs...)
end

metadata(mdd::MetadataDict) = getfield(mdd, :metadata)
#region metadata interface
metadatakeys(mdd::MetadataDict) = keys(getfield(mdd, :metadata))
function metadatasupport(T::Type{<:MetadataDict})
(read=true, write=ArrayInterface.can_setindex(fieldtype(T, :metadata)))
end
function metadata(mdd::MetadataDict, key; style::Bool=false)
md = getfield(mdd, :metadata)[key]
if style
Expand Down Expand Up @@ -200,9 +230,9 @@ function emptymetadata!(mdd::MetadataDict)
empty!(getfield(mdd, :metadata))
return mdd
end
#endregion metadata interface

metadatakeys(mdd::MetadataDict) = keys(getfield(mdd, :metadata))
function metadatasupport(T::Type{<:MetadataDict})
(read=true, write=ArrayInterface.can_setindex(fieldtype(T, :metadata)))
end

using .MetadataDicts

38 changes: 36 additions & 2 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,36 @@ Base.IndexStyle(T::Type{<:MetadataArray}) = IndexStyle(fieldtype(T, :parent))

Base.iterate(mda::MetadataArray) = iterate(getfield(mda, :parent))
Base.iterate(mda::MetadataArray, state) = iterate(getfield(mda, :parent), state)
Base.iterate(mda::MetadataUnitRange) = iterate(getfield(mda, :parent))
Base.iterate(mda::MetadataUnitRange, state) = iterate(getfield(mda, :parent), state)

Base.first(mda::MetadataArray) = first(getfield(mda, :parent))
Base.first(x::MetadataUnitRange) = first(getfield(x, :parent))

Base.step(mda::MetadataArray) = step(getfield(mda, :parent))

Base.last(mda::MetadataArray) = last(getfield(mda, :parent))
Base.last(x::MetadataUnitRange) = last(getfield(x, :parent))

Base.size(mda::MetadataArray) = size(getfield(mda, :parent))

Base.axes(mda::MetadataArray) = axes(getfield(mda, :parent))

Base.strides(mda::MetadataArray) = strides(getfield(mda, :parent))

Base.length(mda::MetadataArray) = length(getfield(mda, :parent))
Base.length(mda::MetadataUnitRange) = length(getfield(mda, :parent))

Base.firstindex(mda::MetadataArray) = firstindex(getfield(mda, :parent))
Base.firstindex(mda::MetadataUnitRange) = firstindex(getfield(mda, :parent))

Base.lastindex(mda::MetadataArray) = lastindex(getfield(mda, :parent))
Base.lastindex(mda::MetadataUnitRange) = lastindex(getfield(mda, :parent))

Base.pointer(mda::MetadataArray) = pointer(getfield(mda, :parent))
Base.pointer(mda::MetadataArray, i::Integer) = pointer(getfield(mda, :parent), i)

Base.in(val, mda::MetadataArray) = in(val, getfield(mda, :parent))
Base.keys(mda::MetadataArray) = keys(getfield(mda, :parent))
Base.isempty(mda::MetadataArray) = isempty(getfield(mda, :parent))

Base.@propagate_inbounds function Base.isassigned(mda::MetadataArray, i::Integer...)
isassigned(getfield(mda, :parent), i...)
Expand All @@ -32,3 +44,25 @@ end
function Base.dataids(mda::MetadataArray)
(Base.dataids(getfield(mda, :parent))..., Base.dataids(getfield(mda, :metadata))...)
end

function Base.transpose(mda::MetadataArray)
p = transpose(getfield(mda, :parent))
m = permute_dimsmetadata(mda)
_MetadataArray(p, m)
end
function Base.adjoint(mda::MetadataArray)
p = adjoint(getfield(mda, :parent))
m = permute_dimsmetadata(mda)
_MetadataArray(p, m)
end
function Base.permutedims(mda::MetadataArray)
p = permutedims(getfield(mda, :parent))
m = permute_dimsmetadata(mda)
_MetadataArray(p, m)
end
function Base.permutedims(mda::MetadataArray, perm)
p = permutedims(getfield(mda, :parent), perm)
m = permute_dimsmetadata(mda, perm)
_MetadataArray(p, m)
end

24 changes: 24 additions & 0 deletions src/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,28 @@

Base.@propagate_inbounds function Base.getindex(mdu::MetadataUnitRange, i::Integer)
getindex(getfield(mdu, :parent), i)
end
Base.@propagate_inbounds function Base.getindex(mdu::MetadataUnitRange, i)
propagate_metadata(mdu, getindex(getfield(mdu, :parent), i))
end

Base.@propagate_inbounds function Base.getindex(
mdu::MetadataUnitRange,
s::StepRange{T}
) where T<:Integer
propagate_metadata(mdu, getindex(getfield(mdu, :parent), s))
end
Base.@propagate_inbounds function Base.getindex(
mdu::MetadataUnitRange,
s::AbstractUnitRange{T}
) where {T<:Integer}
propagate_metadata(mdu, getindex(getfield(mdu, :parent), s))
end

Base.getindex(mdu::MetadataUnitRange, ::Colon) = copy(mdu)

Base.copy(mdu::MetadataUnitRange) = copy_metadata(mdu, copy(getfield(mdu, :parent)))

Base.@propagate_inbounds function Base.getindex(mda::MetadataArray, i::Int...)
getfield(mda, :parent)[i...]
end
Expand Down
Loading

0 comments on commit 4bd282c

Please sign in to comment.