From 0bf756dfd178174b6d106cf502e6d9c915734a68 Mon Sep 17 00:00:00 2001 From: marve Date: Thu, 2 May 2024 14:38:55 -0700 Subject: [PATCH 1/3] fix(server): at-least-once messages received during subscription must be delivered --- .../Server/Retained_Messages_Tests.cs | 34 +++++++++++++++++++ .../MqttClientSubscriptionsManager.cs | 26 ++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs b/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs index b27741575..5adf3d86f 100644 --- a/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System.Linq; +using System.Threading; using System.Threading.Tasks; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Client; @@ -169,6 +170,39 @@ public async Task Receive_Retained_Message_After_Subscribe() } } + [TestMethod] + public async Task Receive_AtLeastOnce_Retained_Message_Published_During_Subscribe() + { + using (var testEnvironment = CreateTestEnvironment()) + { + var messagePublished = new SemaphoreSlim(0,1); + var subscribeReceived = new SemaphoreSlim(0,1); + await testEnvironment.StartServer(); + testEnvironment.Server.InterceptingSubscriptionAsync += _ => + { + subscribeReceived.Release(); + return messagePublished.WaitAsync(); + }; + + var c1 = await testEnvironment.ConnectClient(); + + var c2 = await testEnvironment.ConnectClient(); + var messageHandler = testEnvironment.CreateApplicationMessageHandler(c2); + + Task subscribeComplete = c2.SubscribeAsync(new MqttTopicFilterBuilder().WithTopic("retained").WithAtLeastOnceQoS().Build()); + await subscribeReceived.WaitAsync(1000); + await c1.PublishAsync(new MqttApplicationMessageBuilder().WithTopic("retained").WithPayload(new byte[3]).WithRetainFlag().WithQualityOfServiceLevel(MqttQualityOfServiceLevel.AtLeastOnce).Build()); + await c1.DisconnectAsync(); + + messagePublished.Release(); + await subscribeComplete; + await Task.Delay(500); + + messageHandler.AssertReceivedCountEquals(1); + Assert.IsTrue(messageHandler.ReceivedEventArgs.First().ApplicationMessage.Retain); + } + } + [TestMethod] public async Task Receive_Retained_Messages_From_Higher_Qos_Level() { diff --git a/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs b/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs index 8cacf5605..735076d16 100644 --- a/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs @@ -10,6 +10,7 @@ using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; +using static MQTTnet.Server.MqttClientSubscriptionsManager; namespace MQTTnet.Server { @@ -177,6 +178,13 @@ public async Task Subscribe(MqttSubscribePacket subscribePacket var addedSubscriptions = new List(); var finalTopicFilters = new List(); + var atLeastOnceSubscriptionResults = new List(); + + IList retainedApplicationMessages = null; + if (subscribePacket.TopicFilters.Any(f => f.QualityOfServiceLevel != MqttQualityOfServiceLevel.AtLeastOnce)) + { + retainedApplicationMessages = await _retainedMessagesManager.GetMessages().ConfigureAwait(false); + } // The topic filters are order by its QoS so that the higher QoS will win over a // lower one. @@ -208,6 +216,24 @@ public async Task Subscribe(MqttSubscribePacket subscribePacket finalTopicFilters.Add(topicFilter); FilterRetainedApplicationMessages(retainedApplicationMessages, createSubscriptionResult, result); + if (createSubscriptionResult.Subscription.GrantedQualityOfServiceLevel != MqttQualityOfServiceLevel.AtLeastOnce) + { + FilterRetainedApplicationMessages(retainedApplicationMessages, createSubscriptionResult, result); + } + else + { + atLeastOnceSubscriptionResults.Add(createSubscriptionResult); + } + } + + if (atLeastOnceSubscriptionResults.Count != 0) + { + // In order to satisfy at least once, we must query for retained messages after creating the subscription. + retainedApplicationMessages = await _retainedMessagesManager.GetMessages().ConfigureAwait(false); + foreach (var createSubscriptionResult in atLeastOnceSubscriptionResults) + { + FilterRetainedApplicationMessages(retainedApplicationMessages, createSubscriptionResult, result); + } } // This call will add the new subscription to the internal storage. From 3365bf24254fbfa1424d532bd255694222076967 Mon Sep 17 00:00:00 2001 From: marve Date: Wed, 8 May 2024 13:54:42 -0700 Subject: [PATCH 2/3] fix: compile error --- Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs b/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs index 735076d16..87012f7d3 100644 --- a/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs +++ b/Source/MQTTnet/Server/Internal/MqttClientSubscriptionsManager.cs @@ -173,7 +173,6 @@ public async Task Subscribe(MqttSubscribePacket subscribePacket throw new ArgumentNullException(nameof(subscribePacket)); } - var retainedApplicationMessages = await _retainedMessagesManager.GetMessages().ConfigureAwait(false); var result = new SubscribeResult(subscribePacket.TopicFilters.Count); var addedSubscriptions = new List(); From 1df8a357344f94728e3568e66661e0174398a655 Mon Sep 17 00:00:00 2001 From: marve Date: Wed, 8 May 2024 14:07:33 -0700 Subject: [PATCH 3/3] chore: remove unused using --- .../Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs b/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs index ed4596779..2d9ee3147 100644 --- a/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs @@ -7,7 +7,6 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; -using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Client; using MQTTnet.Diagnostics;