diff --git a/src/NSubstitute/Core/WhenCalled.cs b/src/NSubstitute/Core/WhenCalled.cs index c15b9691..ecd373b3 100644 --- a/src/NSubstitute/Core/WhenCalled.cs +++ b/src/NSubstitute/Core/WhenCalled.cs @@ -1,3 +1,4 @@ +using System.Threading.Tasks; using NSubstitute.Routing; // Disable nullability for entry-point API @@ -35,6 +36,15 @@ public void Do(Action callbackWithArguments) } /// + /// Perform this action when called. + /// + /// + public void Do(Func callbackWithArguments) + { + Do(callInfo => callbackWithArguments(callInfo).GetAwaiter().GetResult()); + } + + /// /// Perform this configured callback when called. /// /// diff --git a/tests/NSubstitute.Acceptance.Specs/WhenCalledDo.cs b/tests/NSubstitute.Acceptance.Specs/WhenCalledDo.cs index 8b001a0e..05528b46 100644 --- a/tests/NSubstitute.Acceptance.Specs/WhenCalledDo.cs +++ b/tests/NSubstitute.Acceptance.Specs/WhenCalledDo.cs @@ -1,3 +1,6 @@ +using System; +using System.Threading; +using System.Threading.Tasks; using NSubstitute.Acceptance.Specs.Infrastructure; using NSubstitute.Core; using NUnit.Framework; @@ -9,11 +12,31 @@ public class WhenCalledDo { private ISomething _something; + + [Test] + public void Execute_when_called_async() + { + var called = false; + _something.When(substitute => substitute.Echo(1)).Do(async info => + { + await Task.Delay(100); + called = true; + }); + + Assert.That(called, Is.False, "Called"); + _something.Echo(1); + Assert.That(called, Is.True, "Called"); + } + [Test] public void Execute_when_called() { var called = false; - _something.When(substitute => substitute.Echo(1)).Do(info => called = true); + _something.When(substitute => substitute.Echo(1)).Do(info => + { + Thread.Sleep(100); + called = true; + }); Assert.That(called, Is.False, "Called"); _something.Echo(1);