Skip to content

Commit

Permalink
add parametrized vector width
Browse files Browse the repository at this point in the history
Signed-off-by: martinvuyk <[email protected]>
  • Loading branch information
martinvuyk committed Dec 10, 2024
1 parent 6ffd02c commit d8aa3fa
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions stdlib/src/memory/span.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct _SpanSIMDIter[
D: DType,
span_origin: Origin[False],
iter_origin: Origin[True],
width: Int,
forward: Bool = True,
]:
"""SIMD Iterator for Span.
Expand All @@ -57,6 +58,7 @@ struct _SpanSIMDIter[
D: The DType of the elements in the span.
span_origin: The origin of the Span.
iter_origin: The origin of the _SpanIter.
width: The width of the resulting vector.
forward: The iteration direction. `False` is backwards.
Notes:
Expand All @@ -79,7 +81,6 @@ struct _SpanSIMDIter[
.
"""

alias _width = simdwidthof[D]()
var src: Pointer[
_SpanIter[Scalar[D], span_origin, forward=forward], iter_origin
]
Expand All @@ -89,35 +90,35 @@ struct _SpanSIMDIter[
return self

@always_inline
fn __next__(mut self) -> SIMD[D, Self._width]:
fn __next__(mut self) -> SIMD[D, width]:
@parameter
if forward:
var i = self.src[].index
self.src[].index += Self._width
return (self.src[].src.unsafe_ptr() + i).load[width = Self._width]()
self.src[].index += width
return (self.src[].src.unsafe_ptr() + i).load[width=width]()
else:
self.src[].index -= Self._width
self.src[].index -= width
return (
(self.src[].src.unsafe_ptr() + self.src[].index)
.load[width = Self._width]()
.load[width=width]()
.reversed()
)

@always_inline
fn __has_next__(self) -> Bool:
@parameter
if forward:
return (len(self.src[].src) - self.src[].index) >= Self._width
return (len(self.src[].src) - self.src[].index) >= width
else:
return self.src[].index > 0

@always_inline
fn __len__(self) -> Int:
@parameter
if forward:
return (len(self.src[].src) - self.src[].index) // Self._width
return (len(self.src[].src) - self.src[].index) // width
else:
return self.src[].index // Self._width
return self.src[].index // width


@value
Expand Down Expand Up @@ -154,12 +155,17 @@ struct _SpanIter[
return Pointer.address_of(self.src[self.index])

fn vectorized[
D: DType, O: ImmutableOrigin, //
D: DType, O: ImmutableOrigin, //, width: Int = simdwidthof[D]()
](mut self: _SpanIter[Scalar[D], O, forward=True]) -> _SpanSIMDIter[
D, O, __origin_of(self), forward=True
D, O, __origin_of(self), width, forward=True
]:
"""Return a vectorized Span iterator.
Parameters:
D: The DType of the elements in the span.
O: The origin of the Span.
width: The width of the resulting vector.
Notes:
This iterator should be used on 2 loops.
Expand All @@ -179,7 +185,7 @@ struct _SpanIter[
```
.
"""
return _SpanSIMDIter[D, O, __origin_of(self), forward=True](
return _SpanSIMDIter[D, O, __origin_of(self), width, forward=True](
Pointer.address_of(self)
)

Expand Down

0 comments on commit d8aa3fa

Please sign in to comment.