From d8aa3facf0fa9f0195924ee77c52bce2d2f0e846 Mon Sep 17 00:00:00 2001 From: martinvuyk Date: Tue, 10 Dec 2024 09:18:30 -0300 Subject: [PATCH] add parametrized vector width Signed-off-by: martinvuyk --- stdlib/src/memory/span.mojo | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/stdlib/src/memory/span.mojo b/stdlib/src/memory/span.mojo index 16abf73511..39b9ef798a 100644 --- a/stdlib/src/memory/span.mojo +++ b/stdlib/src/memory/span.mojo @@ -49,6 +49,7 @@ struct _SpanSIMDIter[ D: DType, span_origin: Origin[False], iter_origin: Origin[True], + width: Int, forward: Bool = True, ]: """SIMD Iterator for Span. @@ -57,6 +58,7 @@ struct _SpanSIMDIter[ D: The DType of the elements in the span. span_origin: The origin of the Span. iter_origin: The origin of the _SpanIter. + width: The width of the resulting vector. forward: The iteration direction. `False` is backwards. Notes: @@ -79,7 +81,6 @@ struct _SpanSIMDIter[ . """ - alias _width = simdwidthof[D]() var src: Pointer[ _SpanIter[Scalar[D], span_origin, forward=forward], iter_origin ] @@ -89,17 +90,17 @@ struct _SpanSIMDIter[ return self @always_inline - fn __next__(mut self) -> SIMD[D, Self._width]: + fn __next__(mut self) -> SIMD[D, width]: @parameter if forward: var i = self.src[].index - self.src[].index += Self._width - return (self.src[].src.unsafe_ptr() + i).load[width = Self._width]() + self.src[].index += width + return (self.src[].src.unsafe_ptr() + i).load[width=width]() else: - self.src[].index -= Self._width + self.src[].index -= width return ( (self.src[].src.unsafe_ptr() + self.src[].index) - .load[width = Self._width]() + .load[width=width]() .reversed() ) @@ -107,7 +108,7 @@ struct _SpanSIMDIter[ fn __has_next__(self) -> Bool: @parameter if forward: - return (len(self.src[].src) - self.src[].index) >= Self._width + return (len(self.src[].src) - self.src[].index) >= width else: return self.src[].index > 0 @@ -115,9 +116,9 @@ struct _SpanSIMDIter[ fn __len__(self) -> Int: @parameter if forward: - return (len(self.src[].src) - self.src[].index) // Self._width + return (len(self.src[].src) - self.src[].index) // width else: - return self.src[].index // Self._width + return self.src[].index // width @value @@ -154,12 +155,17 @@ struct _SpanIter[ return Pointer.address_of(self.src[self.index]) fn vectorized[ - D: DType, O: ImmutableOrigin, // + D: DType, O: ImmutableOrigin, //, width: Int = simdwidthof[D]() ](mut self: _SpanIter[Scalar[D], O, forward=True]) -> _SpanSIMDIter[ - D, O, __origin_of(self), forward=True + D, O, __origin_of(self), width, forward=True ]: """Return a vectorized Span iterator. + Parameters: + D: The DType of the elements in the span. + O: The origin of the Span. + width: The width of the resulting vector. + Notes: This iterator should be used on 2 loops. @@ -179,7 +185,7 @@ struct _SpanIter[ ``` . """ - return _SpanSIMDIter[D, O, __origin_of(self), forward=True]( + return _SpanSIMDIter[D, O, __origin_of(self), width, forward=True]( Pointer.address_of(self) )