Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use interface static methods to define tensor operators #36

Merged
merged 2 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 0 additions & 64 deletions src/NetFabric.Numerics.Benchmarks/SpanVector2SumBenchmarks.cs

This file was deleted.

33 changes: 5 additions & 28 deletions src/NetFabric.Numerics.Tensors/Add.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,20 @@

namespace NetFabric.Numerics;

public static partial class Tensor

Check warning on line 5 in src/NetFabric.Numerics.Tensors/Add.cs

View workflow job for this annotation

GitHub Actions / build

Missing XML comment for publicly visible type or member 'Tensor'

Check warning on line 5 in src/NetFabric.Numerics.Tensors/Add.cs

View workflow job for this annotation

GitHub Actions / build

Missing XML comment for publicly visible type or member 'Tensor'
{
/// <summary>
/// Adds a value to each element in the source span and stores the result in the destination span.
/// </summary>
/// <typeparam name="T">The type of the elements in the spans.</typeparam>
/// <param name="source">The source span.</param>
/// <param name="value">The value to add to each element.</param>
/// <param name="left">The source span.</param>
/// <param name="right">The value to add to each element.</param>
/// <param name="destination">The destination span to store the result.</param>
/// <exception cref="ArgumentException">Thrown when the source and destination spans have different lengths.</exception>
/// <exception cref="InvalidOperationException">Thrown when the type <typeparamref name="T"/> does not implement the <see cref="IAdditionOperators{T, T, T}"/> interface.</exception>
public static void Add<T>(ReadOnlySpan<T> source, T value, Span<T> destination)
public static void Add<T>(ReadOnlySpan<T> left, T right, Span<T> destination)
where T : struct, IAdditionOperators<T, T, T>
{
var add = new AddValueOperation<T>(value);
Apply(source, destination, ref add);
}

/// <summary>
/// Adds a value to a pair of elements in the source span and stores the result in the destination span.
/// </summary>
/// <typeparam name="T">The type of the elements in the spans.</typeparam>
/// <param name="source">The source span.</param>
/// <param name="value1">The value to be added to the first element of the pair.</param>
/// <param name="value2">The value to be added to the second element of the pair.</param>
/// <param name="destination">The destination span to store the result.</param>
/// <exception cref="ArgumentException">Thrown when the source and destination spans have different lengths.</exception>
/// <exception cref="InvalidOperationException">Thrown when the type <typeparamref name="T"/> does not implement the <see cref="IAdditionOperators{T, T, T}"/> interface.</exception>
public static void Add<T>(ReadOnlySpan<T> source, T value1, T value2, Span<T> destination)
where T : struct, IAdditionOperators<T, T, T>
{
var add = new AddValueOperation2D<T>(value1, value2);
Apply2D(source, destination, ref add);
}
=> Apply<T, AddOperator<T>>(left, right, destination);

/// <summary>
/// Adds corresponding elements in the left and right spans and stores the result in the destination span.
Expand All @@ -48,8 +28,5 @@
/// <exception cref="InvalidOperationException">Thrown when the type <typeparamref name="T"/> does not implement the <see cref="IAdditionOperators{T, T, T}"/> interface.</exception>
public static void Add<T>(ReadOnlySpan<T> left, ReadOnlySpan<T> right, Span<T> destination)
where T : struct, IAdditionOperators<T, T, T>
{
var add = new AddOperation<T>();
Apply(left, right, destination, ref add);
}
=> Apply<T, AddOperator<T>>(left, right, destination);
}
46 changes: 46 additions & 0 deletions src/NetFabric.Numerics.Tensors/Aggregate.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

using System;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace NetFabric.Numerics
{
public static partial class Tensor
{
/// <summary>
/// Aggregates the elements in the specified <see cref="ReadOnlySpan{T}"/> using the specified <see cref="IAggregationOperator{T}"/>.
/// </summary>
/// <typeparam name="T">The type of the elements in the span.</typeparam>
/// <typeparam name="TOperator">The type of the aggregation operator.</typeparam>
/// <param name="source">The source span.</param>
/// <returns>The aggregated value.</returns>
public static T Aggregate<T, TOperator>(ReadOnlySpan<T> source)
where T : struct
where TOperator : struct, IAggregationOperator<T>
{
var result = TOperator.Seed;
var resultVector = new Vector<T>(TOperator.Seed);
nint index = 0;

if (Vector.IsHardwareAccelerated &&
Vector<T>.IsSupported &&
source.Length >= Vector<T>.Count)
{
var sourceVectors = MemoryMarshal.Cast<T, Vector<T>>(source);

ref var sourceVectorsRef = ref MemoryMarshal.GetReference(sourceVectors);
for (nint indexVector = 0; indexVector < sourceVectors.Length; indexVector++)
resultVector = TOperator.Invoke(resultVector, Unsafe.Add(ref sourceVectorsRef, indexVector));

index = source.Length - source.Length % Vector<T>.Count;
}

ref var sourceRef = ref MemoryMarshal.GetReference(source);
for (; index < source.Length; index++)
result = TOperator.Invoke(result, Unsafe.Add(ref sourceRef, index));

return TOperator.ResultSelector(result, resultVector);
}
}
}
Loading