Skip to content

Commit

Permalink
Work in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobnissen committed Dec 11, 2024
1 parent 4b17278 commit 671c357
Showing 1 changed file with 76 additions and 20 deletions.
96 changes: 76 additions & 20 deletions src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,7 @@ function Base.getindex(v::MemoryView, i::Integer)
@inbounds ref[]
end

function Base.similar(
mem::MemoryView{T1, M},
::Type{T2},
dims::Tuple{Int},
) where {T1, T2, M}
function Base.similar(::MemoryView{T1, M}, ::Type{T2}, dims::Tuple{Int}) where {T1, T2, M}
len = Int(only(dims))::Int
memory = Memory{T2}(undef, len)
MemoryView{T2, M}(unsafe, memoryref(memory), len)
Expand Down Expand Up @@ -89,6 +85,23 @@ end
Base.getindex(v::MemoryView, ::Colon) = v
Base.view(v::MemoryView, idx::AbstractUnitRange) = v[idx]

function truncate(mem::MemoryView, include_last::Integer)
lst = Int(include_last)::Int
@boundscheck if (lst % UInt) > length(mem) % UInt
throw(BoundsError(mem, lst))
end
typeof(mem)(unsafe, mem.ref, lst)
end

function truncate_start_nonempty(mem::MemoryView, from::Integer)
frm = Int(from)::Int
@boundscheck if ((frm - 1) % UInt) length(mem) % UInt
throw(BoundsError(mem, frm))
end
newref = @inbounds memoryref(mem.ref, frm)
typeof(mem)(unsafe, newref, length(mem) - frm + 1)
end

function Base.unsafe_copyto!(dst::MutableMemoryView{T}, src::MemoryView{T}) where {T}
iszero(length(src)) && return dst
@inbounds unsafe_copyto!(dst.ref, src.ref, length(src))
Expand Down Expand Up @@ -116,16 +129,6 @@ function Base.findnext(p::Function, mem::MemoryView, start::Integer)
nothing
end

function Base.findprev(p::Function, mem::MemoryView, start::Integer)
i = Int(start)::Int
@boundscheck (i > length(mem) && throw(BoundsError(mem, i)))
@inbounds while i > 0
p(mem[i]) && return i
i -= 1
end
nothing
end

# The following two methods could be collapsed, but they aren't for two reasons:
# * To prevent ambiguity with Base
# * Because we DON'T want this code to run with MemoryView{Union{UInt8, Int8}}.
Expand All @@ -147,16 +150,17 @@ function Base.findnext(
_findnext(mem, p.x, start)
end

@inline function _findnext(
Base.@propagate_inbounds function _findnext(
mem::MemoryView{T},
byte::Union{T},
start::Integer,
) where {T <: Union{UInt8, Int8}}
start = Int(start)::Int
real_start = max(start, 1)
v = @inbounds ImmutableMemoryView(mem[real_start:end])
v_ind = @something memchr(v, byte) return nothing
v_ind + real_start - 1
@boundscheck(start < 1 && throw(BoundsError(mem, start)))
start > length(mem) && return nothing
im = @inbounds truncate_start_nonempty(ImmutableMemoryView(mem), start)
v_ind = @something memchr(im, byte) return nothing
v_ind + start - 1
end

function memchr(mem::ImmutableMemoryView{T}, byte::T) where {T <: Union{Int8, UInt8}}
Expand All @@ -172,6 +176,58 @@ function memchr(mem::ImmutableMemoryView{T}, byte::T) where {T <: Union{Int8, UI
p == C_NULL ? nothing : (p - ptr) % Int + 1
end

function Base.findprev(p::Function, mem::MemoryView, start::Integer)
i = Int(start)::Int
@boundscheck (i > length(mem) && throw(BoundsError(mem, i)))
@inbounds while i > 0
p(mem[i]) && return i
i -= 1
end
nothing
end

function Base.findprev(
p::Base.Fix2{<:Union{typeof(==), typeof(isequal)}, UInt8},
mem::MemoryView{UInt8},
start::Integer,
)
_findprev(mem, p.x, start)
end

function Base.findprev(
p::Base.Fix2{<:Union{typeof(==), typeof(isequal)}, Int8},
mem::MemoryView{Int8},
start::Integer,
)
_findprev(mem, p.x, start)
end

Base.@propagate_inbounds function _findprev(
mem::MemoryView{T},
byte::Union{T},
start::Integer,
) where {T <: Union{UInt8, Int8}}
start = Int(start)::Int
@boundscheck (start > length(mem) && throw(BoundsError(mem, start)))
start < 1 && return nothing
im = @inbounds truncate(ImmutableMemoryView(mem), start)
v_ind = @something memrchr(im, byte) return nothing
v_ind + start - 1
end

function memrchr(mem::ImmutableMemoryView{T}, byte::T) where {T <: Union{Int8, UInt8}}
isempty(mem) && return nothing
GC.@preserve mem begin
ptr = Ptr{UInt8}(pointer(mem))
p = @ccall memrchr(
ptr::Ptr{UInt8},
(byte % UInt8)::UInt8,
length(mem)::Int,
)::Ptr{Cvoid}
end
p == C_NULL ? nothing : (p - ptr) % Int + 1
end

const Bits =
Union{Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Int128, UInt128, Char}

Expand Down

0 comments on commit 671c357

Please sign in to comment.