Skip to content

Commit 629f60e

Browse files
authored
[QUIC] Add QuicStream.WaitForWriteCompletionAsync (#58236)
* Add QuicStream.WaitForWriteCompletionAsync * Fix flakey mock stream tests
1 parent 81f884d commit 629f60e

File tree

7 files changed

+392
-5
lines changed

7 files changed

+392
-5
lines changed

src/libraries/System.Net.Quic/ref/System.Net.Quic.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ public override void Flush() { }
109109
public override void SetLength(long value) { }
110110
public void Shutdown() { }
111111
public System.Threading.Tasks.ValueTask ShutdownCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
112+
public System.Threading.Tasks.ValueTask WaitForWriteCompletionAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
112113
public override void Write(byte[] buffer, int offset, int count) { }
113114
public override void Write(System.ReadOnlySpan<byte> buffer) { }
114115
public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence<byte> buffers, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }

src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockConnection.cs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
using System.Threading;
99
using System.Threading.Channels;
1010
using System.Threading.Tasks;
11+
using System.Collections.Concurrent;
12+
using System.Collections.Generic;
1113

1214
namespace System.Net.Quic.Implementations.Mock
1315
{
@@ -244,6 +246,9 @@ internal MockStream OpenStream(long streamId, bool bidirectional)
244246
}
245247

246248
MockStream.StreamState streamState = new MockStream.StreamState(streamId, bidirectional);
249+
// TODO Streams are never removed from a connection. Consider cleaning up in the future.
250+
state._streams[streamState._streamId] = streamState;
251+
247252
Channel<MockStream.StreamState> streamChannel = _isClient ? state._clientInitiatedStreamChannel : state._serverInitiatedStreamChannel;
248253
streamChannel.Writer.TryWrite(streamState);
249254

@@ -320,6 +325,12 @@ internal override ValueTask CloseAsync(long errorCode, CancellationToken cancell
320325
state._serverErrorCode = errorCode;
321326
DrainAcceptQueue(errorCode, -1);
322327
}
328+
329+
foreach (KeyValuePair<long, MockStream.StreamState> kvp in state._streams)
330+
{
331+
kvp.Value._outboundWritesCompletedTcs.TrySetException(new QuicConnectionAbortedException(errorCode));
332+
kvp.Value._inboundWritesCompletedTcs.TrySetException(new QuicConnectionAbortedException(errorCode));
333+
}
323334
}
324335

325336
Dispose();
@@ -474,8 +485,9 @@ public PeerStreamLimit(int maxUnidirectional, int maxBidirectional)
474485
internal sealed class ConnectionState
475486
{
476487
public readonly SslApplicationProtocol _applicationProtocol;
477-
public Channel<MockStream.StreamState> _clientInitiatedStreamChannel;
478-
public Channel<MockStream.StreamState> _serverInitiatedStreamChannel;
488+
public readonly Channel<MockStream.StreamState> _clientInitiatedStreamChannel;
489+
public readonly Channel<MockStream.StreamState> _serverInitiatedStreamChannel;
490+
public readonly ConcurrentDictionary<long, MockStream.StreamState> _streams;
479491

480492
public PeerStreamLimit? _clientStreamLimit;
481493
public PeerStreamLimit? _serverStreamLimit;
@@ -490,6 +502,7 @@ public ConnectionState(SslApplicationProtocol applicationProtocol)
490502
_clientInitiatedStreamChannel = Channel.CreateUnbounded<MockStream.StreamState>();
491503
_serverInitiatedStreamChannel = Channel.CreateUnbounded<MockStream.StreamState>();
492504
_clientErrorCode = _serverErrorCode = -1;
505+
_streams = new ConcurrentDictionary<long, MockStream.StreamState>();
493506
}
494507
}
495508
}

src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ internal override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, bool e
164164
if (endStream)
165165
{
166166
streamBuffer.EndWrite();
167+
WritesCompletedTcs.TrySetResult();
167168
}
168169
}
169170

@@ -206,10 +207,12 @@ internal override void AbortRead(long errorCode)
206207
if (_isInitiator)
207208
{
208209
_streamState._outboundWriteErrorCode = errorCode;
210+
_streamState._inboundWritesCompletedTcs.TrySetException(new QuicStreamAbortedException(errorCode));
209211
}
210212
else
211213
{
212214
_streamState._inboundWriteErrorCode = errorCode;
215+
_streamState._outboundWritesCompletedTcs.TrySetException(new QuicOperationAbortedException());
213216
}
214217

215218
ReadStreamBuffer?.AbortRead();
@@ -220,10 +223,12 @@ internal override void AbortWrite(long errorCode)
220223
if (_isInitiator)
221224
{
222225
_streamState._outboundReadErrorCode = errorCode;
226+
_streamState._outboundWritesCompletedTcs.TrySetException(new QuicStreamAbortedException(errorCode));
223227
}
224228
else
225229
{
226230
_streamState._inboundReadErrorCode = errorCode;
231+
_streamState._inboundWritesCompletedTcs.TrySetException(new QuicOperationAbortedException());
227232
}
228233

229234
WriteStreamBuffer?.EndWrite();
@@ -251,6 +256,8 @@ internal override void Shutdown()
251256
{
252257
_connection.LocalStreamLimit!.Bidirectional.Decrement();
253258
}
259+
260+
WritesCompletedTcs.TrySetResult();
254261
}
255262

256263
private void CheckDisposed()
@@ -283,6 +290,17 @@ public override ValueTask DisposeAsync()
283290
return default;
284291
}
285292

293+
internal override ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default)
294+
{
295+
CheckDisposed();
296+
297+
return new ValueTask(WritesCompletedTcs.Task);
298+
}
299+
300+
private TaskCompletionSource WritesCompletedTcs => _isInitiator
301+
? _streamState._outboundWritesCompletedTcs
302+
: _streamState._inboundWritesCompletedTcs;
303+
286304
internal sealed class StreamState
287305
{
288306
public readonly long _streamId;
@@ -292,6 +310,8 @@ internal sealed class StreamState
292310
public long _inboundReadErrorCode;
293311
public long _outboundWriteErrorCode;
294312
public long _inboundWriteErrorCode;
313+
public TaskCompletionSource _outboundWritesCompletedTcs;
314+
public TaskCompletionSource _inboundWritesCompletedTcs;
295315

296316
private const int InitialBufferSize =
297317
#if DEBUG
@@ -310,6 +330,8 @@ public StreamState(long streamId, bool bidirectional)
310330
_streamId = streamId;
311331
_outboundStreamBuffer = new StreamBuffer(initialBufferSize: InitialBufferSize, maxBufferSize: MaxBufferSize);
312332
_inboundStreamBuffer = (bidirectional ? new StreamBuffer() : null);
333+
_outboundWritesCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
334+
_inboundWritesCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
313335
}
314336
}
315337
}

src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs

Lines changed: 126 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ private sealed class State
6969
// Resettable completions to be used for multiple calls to send.
7070
public readonly ResettableCompletionSource<uint> SendResettableCompletionSource = new ResettableCompletionSource<uint>();
7171

72+
public ShutdownWriteState ShutdownWriteState;
73+
74+
// Set once writes have been shutdown.
75+
public readonly TaskCompletionSource ShutdownWriteCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
76+
7277
public ShutdownState ShutdownState;
7378
// The value makes sure that we release the handles only once.
7479
public int ShutdownDone;
@@ -577,12 +582,26 @@ internal override void AbortWrite(long errorCode)
577582
return;
578583
}
579584

585+
bool shouldComplete = false;
586+
580587
lock (_state)
581588
{
582589
if (_state.SendState < SendState.Aborted)
583590
{
584591
_state.SendState = SendState.Aborted;
585592
}
593+
594+
if (_state.ShutdownWriteState == ShutdownWriteState.None)
595+
{
596+
_state.ShutdownWriteState = ShutdownWriteState.Canceled;
597+
shouldComplete = true;
598+
}
599+
}
600+
601+
if (shouldComplete)
602+
{
603+
_state.ShutdownWriteCompletionSource.SetException(
604+
ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException("Write was aborted.")));
586605
}
587606

588607
StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND, errorCode);
@@ -629,6 +648,23 @@ internal override async ValueTask ShutdownCompleted(CancellationToken cancellati
629648
await _state.ShutdownCompletionSource.Task.ConfigureAwait(false);
630649
}
631650

651+
internal override ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default)
652+
{
653+
// TODO: What should happen if this is called for a unidirectional stream and there are no writes?
654+
655+
ThrowIfDisposed();
656+
657+
lock (_state)
658+
{
659+
if (_state.ShutdownWriteState == ShutdownWriteState.ConnectionClosed)
660+
{
661+
throw GetConnectionAbortedException(_state);
662+
}
663+
}
664+
665+
return new ValueTask(_state.ShutdownWriteCompletionSource.Task.WaitAsync(cancellationToken));
666+
}
667+
632668
internal override void Shutdown()
633669
{
634670
ThrowIfDisposed();
@@ -861,6 +897,11 @@ private static uint HandleEvent(State state, ref StreamEvent evt)
861897
// Peer has stopped receiving data, don't send anymore.
862898
case QUIC_STREAM_EVENT_TYPE.PEER_RECEIVE_ABORTED:
863899
return HandleEventPeerRecvAborted(state, ref evt);
900+
// Occurs when shutdown is completed for the send side.
901+
// This only happens for shutdown on sending, not receiving
902+
// Receive shutdown can only be abortive.
903+
case QUIC_STREAM_EVENT_TYPE.SEND_SHUTDOWN_COMPLETE:
904+
return HandleEventSendShutdownComplete(state, ref evt);
864905
// Shutdown for both sending and receiving is completed.
865906
case QUIC_STREAM_EVENT_TYPE.SHUTDOWN_COMPLETE:
866907
return HandleEventShutdownComplete(state, ref evt);
@@ -993,23 +1034,37 @@ private static unsafe uint HandleEventRecv(State state, ref StreamEvent evt)
9931034

9941035
private static uint HandleEventPeerRecvAborted(State state, ref StreamEvent evt)
9951036
{
996-
bool shouldComplete = false;
1037+
bool shouldSendComplete = false;
1038+
bool shouldShutdownWriteComplete = false;
9971039
lock (state)
9981040
{
9991041
if (state.SendState == SendState.None || state.SendState == SendState.Pending)
10001042
{
1001-
shouldComplete = true;
1043+
shouldSendComplete = true;
1044+
}
1045+
1046+
if (state.ShutdownWriteState == ShutdownWriteState.None)
1047+
{
1048+
state.ShutdownWriteState = ShutdownWriteState.Canceled;
1049+
shouldShutdownWriteComplete = true;
10021050
}
1051+
10031052
state.SendState = SendState.Aborted;
10041053
state.SendErrorCode = (long)evt.Data.PeerReceiveAborted.ErrorCode;
10051054
}
10061055

1007-
if (shouldComplete)
1056+
if (shouldSendComplete)
10081057
{
10091058
state.SendResettableCompletionSource.CompleteException(
10101059
ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.SendErrorCode)));
10111060
}
10121061

1062+
if (shouldShutdownWriteComplete)
1063+
{
1064+
state.ShutdownWriteCompletionSource.SetException(
1065+
ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.SendErrorCode)));
1066+
}
1067+
10131068
return MsQuicStatusCodes.Success;
10141069
}
10151070

@@ -1021,6 +1076,38 @@ private static uint HandleEventStartComplete(State state, ref StreamEvent evt)
10211076
return MsQuicStatusCodes.Success;
10221077
}
10231078

1079+
private static uint HandleEventSendShutdownComplete(State state, ref StreamEvent evt)
1080+
{
1081+
// Graceful will be false in three situations:
1082+
// 1. The peer aborted reads and the PEER_RECEIVE_ABORTED event was raised.
1083+
// ShutdownWriteCompletionSource is already complete with an error.
1084+
// 2. We aborted writes.
1085+
// ShutdownWriteCompletionSource is already complete with an error.
1086+
// 3. The connection was closed.
1087+
// SHUTDOWN_COMPLETE event will be raised immediately after this event. It will handle completing with an error.
1088+
//
1089+
// Only use this event with sends gracefully completed.
1090+
if (evt.Data.SendShutdownComplete.Graceful != 0)
1091+
{
1092+
bool shouldComplete = false;
1093+
lock (state)
1094+
{
1095+
if (state.ShutdownWriteState == ShutdownWriteState.None)
1096+
{
1097+
state.ShutdownWriteState = ShutdownWriteState.Finished;
1098+
shouldComplete = true;
1099+
}
1100+
}
1101+
1102+
if (shouldComplete)
1103+
{
1104+
state.ShutdownWriteCompletionSource.SetResult();
1105+
}
1106+
}
1107+
1108+
return MsQuicStatusCodes.Success;
1109+
}
1110+
10241111
private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt)
10251112
{
10261113
StreamEventDataShutdownComplete shutdownCompleteEvent = evt.Data.ShutdownComplete;
@@ -1031,6 +1118,7 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt
10311118
}
10321119

10331120
bool shouldReadComplete = false;
1121+
bool shouldShutdownWriteComplete = false;
10341122
bool shouldShutdownComplete = false;
10351123

10361124
lock (state)
@@ -1040,6 +1128,15 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt
10401128

10411129
shouldReadComplete = CleanupReadStateAndCheckPending(state, ReadState.ReadsCompleted);
10421130

1131+
if (state.ShutdownWriteState == ShutdownWriteState.None)
1132+
{
1133+
// TODO: We can get to this point if the stream is unidirectional and there are no writes.
1134+
// Consider what is the best behavior here with write shutdown and the read side of
1135+
// unidirecitonal streams in the future.
1136+
state.ShutdownWriteState = ShutdownWriteState.Finished;
1137+
shouldShutdownWriteComplete = true;
1138+
}
1139+
10431140
if (state.ShutdownState == ShutdownState.None)
10441141
{
10451142
state.ShutdownState = ShutdownState.Finished;
@@ -1052,6 +1149,11 @@ private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt
10521149
state.ReceiveResettableCompletionSource.Complete(0);
10531150
}
10541151

1152+
if (shouldShutdownWriteComplete)
1153+
{
1154+
state.ShutdownWriteCompletionSource.SetResult();
1155+
}
1156+
10551157
if (shouldShutdownComplete)
10561158
{
10571159
state.ShutdownCompletionSource.SetResult();
@@ -1361,6 +1463,7 @@ private static uint HandleEventConnectionClose(State state)
13611463

13621464
bool shouldCompleteRead = false;
13631465
bool shouldCompleteSend = false;
1466+
bool shouldCompleteShutdownWrite = false;
13641467
bool shouldCompleteShutdown = false;
13651468

13661469
lock (state)
@@ -1373,6 +1476,12 @@ private static uint HandleEventConnectionClose(State state)
13731476
}
13741477
state.SendState = SendState.ConnectionClosed;
13751478

1479+
if (state.ShutdownWriteState == ShutdownWriteState.None)
1480+
{
1481+
shouldCompleteShutdownWrite = true;
1482+
}
1483+
state.ShutdownWriteState = ShutdownWriteState.ConnectionClosed;
1484+
13761485
if (state.ShutdownState == ShutdownState.None)
13771486
{
13781487
shouldCompleteShutdown = true;
@@ -1392,6 +1501,12 @@ private static uint HandleEventConnectionClose(State state)
13921501
ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state)));
13931502
}
13941503

1504+
if (shouldCompleteShutdownWrite)
1505+
{
1506+
state.ShutdownWriteCompletionSource.SetException(
1507+
ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state)));
1508+
}
1509+
13951510
if (shouldCompleteShutdown)
13961511
{
13971512
state.ShutdownCompletionSource.SetException(
@@ -1493,6 +1608,14 @@ private enum ReadState
14931608
Closed
14941609
}
14951610

1611+
private enum ShutdownWriteState
1612+
{
1613+
None = 0,
1614+
Canceled,
1615+
Finished,
1616+
ConnectionClosed
1617+
}
1618+
14961619
private enum ShutdownState
14971620
{
14981621
None = 0,

src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ internal abstract class QuicStreamProvider : IDisposable, IAsyncDisposable
4747

4848
internal abstract ValueTask ShutdownCompleted(CancellationToken cancellationToken = default);
4949

50+
internal abstract ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default);
51+
5052
internal abstract void Shutdown();
5153

5254
internal abstract void Flush();

src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ public override int WriteTimeout
117117

118118
public ValueTask ShutdownCompleted(CancellationToken cancellationToken = default) => _provider.ShutdownCompleted(cancellationToken);
119119

120+
public ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default) => _provider.WaitForWriteCompletionAsync(cancellationToken);
121+
120122
public void Shutdown() => _provider.Shutdown();
121123

122124
protected override void Dispose(bool disposing)

0 commit comments

Comments
 (0)