Skip to content

Commit

Permalink
fix strides for AbstractRanges, fixes #18
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Jun 15, 2022
1 parent 7465b68 commit ac943fa
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "StrideArraysCore"
uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da"
authors = ["Chris Elrod <[email protected]> and contributors"]
version = "0.3.9"
version = "0.3.10"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
12 changes: 3 additions & 9 deletions src/StrideArraysCore.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
module StrideArraysCore

using LayoutPointers, ArrayInterface, ThreadingUtilities, ManualMemory, IfElse, Static
using Static:
StaticInt,
StaticBool,
True,
False,
Zero,
One

using Static: StaticInt, StaticBool, True, False, Zero, One

using ArrayInterface:
OptionallyStaticUnitRange,
size,
Expand Down Expand Up @@ -53,7 +47,7 @@ function __init__()
if Base.JLOptions().check_bounds == 1
@eval boundscheck() = true
end
# # @require LoopVectorization="bdcacae8-1622-11e9-2a5c-532679323890" @eval using StrideArrays
# # @require LoopVectorization="bdcacae8-1622-11e9-2a5c-532679323890" @eval using StrideArrays
end

end
24 changes: 12 additions & 12 deletions src/ptr_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ ptrarray0(ptr::Ptr, s::Tuple{Vararg{Integer}}, ::StaticInt{1}) = ptrarray0(ptr,
perm = Expr(:tuple)
resize!(perm.args, N)
perm.args[C] = 1
for n in 1:N
for n = 1:N
if n != C
push!(d.args, Expr(:call, getfield, :s, n))
perm.args[n] = n + (n < C)
Expand Down Expand Up @@ -223,9 +223,9 @@ end
@inline function Base.stride(A::AbstractStrideArray, i::Int)
x = Base.strides(A)
@assert i > 0
i <= length(x) ? @inbounds(x[i]) : x[end]*Base.size(A)[end]
i <= length(x) ? @inbounds(x[i]) : x[end] * Base.size(A)[end]
end
@generated _oneto(x) = Expr(:new, Base.OneTo{Int}, :(x%Int))
@generated _oneto(x) = Expr(:new, Base.OneTo{Int}, :(x % Int))

@inline create_axis(s, ::Zero) = CloseOpen(s)
@inline create_axis(s, ::One) = _oneto(unsigned(s))
Expand Down Expand Up @@ -322,8 +322,7 @@ end
@inline ArrayInterface.size(A::AbstractStrideArray, ::StaticInt{N}) where {N} = size(A)[N]
@inline ArrayInterface.size(A::AbstractStrideArray, i::Integer) =
type_stable_select(size(A), i)
@inline ArrayInterface.size(A::AbstractStrideArray, i::Int) =
type_stable_select(size(A), i)
@inline ArrayInterface.size(A::AbstractStrideArray, i::Int) = type_stable_select(size(A), i)
@inline Base.size(A::AbstractStrideArray, i::Integer) = size(A, i)


Expand Down Expand Up @@ -660,11 +659,13 @@ end
@assert 1 C N
else#if C < 1
known_offsets = known(O)
first_offset = if all(Base.Fix2(isa,Int), known_offsets) && all(==(first(known_offsets)), known_offsets)
first(known_offsets)
else
1
end
first_offset =
if all(Base.Fix2(isa, Int), known_offsets) &&
all(==(first(known_offsets)), known_offsets)
first(known_offsets)
else
1
end
push!(offs_expr.args, static(first_offset))
push!(Dnew.args, true)
push!(Rnew.args, 1)
Expand Down Expand Up @@ -724,5 +725,4 @@ end
@inline Base.reinterpret(::typeof(reshape), ::Type{T}, A::AbstractStrideArray) where {T} =
StrideArray(reinterpret(reshape, T, PtrArray(A)), preserve_buffer(A))

Base.LinearIndices(x::AbstractStrideVector) = axes(x,static(1))

Base.LinearIndices(x::AbstractStrideVector) = axes(x, static(1))
5 changes: 4 additions & 1 deletion src/views.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ function view_quote(i, K, S, D, N, C, B, R, zero_offsets::Bool = false)
if i[k] <: AbstractRange
Nnew += 1
push!(s.args, Expr(:call, :static_length, iₖ))
push!(x.args, Expr(:ref, :x, k))
push!(
x.args,
Expr(:call, :(*), Expr(:call, ArrayInterface.static_step, iₖ), Expr(:ref, :x, k)),
)
push!(o.args, zero_offsets ? :(Zero()) : :(One()))
if k == C
Cnew = Nnew
Expand Down
41 changes: 30 additions & 11 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ allocated_cartesianindexsum(x) = @allocated cartesianindexsum(x)

@testset "StrideArrays Basic" begin
@test (Base.JLOptions().check_bounds == 1) == StrideArraysCore.boundscheck()

Acomplex = StrideArray{Complex{Float64}}(undef, (StaticInt(4), StaticInt(5)))
@test @inferred(StrideArraysCore.ArrayInterface.known_size(Acomplex)) === (4, 5)
Acomplex .= rand.(Complex{Float64})
Expand Down Expand Up @@ -86,17 +86,32 @@ allocated_cartesianindexsum(x) = @allocated cartesianindexsum(x)

D = copy(A)
Cslice = view(C, 23:48, 17:89)
@test Base.stride(Cslice,1) == Base.stride(C,1) == StrideArraysCore.stride(Cslice,1) == StrideArraysCore.stride(Cslice,static(1)) == StrideArraysCore.stride(C,1) == StrideArraysCore.stride(C,static(1))
@test Base.stride(Cslice,2) == Base.stride(C,2) == StrideArraysCore.stride(Cslice,2) == StrideArraysCore.stride(Cslice,static(2)) == StrideArraysCore.stride(C,2) == StrideArraysCore.stride(C,static(2))
@test Base.stride(Cslice, 1) ==
Base.stride(C, 1) ==
StrideArraysCore.stride(Cslice, 1) ==
StrideArraysCore.stride(Cslice, static(1)) ==
StrideArraysCore.stride(C, 1) ==
StrideArraysCore.stride(C, static(1))
@test Base.stride(Cslice, 2) ==
Base.stride(C, 2) ==
StrideArraysCore.stride(Cslice, 2) ==
StrideArraysCore.stride(Cslice, static(2)) ==
StrideArraysCore.stride(C, 2) ==
StrideArraysCore.stride(C, static(2))
if VERSION >= v"1.9.0-DEV.569"
@test Base.stride(C,3) == StrideArraysCore.stride(C,3) == StrideArraysCore.stride(C,static(3))
@test Base.stride(C, 3) ==
StrideArraysCore.stride(C, 3) ==
StrideArraysCore.stride(C, static(3))

end
@test Base.stride(C,3) == StrideArraysCore.stride(C,3)
@test Base.stride(Cslice,3) == StrideArraysCore.stride(Cslice,3)
@test_broken Base.stride(Cslice,3) == StrideArraysCore.stride(Cslice,static(3))
@test Base.strides(Cslice) == Base.strides(C) == StrideArraysCore.strides(Cslice) == StrideArraysCore.strides(C)

@test Base.stride(C, 3) == StrideArraysCore.stride(C, 3)
@test Base.stride(Cslice, 3) == StrideArraysCore.stride(Cslice, 3)
@test_broken Base.stride(Cslice, 3) == StrideArraysCore.stride(Cslice, static(3))
@test Base.strides(Cslice) ==
Base.strides(C) ==
StrideArraysCore.strides(Cslice) ==
StrideArraysCore.strides(C)

Cslice .= 2
@test D != C
D[23:48, 17:89] .= 2
Expand Down Expand Up @@ -261,6 +276,10 @@ allocated_cartesianindexsum(x) = @allocated cartesianindexsum(x)
@test all(isone, StrideArray(one, 100, 200))
end
@testset "views" begin
B0 = reshape(collect(1:12), 3, 4)
B1 = StrideArray(A)
@test view(B0, :, 4:-1:1) == view(B1, :, 4:-1:1)
@test view(B0, :, 1:2:4) == view(B1, :, 1:2:4)
A = StrideArray{Float64}(undef, (100, 100)) .= rand.()
vA = view(A, 3:40, 2:50)
vAslice = view(A, :, 2:50)
Expand All @@ -286,11 +305,11 @@ Bool[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]"""
Bool[0 0 0 0 0 1 1 1 1 1]"""
end
@testset "ptrarray0" begin
x = collect(0:3);
x = collect(0:3)
pzx = StrideArraysCore.ptrarray0(pointer(x), (4,))
GC.@preserve x begin
for i = 0:3
@test pzx[i] == pzx[i,1] == i
@test pzx[i] == pzx[i, 1] == i
end
end
end
Expand Down

2 comments on commit ac943fa

@chriselrod
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/62378

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.10 -m "<description of version>" ac943fa1a603804de7e442c957b37ba321ee6c18
git push origin v0.3.10

Please sign in to comment.