Skip to content

CSHARP-3458: Extend IAsyncCursor and IAsyncCursorSource to support IAsyncEnumerable #1708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/MongoDB.Driver/Core/IAsyncCursor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,17 @@ public static class IAsyncCursorExtensions
return new AsyncCursorEnumerableOneTimeAdapter<TDocument>(cursor, cancellationToken);
}

/// <summary>
/// Wraps a cursor in an IAsyncEnumerable that can be enumerated one time.
/// </summary>
/// <typeparam name="TDocument">The type of the document.</typeparam>
/// <param name="cursor">The cursor.</param>
/// <returns>An IAsyncEnumerable.</returns>
public static IAsyncEnumerable<TDocument> ToAsyncEnumerable<TDocument>(this IAsyncCursor<TDocument> cursor)
{
return new AsyncCursorEnumerableOneTimeAdapter<TDocument>(cursor, CancellationToken.None);
}

/// <summary>
/// Returns a list containing all the documents returned by a cursor.
/// </summary>
Expand Down
12 changes: 12 additions & 0 deletions src/MongoDB.Driver/Core/IAsyncCursorSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,18 @@ public static class IAsyncCursorSourceExtensions
return new AsyncCursorSourceEnumerableAdapter<TDocument>(source, cancellationToken);
}

/// <summary>
/// Wraps a cursor source in an IAsyncEnumerable. Each time GetAsyncEnumerator is called a new enumerator is returned and a new cursor
/// is fetched from the cursor source on the first call to MoveNextAsync.
/// </summary>
/// <typeparam name="TDocument">The type of the document.</typeparam>
/// <param name="source">The source.</param>
/// <returns>An IAsyncEnumerable.</returns>
public static IAsyncEnumerable<TDocument> ToAsyncEnumerable<TDocument>(this IAsyncCursorSource<TDocument> source)
{
return new AsyncCursorSourceEnumerableAdapter<TDocument>(source, CancellationToken.None);
}

/// <summary>
/// Returns a list containing all the documents returned by the cursor returned by a cursor source.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

namespace MongoDB.Driver.Core.Operations
{
internal sealed class AsyncCursorEnumerableOneTimeAdapter<TDocument> : IEnumerable<TDocument>
internal sealed class AsyncCursorEnumerableOneTimeAdapter<TDocument> : IEnumerable<TDocument>, IAsyncEnumerable<TDocument>
{
private readonly CancellationToken _cancellationToken;
private readonly IAsyncCursor<TDocument> _cursor;
Expand All @@ -33,6 +33,16 @@ public AsyncCursorEnumerableOneTimeAdapter(IAsyncCursor<TDocument> cursor, Cance
_cancellationToken = cancellationToken;
}

public IAsyncEnumerator<TDocument> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
if (_hasBeenEnumerated)
{
throw new InvalidOperationException("An IAsyncCursor can only be enumerated once.");
}
_hasBeenEnumerated = true;
return new AsyncCursorEnumerator<TDocument>(_cursor, cancellationToken);
}

public IEnumerator<TDocument> GetEnumerator()
{
if (_hasBeenEnumerated)
Expand Down
58 changes: 45 additions & 13 deletions src/MongoDB.Driver/Core/Operations/AsyncCursorEnumerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
using System.Collections;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using MongoDB.Driver.Core.Misc;

namespace MongoDB.Driver.Core.Operations
{
internal class AsyncCursorEnumerator<TDocument> : IEnumerator<TDocument>
internal sealed class AsyncCursorEnumerator<TDocument> : IEnumerator<TDocument>, IAsyncEnumerator<TDocument>
{
// private fields
private IEnumerator<TDocument> _batchEnumerator;
Expand Down Expand Up @@ -72,6 +73,15 @@ public void Dispose()
}
}

public ValueTask DisposeAsync()
{
// TODO: implement true async disposal (CSHARP-5630)
Dispose();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a comment like:

 Dispose(); // TODO: implement true async disposal


// TODO: convert to ValueTask.CompletedTask once we stop supporting older target frameworks
return default; // Equivalent to ValueTask.CompletedTask which is not available on older target frameworks.
}

public bool MoveNext()
{
ThrowIfDisposed();
Expand All @@ -82,24 +92,46 @@ public bool MoveNext()
return true;
}

while (true)
while (_cursor.MoveNext(_cancellationToken))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good refactoring

{
if (_cursor.MoveNext(_cancellationToken))
_batchEnumerator?.Dispose();
_batchEnumerator = _cursor.Current.GetEnumerator();
if (_batchEnumerator.MoveNext())
{
_batchEnumerator?.Dispose();
_batchEnumerator = _cursor.Current.GetEnumerator();
if (_batchEnumerator.MoveNext())
{
return true;
}
return true;
}
else
}

_batchEnumerator?.Dispose();
_batchEnumerator = null;
_finished = true;
return false;
}

public async ValueTask<bool> MoveNextAsync()
{
ThrowIfDisposed();
_started = true;

if (_batchEnumerator != null && _batchEnumerator.MoveNext())
{
return true;
}

while (await _cursor.MoveNextAsync(_cancellationToken).ConfigureAwait(false))
{
_batchEnumerator?.Dispose();
_batchEnumerator = _cursor.Current.GetEnumerator();
if (_batchEnumerator.MoveNext())
{
_batchEnumerator = null;
_finished = true;
return false;
return true;
}
}

_batchEnumerator?.Dispose();
_batchEnumerator = null;
_finished = true;
return false;
}

public void Reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
* limitations under the License.
*/

using System;
using System.Collections;
using System.Collections.Generic;
using System.Threading;
using MongoDB.Driver.Core.Misc;

namespace MongoDB.Driver.Core.Operations
{
internal class AsyncCursorSourceEnumerableAdapter<TDocument> : IEnumerable<TDocument>
internal sealed class AsyncCursorSourceEnumerableAdapter<TDocument> : IEnumerable<TDocument>, IAsyncEnumerable<TDocument>
{
// private fields
private readonly CancellationToken _cancellationToken;
Expand All @@ -34,6 +33,11 @@ public AsyncCursorSourceEnumerableAdapter(IAsyncCursorSource<TDocument> source,
_cancellationToken = cancellationToken;
}

public IAsyncEnumerator<TDocument> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new AsyncCursorSourceEnumerator<TDocument>(_source, cancellationToken);
}

// public methods
public IEnumerator<TDocument> GetEnumerator()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using MongoDB.Driver.Core.Misc;

namespace MongoDB.Driver.Core.Operations
{
#pragma warning disable CA1001
// we are suppressing this warning as we currently use the old Microsoft.CodeAnalysis.FxCopAnalyzers which doesn't
// have a concept of IAsyncDisposable.
// TODO: remove this suppression once we update our analyzers to use Microsoft.CodeAnalysis.NetAnalyzers
internal sealed class AsyncCursorSourceEnumerator<TDocument> : IAsyncEnumerator<TDocument>
#pragma warning restore CA1001
{
private readonly CancellationToken _cancellationToken;
private AsyncCursorEnumerator<TDocument> _cursorEnumerator;
private readonly IAsyncCursorSource<TDocument> _cursorSource;
private bool _disposed;

public AsyncCursorSourceEnumerator(IAsyncCursorSource<TDocument> cursorSource, CancellationToken cancellationToken)
{
_cursorSource = Ensure.IsNotNull(cursorSource, nameof(cursorSource));
_cancellationToken = cancellationToken;
}

public TDocument Current
{
get
{
if (_cursorEnumerator == null)
{
throw new InvalidOperationException("Enumeration has not started. Call MoveNextAsync.");
}
return _cursorEnumerator.Current;
}
}

public async ValueTask DisposeAsync()
{
if (!_disposed)
{
_disposed = true;

if (_cursorEnumerator != null)
{
await _cursorEnumerator.DisposeAsync().ConfigureAwait(false);
}
}
}

public async ValueTask<bool> MoveNextAsync()
{
ThrowIfDisposed();

if (_cursorEnumerator == null)
{
var cursor = await _cursorSource.ToCursorAsync(_cancellationToken).ConfigureAwait(false);
_cursorEnumerator = new AsyncCursorEnumerator<TDocument>(cursor, _cancellationToken);
}

return await _cursorEnumerator.MoveNextAsync().ConfigureAwait(false);
}

public void Reset()
{
ThrowIfDisposed();
throw new NotSupportedException();
}

// private methods
private void ThrowIfDisposed()
{
if (_disposed)
{
throw new ObjectDisposedException(GetType().Name);
}
}
}
}
12 changes: 12 additions & 0 deletions src/MongoDB.Driver/Linq/MongoQueryable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3385,6 +3385,18 @@ public static IQueryable<TSource> Take<TSource>(this IQueryable<TSource> source,
Expression.Constant(count)));
}

/// <summary>
/// Returns an <see cref="IAsyncEnumerable{T}" /> which can be enumerated asynchronously.
/// </summary>
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
/// <param name="source">A sequence of values.</param>
/// <returns>An IAsyncEnumerable for the query results.</returns>
public static IAsyncEnumerable<TSource> ToAsyncEnumerable<TSource>(this IQueryable<TSource> source)
{
var cursorSource = GetCursorSource(source);
return cursorSource.ToAsyncEnumerable();
}

/// <summary>
/// Executes the LINQ query and returns a cursor to the results.
/// </summary>
Expand Down
3 changes: 2 additions & 1 deletion tests/MongoDB.Driver.Tests/BulkWriteErrorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
using System.Threading.Tasks;
using FluentAssertions;
using MongoDB.Bson;
using MongoDB.Driver.Core.Operations;
using Xunit;

namespace MongoDB.Driver.Tests
Expand All @@ -34,7 +35,7 @@ public class BulkWriteErrorTests
[InlineData(12582, ServerErrorCategory.DuplicateKey)]
public void Should_translate_category_correctly(int code, ServerErrorCategory expectedCategory)
{
var coreError = new Core.Operations.BulkWriteOperationError(0, code, "blah", new BsonDocument());
var coreError = new BulkWriteOperationError(0, code, "blah", new BsonDocument());
var subject = BulkWriteError.FromCore(coreError);

subject.Category.Should().Be(expectedCategory);
Expand Down
51 changes: 51 additions & 0 deletions tests/MongoDB.Driver.Tests/Core/IAsyncCursorExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
using MongoDB.Bson;
using MongoDB.Bson.Serialization.Serializers;
Expand Down Expand Up @@ -201,6 +203,55 @@ public void SingleOrDefault_should_throw_when_cursor_has_wrong_number_of_documen
action.ShouldThrow<InvalidOperationException>();
}

[Fact]
public void ToAsyncEnumerable_result_should_only_be_enumerable_one_time()
{
var cursor = CreateCursor(2);
var enumerable = cursor.ToAsyncEnumerable();
enumerable.GetAsyncEnumerator();

Record.Exception(() => enumerable.GetAsyncEnumerator()).Should().BeOfType<InvalidOperationException>();
}

[Fact]
public async Task ToAsyncEnumerable_should_respect_cancellation_token()
{
var source = CreateCursor(5);
using var cts = new CancellationTokenSource();

var count = 0;
var exception = await Record.ExceptionAsync(async () =>
{
await foreach (var doc in source.ToAsyncEnumerable().WithCancellation(cts.Token))
{
count++;
if (count == 2)
cts.Cancel();
}
});

exception.Should().BeOfType<OperationCanceledException>();
}

[Fact]
public async Task ToAsyncEnumerable_should_return_expected_result()
{
var cursor = CreateCursor(2);
var expectedDocuments = new[]
{
new BsonDocument("_id", 0),
new BsonDocument("_id", 1)
};

var result = new List<BsonDocument>();
await foreach (var doc in cursor.ToAsyncEnumerable())
{
result.Add(doc);
}

result.Should().Equal(expectedDocuments);
}

[Fact]
public void ToEnumerable_result_should_only_be_enumerable_one_time()
{
Expand Down
Loading