-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support sum of 2D vectors using tensor
- Loading branch information
Showing
8 changed files
with
510 additions
and
382 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
|
||
using System; | ||
using System.Numerics; | ||
using System.Runtime.CompilerServices; | ||
using System.Runtime.InteropServices; | ||
|
||
namespace NetFabric.Numerics | ||
{ | ||
public static partial class Tensor | ||
{ | ||
/// <summary> | ||
/// Aggregates the elements in the specified <see cref="ReadOnlySpan{T}"/> using the specified <see cref="IAggregationOperator{T}"/>. | ||
/// </summary> | ||
/// <typeparam name="T">The type of the elements in the span.</typeparam> | ||
/// <typeparam name="TOperator">The type of the aggregation operator.</typeparam> | ||
/// <param name="source">The source span.</param> | ||
/// <returns>The aggregated value.</returns> | ||
public static ValueTuple<T, T> AggregatePair<T, TOperator>(ReadOnlySpan<T> source) | ||
where T : struct | ||
where TOperator : struct, IPairAggregationOperator<T> | ||
{ | ||
var result = TOperator.Seed; | ||
var resultVector = Vector<T>.Zero; | ||
|
||
nint index = 0; | ||
|
||
if (Vector.IsHardwareAccelerated && | ||
Vector<T>.IsSupported && | ||
Vector<T>.Count % 2 == 0 && | ||
source.Length >= Vector<T>.Count) | ||
{ | ||
var resutArray = new T[Vector<T>.Count]; | ||
ref var resultRef = ref MemoryMarshal.GetReference<T>(resutArray); | ||
for (nint indexVector = 0; indexVector < resutArray.Length; indexVector += 2) | ||
{ | ||
Unsafe.Add(ref resultRef, indexVector) = TOperator.Seed.Item1; | ||
Unsafe.Add(ref resultRef, indexVector + 1) = TOperator.Seed.Item2; | ||
} | ||
resultVector = new Vector<T>(resutArray); | ||
|
||
var sourceVectors = MemoryMarshal.Cast<T, Vector<T>>(source); | ||
|
||
ref var sourceVectorsRef = ref MemoryMarshal.GetReference(sourceVectors); | ||
for (nint indexVector = 0; indexVector < sourceVectors.Length; indexVector++) | ||
resultVector = TOperator.Invoke(resultVector, Unsafe.Add(ref sourceVectorsRef, indexVector)); | ||
|
||
index = source.Length - source.Length % Vector<T>.Count; | ||
} | ||
|
||
ref var sourceRef = ref MemoryMarshal.GetReference(source); | ||
for (; index < source.Length; index += 2) | ||
{ | ||
result.Item1 = TOperator.Invoke(result.Item1, Unsafe.Add(ref sourceRef, index)); | ||
result.Item2 = TOperator.Invoke(result.Item2, Unsafe.Add(ref sourceRef, index + 1)); | ||
} | ||
|
||
return TOperator.ResultSelector(result, resultVector); | ||
} | ||
} | ||
} |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
using System; | ||
using System.Numerics; | ||
using System.Runtime.CompilerServices; | ||
using System.Runtime.InteropServices; | ||
|
||
namespace NetFabric.Numerics; | ||
|
||
public static partial class Tensor | ||
{ | ||
public static void Apply<T, TOperator>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination, bool useIntrinsics = true) | ||
where T : struct | ||
where TOperator : struct, IBinaryOperator<T> | ||
{ | ||
if(x.Length != y.Length) | ||
Throw.ArgumentException(nameof(y), "x and y spans must have the same length."); | ||
if (x.Length > destination.Length) | ||
Throw.ArgumentException(nameof(destination), "Destination span is too small."); | ||
if(SpansOverlapAndAreNotSame(x, destination)) | ||
Throw.ArgumentException(nameof(destination), "Destination span overlaps with x."); | ||
if(SpansOverlapAndAreNotSame(y, destination)) | ||
Throw.ArgumentException(nameof(destination), "Destination span overlaps with y."); | ||
|
||
// Initialize the index to 0. | ||
nint index = 0; | ||
|
||
// Check if hardware acceleration and Vector<T> support are available, | ||
// and if the length of the x is greater than the Vector<T>.Count. | ||
if (useIntrinsics && | ||
Vector.IsHardwareAccelerated && | ||
Vector<T>.IsSupported && | ||
x.Length >= Vector<T>.Count) | ||
{ | ||
// Cast the spans to vectors for hardware acceleration. | ||
var xVectors = MemoryMarshal.Cast<T, Vector<T>>(x); | ||
var yVectors = MemoryMarshal.Cast<T, Vector<T>>(y); | ||
var destinationVectors = MemoryMarshal.Cast<T, Vector<T>>(destination); | ||
|
||
// Iterate through the vectors. | ||
ref var xVectorsRef = ref MemoryMarshal.GetReference(xVectors); | ||
ref var yVectorsRef = ref MemoryMarshal.GetReference(yVectors); | ||
ref var destinationVectorsRef = ref MemoryMarshal.GetReference(destinationVectors); | ||
for (nint indexVector = 0; indexVector < xVectors.Length; indexVector++) | ||
{ | ||
Unsafe.Add(ref destinationVectorsRef, indexVector) = TOperator.Invoke( | ||
Unsafe.Add(ref xVectorsRef, indexVector), | ||
Unsafe.Add(ref yVectorsRef, indexVector)); | ||
} | ||
|
||
// Update the index to the end of the last complete vector. | ||
index = x.Length - x.Length % Vector<T>.Count; | ||
} | ||
|
||
// Iterate through the remaining elements. | ||
ref var xRef = ref MemoryMarshal.GetReference(x); | ||
ref var yRef = ref MemoryMarshal.GetReference(y); | ||
ref var destinationRef = ref MemoryMarshal.GetReference(destination); | ||
for (; index < x.Length; index++) | ||
{ | ||
Unsafe.Add(ref destinationRef, index) = TOperator.Invoke( | ||
Unsafe.Add(ref xRef, index), | ||
Unsafe.Add(ref yRef, index)); | ||
} | ||
} | ||
|
||
public static void Apply<T, TOperator>(ReadOnlySpan<T> x, T y, Span<T> destination, bool useIntrinsics = true) | ||
where T : struct | ||
where TOperator : struct, IBinaryOperator<T> | ||
{ | ||
if (x.Length > destination.Length) | ||
Throw.ArgumentException(nameof(destination), "Destination span is too small."); | ||
if(SpansOverlapAndAreNotSame(x, destination)) | ||
Throw.ArgumentException(nameof(destination), "Destination span overlaps with x."); | ||
|
||
// Initialize the index to 0. | ||
nint index = 0; | ||
|
||
// Check if hardware acceleration and Vector<T> support are available, | ||
// and if the length of the x is greater than the Vector<T>.Count. | ||
if (useIntrinsics && | ||
Vector.IsHardwareAccelerated && | ||
Vector<T>.IsSupported && | ||
x.Length >= Vector<T>.Count) | ||
{ | ||
// Cast the spans to vectors for hardware acceleration. | ||
var xVectors = MemoryMarshal.Cast<T, Vector<T>>(x); | ||
var valueVector = new Vector<T>(y); | ||
var destinationVectors = MemoryMarshal.Cast<T, Vector<T>>(destination); | ||
|
||
// Iterate through the vectors. | ||
ref var xVectorsRef = ref MemoryMarshal.GetReference(xVectors); | ||
ref var destinationVectorsRef = ref MemoryMarshal.GetReference(destinationVectors); | ||
for (nint indexVector = 0; indexVector < xVectors.Length; indexVector++) | ||
{ | ||
Unsafe.Add(ref destinationVectorsRef, indexVector) = TOperator.Invoke( | ||
Unsafe.Add(ref xVectorsRef, indexVector), | ||
valueVector); | ||
} | ||
|
||
// Update the index to the end of the last complete vector. | ||
index = x.Length - x.Length % Vector<T>.Count; | ||
} | ||
|
||
// Iterate through the remaining elements. | ||
ref var xRef = ref MemoryMarshal.GetReference(x); | ||
ref var destinationRef = ref MemoryMarshal.GetReference(destination); | ||
for (; index < x.Length; index++) | ||
{ | ||
Unsafe.Add(ref destinationRef, index) = TOperator.Invoke( | ||
Unsafe.Add(ref xRef, index), | ||
y); | ||
} | ||
} | ||
|
||
} |
Oops, something went wrong.