Skip to content

Commit

Permalink
Merge pull request #107 from Cysharp/hadashiA/aggregate-by
Browse files Browse the repository at this point in the history
Add AggregateBy operator
  • Loading branch information
neuecc authored Feb 16, 2024
2 parents f4c338c + 60b9463 commit 383e756
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 1 deletion.
1 change: 0 additions & 1 deletion src/R3/Operators/AggregateAsync.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ protected override void OnCompletedCore(Result result)
}
}


internal sealed class AggregateAsync<T, TAccumulate, TResult>(
TAccumulate seed,
Func<TAccumulate, T, TAccumulate> func,
Expand Down
105 changes: 105 additions & 0 deletions src/R3/Operators/AggregateByAsync.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
namespace R3;

public static partial class ObservableExtensions
{
public static Task<IEnumerable<KeyValuePair<TKey, TAccumulate>>> AggregateByAsync<TSource, TKey, TAccumulate>(
this Observable<TSource> source,
Func<TSource, TKey> keySelector,
TAccumulate seed,
Func<TAccumulate, TSource, TAccumulate> func,
IEqualityComparer<TKey>? keyComparer = null,
CancellationToken cancellationToken = default)
where TKey : notnull
{
var observer = new AggregateByAsync<TSource, TKey, TAccumulate>(keySelector, seed, func, keyComparer, cancellationToken);
source.Subscribe(observer);
return observer.Task;
}

public static Task<IEnumerable<KeyValuePair<TKey, TAccumulate>>> AggregateByAsync<TSource, TKey, TAccumulate>(
this Observable<TSource> source,
Func<TSource, TKey> keySelector,
Func<TKey, TAccumulate> seedSelector,
Func<TAccumulate, TSource, TAccumulate> func,
IEqualityComparer<TKey>? keyComparer = null,
CancellationToken cancellationToken = default)
where TKey : notnull
{
var observer = new AggregateByAsyncSeedSelector<TSource, TKey, TAccumulate>(keySelector, seedSelector, func, keyComparer, cancellationToken);
source.Subscribe(observer);
return observer.Task;
}
}

internal sealed class AggregateByAsync<TSource, TKey, TAccumulate>(
Func<TSource, TKey> keySelector,
TAccumulate seed,
Func<TAccumulate, TSource, TAccumulate> func,
IEqualityComparer<TKey>? keyComparer,
CancellationToken cancellationToken)
: TaskObserverBase<TSource, IEnumerable<KeyValuePair<TKey, TAccumulate>>>(cancellationToken)
{
readonly Dictionary<TKey, TAccumulate> dictionary = new(keyComparer);

protected override void OnNextCore(TSource value)
{
var key = keySelector(value);
if (!dictionary.TryGetValue(key, out var currentAccumulate))
{
currentAccumulate = seed;
}
dictionary[key] = func(currentAccumulate, value);
}

protected override void OnErrorResumeCore(Exception error)
{
TrySetException(error);
}

protected override void OnCompletedCore(Result result)
{
if (result.IsFailure)
{
TrySetException(result.Exception);
return;
}
TrySetResult(dictionary);
}
}

internal sealed class AggregateByAsyncSeedSelector<TSource, TKey, TAccumulate>(
Func<TSource, TKey> keySelector,
Func<TKey, TAccumulate> seedSelector,
Func<TAccumulate, TSource, TAccumulate> func,
IEqualityComparer<TKey>? keyComparer,
CancellationToken cancellationToken)
: TaskObserverBase<TSource, IEnumerable<KeyValuePair<TKey, TAccumulate>>>(cancellationToken)
{
readonly Dictionary<TKey, TAccumulate> dictionary = new(keyComparer);

protected override void OnNextCore(TSource value)
{
var key = keySelector(value);
if (!dictionary.TryGetValue(key, out var currentAccumulate))
{
currentAccumulate = seedSelector(key);
}
dictionary[key] = func(currentAccumulate, value);
}

protected override void OnErrorResumeCore(Exception error)
{
TrySetException(error);
}

protected override void OnCompletedCore(Result result)
{
if (result.IsFailure)
{
TrySetException(result.Exception);
return;
}
TrySetResult(dictionary);
}
}

81 changes: 81 additions & 0 deletions tests/R3.Tests/OperatorTests/AggregateByTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
namespace R3.Tests.OperatorTests;

public class AggregateByTest
{
[Fact]
public async Task AggregateBy()
{
var publisher = new Subject<int>();

var task = publisher.AggregateByAsync(x => x % 2 == 0, 100, (sum, x) => x + sum);

publisher.OnNext(1);
publisher.OnNext(2);
publisher.OnNext(3);
publisher.OnNext(4);
publisher.OnNext(5);

task.Status.Should().Be(TaskStatus.WaitingForActivation);

publisher.OnCompleted();

var result = await task;
result.FirstOrDefault(x => x.Key).Value.Should().Be(106);
result.FirstOrDefault(x => !x.Key).Value.Should().Be(109);
}

[Fact]
public async Task AggregateBy_Empty()
{
var publisher = new Subject<int>();

var task = publisher.AggregateByAsync(x => x % 2 == 0, 100, (sum, x) => x + sum);

task.Status.Should().Be(TaskStatus.WaitingForActivation);
publisher.OnCompleted();

(await task).Should().BeEmpty();
}

[Fact]
public async Task AggregateBy_One()
{
var publisher = new Subject<int>();

var task = publisher.AggregateByAsync(x => x % 2 == 0, 100, (sum, x) => x + sum);

publisher.OnNext(2);

task.Status.Should().Be(TaskStatus.WaitingForActivation);
publisher.OnCompleted();

var result = await task;
result.Count().Should().Be(1);
result.First().Value.Should().Be(102);
}

[Fact]
public async Task AggregateBy_SeedSelector()
{
var publisher = new Subject<int>();

var task = publisher.AggregateByAsync(
x => x % 2 == 0,
key => key ? 100 : 0,
(sum, x) => x + sum);

publisher.OnNext(1);
publisher.OnNext(2);
publisher.OnNext(3);
publisher.OnNext(4);
publisher.OnNext(5);

task.Status.Should().Be(TaskStatus.WaitingForActivation);

publisher.OnCompleted();

var result = await task;
result.FirstOrDefault(x => x.Key).Value.Should().Be(106);
result.FirstOrDefault(x => !x.Key).Value.Should().Be(9);
}
}

0 comments on commit 383e756

Please sign in to comment.