Skip to content

Commit

Permalink
Support sum of 2D vectors using tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
aalmada committed Dec 31, 2023
1 parent e39e44d commit f951128
Show file tree
Hide file tree
Showing 8 changed files with 510 additions and 382 deletions.
60 changes: 60 additions & 0 deletions src/NetFabric.Numerics.Tensors/AggregatePair.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@

using System;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace NetFabric.Numerics
{
public static partial class Tensor
{
/// <summary>
/// Aggregates the elements in the specified <see cref="ReadOnlySpan{T}"/> using the specified <see cref="IAggregationOperator{T}"/>.
/// </summary>
/// <typeparam name="T">The type of the elements in the span.</typeparam>
/// <typeparam name="TOperator">The type of the aggregation operator.</typeparam>
/// <param name="source">The source span.</param>
/// <returns>The aggregated value.</returns>
public static ValueTuple<T, T> AggregatePair<T, TOperator>(ReadOnlySpan<T> source)
where T : struct
where TOperator : struct, IPairAggregationOperator<T>
{
var result = TOperator.Seed;
var resultVector = Vector<T>.Zero;

nint index = 0;

if (Vector.IsHardwareAccelerated &&
Vector<T>.IsSupported &&
Vector<T>.Count % 2 == 0 &&
source.Length >= Vector<T>.Count)
{
var resutArray = new T[Vector<T>.Count];
ref var resultRef = ref MemoryMarshal.GetReference<T>(resutArray);
for (nint indexVector = 0; indexVector < resutArray.Length; indexVector += 2)
{
Unsafe.Add(ref resultRef, indexVector) = TOperator.Seed.Item1;
Unsafe.Add(ref resultRef, indexVector + 1) = TOperator.Seed.Item2;
}
resultVector = new Vector<T>(resutArray);

var sourceVectors = MemoryMarshal.Cast<T, Vector<T>>(source);

ref var sourceVectorsRef = ref MemoryMarshal.GetReference(sourceVectors);
for (nint indexVector = 0; indexVector < sourceVectors.Length; indexVector++)
resultVector = TOperator.Invoke(resultVector, Unsafe.Add(ref sourceVectorsRef, indexVector));

index = source.Length - source.Length % Vector<T>.Count;
}

ref var sourceRef = ref MemoryMarshal.GetReference(source);
for (; index < source.Length; index += 2)
{
result.Item1 = TOperator.Invoke(result.Item1, Unsafe.Add(ref sourceRef, index));
result.Item2 = TOperator.Invoke(result.Item2, Unsafe.Add(ref sourceRef, index + 1));
}

return TOperator.ResultSelector(result, resultVector);
}
}
}
381 changes: 0 additions & 381 deletions src/NetFabric.Numerics.Tensors/Apply.cs

Large diffs are not rendered by default.

114 changes: 114 additions & 0 deletions src/NetFabric.Numerics.Tensors/ApplyBinary.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
using System;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace NetFabric.Numerics;

public static partial class Tensor
{
public static void Apply<T, TOperator>(ReadOnlySpan<T> x, ReadOnlySpan<T> y, Span<T> destination, bool useIntrinsics = true)

Check warning on line 10 in src/NetFabric.Numerics.Tensors/ApplyBinary.cs

View workflow job for this annotation

GitHub Actions / build

Missing XML comment for publicly visible type or member 'Tensor.Apply<T, TOperator>(ReadOnlySpan<T>, ReadOnlySpan<T>, Span<T>, bool)'
where T : struct
where TOperator : struct, IBinaryOperator<T>
{
if(x.Length != y.Length)
Throw.ArgumentException(nameof(y), "x and y spans must have the same length.");
if (x.Length > destination.Length)
Throw.ArgumentException(nameof(destination), "Destination span is too small.");
if(SpansOverlapAndAreNotSame(x, destination))
Throw.ArgumentException(nameof(destination), "Destination span overlaps with x.");
if(SpansOverlapAndAreNotSame(y, destination))
Throw.ArgumentException(nameof(destination), "Destination span overlaps with y.");

// Initialize the index to 0.
nint index = 0;

// Check if hardware acceleration and Vector<T> support are available,
// and if the length of the x is greater than the Vector<T>.Count.
if (useIntrinsics &&
Vector.IsHardwareAccelerated &&
Vector<T>.IsSupported &&
x.Length >= Vector<T>.Count)
{
// Cast the spans to vectors for hardware acceleration.
var xVectors = MemoryMarshal.Cast<T, Vector<T>>(x);
var yVectors = MemoryMarshal.Cast<T, Vector<T>>(y);
var destinationVectors = MemoryMarshal.Cast<T, Vector<T>>(destination);

// Iterate through the vectors.
ref var xVectorsRef = ref MemoryMarshal.GetReference(xVectors);
ref var yVectorsRef = ref MemoryMarshal.GetReference(yVectors);
ref var destinationVectorsRef = ref MemoryMarshal.GetReference(destinationVectors);
for (nint indexVector = 0; indexVector < xVectors.Length; indexVector++)
{
Unsafe.Add(ref destinationVectorsRef, indexVector) = TOperator.Invoke(
Unsafe.Add(ref xVectorsRef, indexVector),
Unsafe.Add(ref yVectorsRef, indexVector));
}

// Update the index to the end of the last complete vector.
index = x.Length - x.Length % Vector<T>.Count;
}

// Iterate through the remaining elements.
ref var xRef = ref MemoryMarshal.GetReference(x);
ref var yRef = ref MemoryMarshal.GetReference(y);
ref var destinationRef = ref MemoryMarshal.GetReference(destination);
for (; index < x.Length; index++)
{
Unsafe.Add(ref destinationRef, index) = TOperator.Invoke(
Unsafe.Add(ref xRef, index),
Unsafe.Add(ref yRef, index));
}
}

public static void Apply<T, TOperator>(ReadOnlySpan<T> x, T y, Span<T> destination, bool useIntrinsics = true)

Check warning on line 65 in src/NetFabric.Numerics.Tensors/ApplyBinary.cs

View workflow job for this annotation

GitHub Actions / build

Missing XML comment for publicly visible type or member 'Tensor.Apply<T, TOperator>(ReadOnlySpan<T>, T, Span<T>, bool)'
where T : struct
where TOperator : struct, IBinaryOperator<T>
{
if (x.Length > destination.Length)
Throw.ArgumentException(nameof(destination), "Destination span is too small.");
if(SpansOverlapAndAreNotSame(x, destination))
Throw.ArgumentException(nameof(destination), "Destination span overlaps with x.");

// Initialize the index to 0.
nint index = 0;

// Check if hardware acceleration and Vector<T> support are available,
// and if the length of the x is greater than the Vector<T>.Count.
if (useIntrinsics &&
Vector.IsHardwareAccelerated &&
Vector<T>.IsSupported &&
x.Length >= Vector<T>.Count)
{
// Cast the spans to vectors for hardware acceleration.
var xVectors = MemoryMarshal.Cast<T, Vector<T>>(x);
var valueVector = new Vector<T>(y);
var destinationVectors = MemoryMarshal.Cast<T, Vector<T>>(destination);

// Iterate through the vectors.
ref var xVectorsRef = ref MemoryMarshal.GetReference(xVectors);
ref var destinationVectorsRef = ref MemoryMarshal.GetReference(destinationVectors);
for (nint indexVector = 0; indexVector < xVectors.Length; indexVector++)
{
Unsafe.Add(ref destinationVectorsRef, indexVector) = TOperator.Invoke(
Unsafe.Add(ref xVectorsRef, indexVector),
valueVector);
}

// Update the index to the end of the last complete vector.
index = x.Length - x.Length % Vector<T>.Count;
}

// Iterate through the remaining elements.
ref var xRef = ref MemoryMarshal.GetReference(x);
ref var destinationRef = ref MemoryMarshal.GetReference(destination);
for (; index < x.Length; index++)
{
Unsafe.Add(ref destinationRef, index) = TOperator.Invoke(
Unsafe.Add(ref xRef, index),
y);
}
}

}
Loading

0 comments on commit f951128

Please sign in to comment.