Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Optimize _StringSliceIter to not have branching in forward iteration #3546

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 26 additions & 39 deletions stdlib/src/utils/string_slice.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,9 @@ fn _utf8_first_byte_sequence_length(b: Byte) -> Int:

debug_assert(
(b & 0b1100_0000) != 0b1000_0000,
(
"Function `_utf8_first_byte_sequence_length()` does not work"
" correctly if given a continuation byte."
),
"Function does not work correctly if given a continuation byte.",
)
var flipped = ~b
return int(count_leading_zeros(flipped) + (flipped >> 7))
return int(count_leading_zeros(~b)) + int(b < 0b1000_0000)


fn _shift_unicode_to_utf8(ptr: UnsafePointer[UInt8], c: Int, num_bytes: Int):
Expand Down Expand Up @@ -182,62 +178,53 @@ struct _StringSliceIter[
forward: The iteration direction. `False` is backwards.
"""

alias _S = StringSlice[origin]
alias _U = UnsafePointer[Byte]
var index: Int
var continuation_bytes: Int
var ptr: UnsafePointer[UInt8]
var ptr: Self._U
var length: Int

fn __init__(
inout self, *, unsafe_pointer: UnsafePointer[UInt8], length: Int
):
fn __init__(inout self, *, unsafe_pointer: Self._U, length: Int):
self.index = 0 if forward else length
self.ptr = unsafe_pointer
self.length = length
alias S = Span[Byte, StaticConstantOrigin]
var s = S(ptr=self.ptr, length=self.length)
self.continuation_bytes = _count_utf8_continuation_bytes(s)

fn __iter__(self) -> Self:
return self

fn __next__(inout self) -> StringSlice[origin]:
fn __next__(inout self) -> Self._S:
@parameter
if forward:
var byte_len = 1
if self.continuation_bytes > 0:
var byte_type = _utf8_byte_type(self.ptr[self.index])
if byte_type != 0:
byte_len = int(byte_type)
self.continuation_bytes -= byte_len - 1
byte_len = _utf8_first_byte_sequence_length(self.ptr[self.index])
i = self.index
self.index += byte_len
return StringSlice[origin](
ptr=self.ptr + (self.index - byte_len), length=byte_len
)
return Self._S(ptr=self.ptr + i, length=byte_len)
else:
var byte_len = 1
if self.continuation_bytes > 0:
var byte_type = _utf8_byte_type(self.ptr[self.index - 1])
if byte_type != 0:
while byte_type == 1:
byte_len += 1
var b = self.ptr[self.index - byte_len]
byte_type = _utf8_byte_type(b)
self.continuation_bytes -= byte_len - 1
byte_len = 1
while _utf8_byte_type(self.ptr[self.index - byte_len]) == 1:
byte_len += 1
self.index -= byte_len
return StringSlice[origin](
ptr=self.ptr + self.index, length=byte_len
)
return Self._S(ptr=self.ptr + self.index, length=byte_len)

@always_inline
fn __has_next__(self) -> Bool:
return self.__len__() > 0
@parameter
if forward:
return self.index < self.length
else:
return self.index > 0

fn __len__(self) -> Int:
alias S = Span[Byte, ImmutableAnyOrigin]
alias _count = _count_utf8_continuation_bytes

@parameter
if forward:
return self.length - self.index - self.continuation_bytes
remaining = self.length - self.index
cont = _count(S(ptr=self.ptr + self.index, length=remaining))
return remaining - cont
else:
return self.index - self.continuation_bytes
return self.index - _count(S(ptr=self.ptr, length=self.index))


@value
Expand Down
30 changes: 17 additions & 13 deletions stdlib/test/collections/test_string.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -1272,19 +1272,23 @@ def test_string_iter():

var idx = -1
vs = String("mojo🔥")
for item in vs:
idx += 1
if idx == 0:
assert_equal("m", item)
elif idx == 1:
assert_equal("o", item)
elif idx == 2:
assert_equal("j", item)
elif idx == 3:
assert_equal("o", item)
elif idx == 4:
assert_equal("🔥", item)
assert_equal(4, idx)
var iterator = vs.__iter__()
assert_equal(5, len(iterator))
var item = iterator.__next__()
assert_equal("m", item)
assert_equal(4, len(iterator))
item = iterator.__next__()
assert_equal("o", item)
assert_equal(3, len(iterator))
item = iterator.__next__()
assert_equal("j", item)
assert_equal(2, len(iterator))
item = iterator.__next__()
assert_equal("o", item)
assert_equal(1, len(iterator))
item = iterator.__next__()
assert_equal("🔥", item)
assert_equal(0, len(iterator))

var items = List[String](
"mojo🔥",
Expand Down