Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A better default dataids(::AbstractArray) #26237

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
31 changes: 25 additions & 6 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1522,14 +1522,33 @@ _isdisjoint(as::Tuple, bs::Tuple) = !(as[1] in bs) && _isdisjoint(tail(as), bs)
"""
Base.dataids(A::AbstractArray)

Return a tuple of `UInt`s that represent the mutable data segments of an array.
Return a tuple of `UInt`s that identify the mutable data segments of an array.

Custom arrays that would like to opt-in to aliasing detection of their component
parts can specialize this method to return the concatenation of the `dataids` of
their component parts. A typical definition for an array that wraps a parent is
`Base.dataids(C::CustomArray) = dataids(C.parent)`.
These values are used to determine if two arrays might share memory with [`Base.mightalias`](@ref).
The default implementation recursively combines the `dataids` of all fields of the struct.

Custom arrays only need to implement a custom `dataids` method if:

* they wish to ignore some fields (with non-empty `dataids`) in aliasing considerations;
for example this can be the case if an array is used to store intentionally-shared
metadata or other data that is not mutated by `setindex!`

* or they depend upon non-array fields (with empty `dataids`) to define their indexable
contents that they wish to include in aliasing considerations.
"""
dataids(A::AbstractArray) = (UInt(objectid(A)),)
mbauman marked this conversation as resolved.
Show resolved Hide resolved
function dataids(A::AbstractArray)
@inline
if @generated
:(ids = tuple($([:(dataids(getfield(A, $i))...) for i in 1:fieldcount(A)]...)))
else
ids = _splatmap(dataids, ntuple(i -> getfield(A, i), Val(nfields(A))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check all registered packages for undef fields in AbstractArray before merging this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PkgEval runs didn't come up with any such arrays to my reading, which is indeed a bit surprising (but welcome)!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha maybe its literally only Tridiagonal.

How certain are we than package tests actually call dataids somewhere on all AbstractArray?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question — they'll hit it if they touch any of the builtin implementations of broadcasting, view, or reshape... which I'd rate as fairly likely (even if not explicitly as a targeted test, they'd probably do at least one of those in some implementation). I'd think, anyhow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it sounds unlikely none of those are called.

end
if isimmutable(A) || !isempty(ids)
mbauman marked this conversation as resolved.
Show resolved Hide resolved
return ids
else
return (UInt(pointer_from_objref(A)),)
end
end
dataids(A::Array) = (UInt(pointer(A)),)
dataids(::AbstractRange) = ()
dataids(x) = ()
Expand Down
1 change: 1 addition & 0 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,7 @@ LogicalIndex{Int}(mask::AbstractArray) = LogicalIndex{Int, typeof(mask)}(mask)
size(L::LogicalIndex) = (L.sum,)
length(L::LogicalIndex) = L.sum
collect(L::LogicalIndex) = [i for i in L]
unaliascopy(L::TL) where {TL <: LogicalIndex} = TL(unaliascopy(L.mask))
show(io::IO, r::LogicalIndex) = print(io,collect(r))
print_array(io::IO, X::LogicalIndex) = print_array(io, collect(X))
# Iteration over LogicalIndex is very performance-critical, but it also must
Expand Down
4 changes: 4 additions & 0 deletions base/permuteddimsarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ function _copy!(P::PermutedDimsArray{T,N,perm}, src) where {T,N,perm}
return P
end

function Base.unaliascopy(P::PermutedDimsArray)::typeof(P)
return (typeof(P))(Base.unaliascopy(P.parent))
end

@noinline function _permutedims!(P::PermutedDimsArray, src, R1::CartesianIndices{0}, R2, R3, ds, dp)
ip, is = axes(src, dp), axes(src, ds)
for jo in first(ip):8:last(ip), io in first(is):8:last(is)
Expand Down
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,8 @@ similar(M::Tridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = sim
# Operations on Tridiagonal matrices
copyto!(dest::Tridiagonal, src::Tridiagonal) = (copyto!(dest.dl, src.dl); copyto!(dest.d, src.d); copyto!(dest.du, src.du); dest)

Base.dataids(A::Tridiagonal) = (Base.dataids(A.dl), Base.dataids(A.d), Base.dataids(A.du))
Copy link
Contributor

@rafaqz rafaqz Oct 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the idea here that any non-base package also needs to define a similar dataids method to avoid an error with undef fields? (when they have those)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they would need to (or otherwise not have the undef fields).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should mention that in the docs along with the other reasons to define custom dataids

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean its obscure and no one should do it, but its technically a breaking change


#Elementary operations
for func in (:conj, :copy, :real, :imag)
@eval function ($func)(M::Tridiagonal)
Expand Down
66 changes: 66 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,7 @@ end
Base.strides(S::Strider) = S.strides
Base.elsize(::Type{<:Strider{T}}) where {T} = Base.elsize(Vector{T})
Base.unsafe_convert(::Type{Ptr{T}}, S::Strider{T}) where {T} = pointer(S.data, S.offset)
Base.unaliascopy(S::Strider)::typeof(S) = (typeof(S))(Base.unaliascopy(S.data), S.offset, S.strides, S.size)

@testset "Simple 3d strided views and permutes" for sz in ((5, 3, 2), (7, 11, 13))
A = collect(reshape(1:prod(sz), sz))
Expand Down Expand Up @@ -1287,6 +1288,18 @@ end
end
end

@testset "PermutedDimsArray unaliasing" begin
A = [1 2; 3 4]
P = permutedims(A, (2,1))
A .= PermutedDimsArray(A, (2,1))
@test A == P

A = [1 2; 3 4]
S = Strider(vec(A), strides(A), size(A))
A .= PermutedDimsArray(S, (2, 1))
@test A == P
end

@testset "first/last n elements of $(typeof(itr))" for itr in (collect(1:9),
[1 4 7; 2 5 8; 3 6 9],
ntuple(identity, 9))
Expand All @@ -1309,6 +1322,59 @@ end
end
end

# Ensure dataids are inferrable for custom arrays
struct M0 <: AbstractArray{Int,2} end
struct M1{T} <: AbstractArray{Int,2}
x::T
end
struct M2{T,S} <: AbstractArray{Int,2}
x::T
y::S
end
struct M10{A,B,C,D,E,F,G,H,I,J} <: AbstractArray{Int,2}
a::A
b::B
c::C
d::D
e::E
f::F
g::G
h::H
i::I
j::J
end

@testset "dataids" begin
@test @inferred(Base.dataids(M0())) === ()
@test @inferred(Base.dataids(M1(1))) === ()
@test @inferred(Base.dataids(M1(1:10))) === ()
@test @inferred(Base.dataids(M10(1,2,3,4,5,6,7,8,9,0))) === ()

@test @inferred(Base.dataids(M1(M1([1])))) != Base.dataids(M1(M1([1])))
@test @inferred(Base.dataids(M1(M2([1],2)))) != Base.dataids(M1(M2([1],2)))
@test @inferred(Base.dataids(M1(M2([1],[2])))) != Base.dataids(M1(M2([1],[2])))
@test @inferred(Base.dataids(M10([1],[2],[3],[4],[5],[6],[7],[8],[9],[0]))) != Base.dataids(M10([1],[2],[3],[4],[5],[6],[7],[8],[9],[0]))

x = [1]
y = [1]
mx = M1(x)
mxx = M2(x,x)
mxy = M2(x,y)
@test @inferred(Base.mightalias(mx,mx))
@test @inferred(Base.mightalias(mx,mxx))
@test @inferred(Base.mightalias(mx,mxy))
@test @inferred(Base.mightalias(mxx,x))
@test @inferred(Base.mightalias(x,mxy))
@test !@inferred(Base.mightalias(mx, y))
@test !@inferred(Base.mightalias(mxx, y))
@test !@inferred(Base.mightalias(mxx, [1]))
@test !@inferred(Base.mightalias(mxy, 1:10))
@test !@inferred(Base.mightalias(mxy, M0()))
@test !@inferred(Base.mightalias(mxy, [1]))
@test !@inferred(Base.mightalias(mxy, M1(1:10)))
@test !@inferred(Base.mightalias(mxy, M1([1])))
end

@testset "Base.rest" begin
a = reshape(1:4, 2, 2)'
@test Base.rest(a) == a[:]
Expand Down