Skip to content

Commit

Permalink
Use interface static methods to define tensor operators
Browse files Browse the repository at this point in the history
  • Loading branch information
aalmada committed Dec 9, 2023
1 parent 7814756 commit 8902122
Show file tree
Hide file tree
Showing 24 changed files with 373 additions and 686 deletions.
64 changes: 0 additions & 64 deletions src/NetFabric.Numerics.Benchmarks/SpanVector2SumBenchmarks.cs

This file was deleted.

33 changes: 5 additions & 28 deletions src/NetFabric.Numerics.Tensors/Add.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,34 +8,14 @@ public static partial class Tensor
/// Adds a value to each element in the source span and stores the result in the destination span.
/// </summary>
/// <typeparam name="T">The type of the elements in the spans.</typeparam>
/// <param name="source">The source span.</param>
/// <param name="value">The value to add to each element.</param>
/// <param name="left">The source span.</param>
/// <param name="right">The value to add to each element.</param>
/// <param name="destination">The destination span to store the result.</param>
/// <exception cref="ArgumentException">Thrown when the source and destination spans have different lengths.</exception>
/// <exception cref="InvalidOperationException">Thrown when the type <typeparamref name="T"/> does not implement the <see cref="IAdditionOperators{T, T, T}"/> interface.</exception>
public static void Add<T>(ReadOnlySpan<T> source, T value, Span<T> destination)
public static void Add<T>(ReadOnlySpan<T> left, T right, Span<T> destination)
where T : struct, IAdditionOperators<T, T, T>
{
var add = new AddValueOperation<T>(value);
Apply(source, destination, ref add);
}

/// <summary>
/// Adds a value to a pair of elements in the source span and stores the result in the destination span.
/// </summary>
/// <typeparam name="T">The type of the elements in the spans.</typeparam>
/// <param name="source">The source span.</param>
/// <param name="value1">The value to be added to the first element of the pair.</param>
/// <param name="value2">The value to be added to the second element of the pair.</param>
/// <param name="destination">The destination span to store the result.</param>
/// <exception cref="ArgumentException">Thrown when the source and destination spans have different lengths.</exception>
/// <exception cref="InvalidOperationException">Thrown when the type <typeparamref name="T"/> does not implement the <see cref="IAdditionOperators{T, T, T}"/> interface.</exception>
public static void Add<T>(ReadOnlySpan<T> source, T value1, T value2, Span<T> destination)
where T : struct, IAdditionOperators<T, T, T>
{
var add = new AddValueOperation2D<T>(value1, value2);
Apply2D(source, destination, ref add);
}
=> Apply<T, AddOperator<T>>(left, right, destination);

/// <summary>
/// Adds corresponding elements in the left and right spans and stores the result in the destination span.
Expand All @@ -48,8 +28,5 @@ public static void Add<T>(ReadOnlySpan<T> source, T value1, T value2, Span<T> de
/// <exception cref="InvalidOperationException">Thrown when the type <typeparamref name="T"/> does not implement the <see cref="IAdditionOperators{T, T, T}"/> interface.</exception>
public static void Add<T>(ReadOnlySpan<T> left, ReadOnlySpan<T> right, Span<T> destination)
where T : struct, IAdditionOperators<T, T, T>
{
var add = new AddOperation<T>();
Apply(left, right, destination, ref add);
}
=> Apply<T, AddOperator<T>>(left, right, destination);
}
46 changes: 46 additions & 0 deletions src/NetFabric.Numerics.Tensors/Aggregate.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

using System;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace NetFabric.Numerics
{
public static partial class Tensor
{
/// <summary>
/// Aggregates the elements in the specified <see cref="ReadOnlySpan{T}"/> using the specified <see cref="IAggregationOperator{T}"/>.
/// </summary>
/// <typeparam name="T">The type of the elements in the span.</typeparam>
/// <typeparam name="TOperator">The type of the aggregation operator.</typeparam>
/// <param name="source">The source span.</param>
/// <returns>The aggregated value.</returns>
public static T Aggregate<T, TOperator>(ReadOnlySpan<T> source)
where T : struct
where TOperator : struct, IAggregationOperator<T>
{
var result = TOperator.Seed;
var resultVector = new Vector<T>(TOperator.Seed);
nint index = 0;

if (Vector.IsHardwareAccelerated &&
Vector<T>.IsSupported &&
source.Length >= Vector<T>.Count)
{
var sourceVectors = MemoryMarshal.Cast<T, Vector<T>>(source);

ref var sourceVectorsRef = ref MemoryMarshal.GetReference(sourceVectors);
for (nint indexVector = 0; indexVector < sourceVectors.Length; indexVector++)
resultVector = TOperator.Invoke(resultVector, Unsafe.Add(ref sourceVectorsRef, indexVector));

index = source.Length - source.Length % Vector<T>.Count;
}

ref var sourceRef = ref MemoryMarshal.GetReference(source);
for (; index < source.Length; index++)
result = TOperator.Invoke(result, Unsafe.Add(ref sourceRef, index));

return TOperator.ResultSelector(result, resultVector);
}
}
}
Loading

0 comments on commit 8902122

Please sign in to comment.