diff --git a/src/MarcusW.VncClient/Protocol/Implementation/Services/Communication/RfbMessageReceiver.cs b/src/MarcusW.VncClient/Protocol/Implementation/Services/Communication/RfbMessageReceiver.cs index 9635534..b72316f 100644 --- a/src/MarcusW.VncClient/Protocol/Implementation/Services/Communication/RfbMessageReceiver.cs +++ b/src/MarcusW.VncClient/Protocol/Implementation/Services/Communication/RfbMessageReceiver.cs @@ -25,7 +25,7 @@ public sealed class RfbMessageReceiver : BackgroundThread, IRfbMessageReceiver /// Initializes a new instance of the . /// /// The connection context. - public RfbMessageReceiver(RfbConnectionContext context) : base("RFB Message Receiver") + public RfbMessageReceiver(RfbConnectionContext context) { _context = context; _state = context.GetState(); @@ -51,7 +51,7 @@ public Task StopReceiveLoopAsync() // This method will not catch exceptions so the BackgroundThread base class will receive them, // raise a "Failure" and trigger a reconnect. - protected override void ThreadWorker(CancellationToken cancellationToken) + protected override async Task ThreadWorker(CancellationToken cancellationToken) { // Get the transport stream so we don't have to call the getter every time Debug.Assert(_context.Transport != null, "_context.Transport != null"); @@ -62,12 +62,12 @@ protected override void ThreadWorker(CancellationToken cancellationToken) ImmutableDictionary incomingMessageLookup = _context.SupportedMessageTypes .OfType().ToImmutableDictionary(mt => mt.Id); - Span messageTypeBuffer = stackalloc byte[1]; + var messageTypeBuffer = new byte[1]; while (!cancellationToken.IsCancellationRequested) { // Read message type - if (transportStream.Read(messageTypeBuffer) == 0) + if (await transportStream.ReadAsync(messageTypeBuffer.AsMemory(), cancellationToken) == 0) { throw new UnexpectedEndOfStreamException("Stream reached its end while reading next message type."); } diff --git a/src/MarcusW.VncClient/Protocol/Implementation/Services/Communication/RfbMessageSender.cs b/src/MarcusW.VncClient/Protocol/Implementation/Services/Communication/RfbMessageSender.cs index f325660..30f493e 100644 --- a/src/MarcusW.VncClient/Protocol/Implementation/Services/Communication/RfbMessageSender.cs +++ b/src/MarcusW.VncClient/Protocol/Implementation/Services/Communication/RfbMessageSender.cs @@ -1,8 +1,8 @@ using System; -using System.Collections.Concurrent; using System.Diagnostics; using System.Linq; using System.Threading; +using System.Threading.Channels; using System.Threading.Tasks; using MarcusW.VncClient.Protocol.Implementation.MessageTypes.Outgoing; using MarcusW.VncClient.Protocol.MessageTypes; @@ -20,7 +20,7 @@ public class RfbMessageSender : BackgroundThread, IRfbMessageSender private readonly RfbConnectionContext _context; private readonly ILogger _logger; - private readonly BlockingCollection _queue = new(new ConcurrentQueue()); + private readonly Channel _queue = Channel.CreateUnbounded(); private readonly ProtocolState _state; @@ -30,7 +30,7 @@ public class RfbMessageSender : BackgroundThread, IRfbMessageSender /// Initializes a new instance of the . /// /// The connection context. - public RfbMessageSender(RfbConnectionContext context) : base("RFB Message Sender") + public RfbMessageSender(RfbConnectionContext context) { _context = context; _state = context.GetState(); @@ -79,7 +79,7 @@ public void EnqueueMessage(IOutgoingMessage message, var messageType = GetAndCheckMessageType(); // Add message to queue - _queue.Add(new(message, messageType), cancellationToken); + _queue.Writer.TryWrite(new QueueItem(message, messageType)); } /// @@ -104,7 +104,7 @@ public Task SendMessageAndWaitAsync(IOutgoingMessage TaskCompletionSource completionSource = new(TaskCreationOptions.RunContinuationsAsynchronously); // Add message to queue - _queue.Add(new(message, messageType, completionSource), cancellationToken); + _queue.Writer.TryWrite(new QueueItem(message, messageType, completionSource)); return completionSource.Task; } @@ -120,7 +120,7 @@ protected override void Dispose(bool disposing) if (disposing) { SetQueueCancelled(); - _queue.Dispose(); + _queue.Writer.TryComplete(); } _disposed = true; @@ -130,7 +130,7 @@ protected override void Dispose(bool disposing) // This method will not catch exceptions so the BackgroundThread base class will receive them, // raise a "Failure" and trigger a reconnect. - protected override void ThreadWorker(CancellationToken cancellationToken) + protected override async Task ThreadWorker(CancellationToken cancellationToken) { try { @@ -138,8 +138,9 @@ protected override void ThreadWorker(CancellationToken cancellationToken) ITransport transport = _context.Transport; // Iterate over all queued items (will block if the queue is empty) - foreach (QueueItem queueItem in _queue.GetConsumingEnumerable(cancellationToken)) + while (await _queue.Reader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) { + QueueItem queueItem = await _queue.Reader.ReadAsync(cancellationToken).ConfigureAwait(false); IOutgoingMessage message = queueItem.Message; IOutgoingMessageType messageType = queueItem.MessageType; @@ -192,8 +193,7 @@ private TMessageType GetAndCheckMessageType() where TMessageType : private void SetQueueCancelled() { - _queue.CompleteAdding(); - foreach (QueueItem queueItem in _queue) + while (_queue.Reader.TryRead(out QueueItem? queueItem)) queueItem.CompletionSource?.TrySetCanceled(); } diff --git a/src/MarcusW.VncClient/Utils/BackgroundThread.cs b/src/MarcusW.VncClient/Utils/BackgroundThread.cs index fd9ce83..cecbc74 100644 --- a/src/MarcusW.VncClient/Utils/BackgroundThread.cs +++ b/src/MarcusW.VncClient/Utils/BackgroundThread.cs @@ -1,36 +1,33 @@ using System; -using System.Diagnostics; using System.Threading; using System.Threading.Tasks; +using JetBrains.Annotations; namespace MarcusW.VncClient.Utils; /// /// Base class for easier creation and clean cancellation of a background thread. /// +[PublicAPI] public abstract class BackgroundThread : IBackgroundThread { - private readonly TaskCompletionSource _completedTcs = new(); - private readonly object _startLock = new(); - + private readonly object _lock = new(); private readonly CancellationTokenSource _stopCts = new(); - private readonly Thread _thread; private volatile bool _disposed; - - private bool _started; + private Task? _task; /// /// Initializes a new instance of the . /// /// The thread name. - protected BackgroundThread(string name) - { - _thread = new(ThreadStart) { - Name = name, - IsBackground = true, - }; - } + [Obsolete("The name field is no longer used")] + protected BackgroundThread(string name) : this() { } + + /// + /// Initializes a new instance of the . + /// + protected BackgroundThread() { } /// public event EventHandler? Failed; @@ -47,27 +44,7 @@ protected virtual void Dispose(bool disposing) if (disposing) { - try - { - // Ensure the thread is stopped - _stopCts.Cancel(); - if (_thread.IsAlive) - { - // Block and wait for completion or hard-kill the thread after 1 second - if (!_thread.Join(TimeSpan.FromSeconds(1))) - { - // _thread.Abort(); -- This is obsolete and not supported - } - } - } - catch - { - // Ignore - } - - // Just to be sure... - _completedTcs.TrySetResult(null); - + _stopCts.Cancel(); _stopCts.Dispose(); } @@ -84,15 +61,15 @@ protected void Start() { ObjectDisposedException.ThrowIf(_disposed, typeof(BackgroundThread)); - lock (_startLock) + // Do your work... + try { - if (_started) - { - throw new InvalidOperationException("Thread already started."); - } - - _thread.Start(_stopCts.Token); - _started = true; + lock (_lock) + _task ??= ThreadWorker(_stopCts.Token); + } + catch (Exception exception) when (exception is not (OperationCanceledException or ThreadAbortException)) + { + Failed?.Invoke(this, new BackgroundThreadFailedEventArgs(exception)); } } @@ -102,49 +79,30 @@ protected void Start() /// /// It is safe to call this method multiple times. /// - protected Task StopAndWaitAsync() + protected async Task StopAndWaitAsync() { ObjectDisposedException.ThrowIf(_disposed, typeof(BackgroundThread)); - lock (_startLock) + // Tell the thread to stop + await _stopCts.CancelAsync(); + + // Wait for completion + if (_task is not null) { - if (!_started) + try + { + await _task.ConfigureAwait(false); + } + catch (Exception exception) when (exception is not (OperationCanceledException or ThreadAbortException)) { - throw new InvalidOperationException("Thread has not been started."); + Failed?.Invoke(this, new BackgroundThreadFailedEventArgs(exception)); } } - - // Tell the thread to stop - _stopCts.Cancel(); - - // Wait for completion - return _completedTcs.Task; } /// /// Executes the work that should happen in the background. /// /// The cancellation token that tells the method implementation when to complete. - protected abstract void ThreadWorker(CancellationToken cancellationToken); - - private void ThreadStart(object? parameter) - { - Debug.Assert(parameter != null, nameof(parameter) + " != null"); - var cancellationToken = (CancellationToken)parameter; - - try - { - // Do your work... - ThreadWorker(cancellationToken); - } - catch (Exception exception) when (!(exception is OperationCanceledException or ThreadAbortException)) - { - Failed?.Invoke(this, new(exception)); - } - finally - { - // Notify stop method that thread has completed - _completedTcs.TrySetResult(null); - } - } + protected abstract Task ThreadWorker(CancellationToken cancellationToken); } diff --git a/tests/MarcusW.VncClient.Tests/Utils/BackgroundThreadTests.cs b/tests/MarcusW.VncClient.Tests/Utils/BackgroundThreadTests.cs index b20d5dd..d5b812d 100644 --- a/tests/MarcusW.VncClient.Tests/Utils/BackgroundThreadTests.cs +++ b/tests/MarcusW.VncClient.Tests/Utils/BackgroundThreadTests.cs @@ -56,16 +56,16 @@ public void Starts_ThreadWorker() mock.Protected().Verify("ThreadWorker", Times.Exactly(1), ItExpr.IsAny()); } - private class CancellableThread() : BackgroundThread("Cancellable Thread") + private class CancellableThread : BackgroundThread { public new void Start() => base.Start(); public new Task StopAndWaitAsync() => base.StopAndWaitAsync(); - protected override void ThreadWorker(CancellationToken cancellationToken) + protected override async Task ThreadWorker(CancellationToken cancellationToken) { while (!cancellationToken.IsCancellationRequested) - Thread.Sleep(10); + await Task.Delay(10); } } }