From 1941c3dd679648503ba16d2a5de61efc25a8b6a4 Mon Sep 17 00:00:00 2001 From: Jason Lau <177781338+Jason31569@users.noreply.github.com> Date: Wed, 20 Nov 2024 10:39:09 -0500 Subject: [PATCH] Ability to mock protected methods with and without return value --- src/NSubstitute/Core/IThreadLocalContext.cs | 6 + src/NSubstitute/Core/ThreadLocalContext.cs | 18 +++ .../Extensions/ProtectedExtensions.cs | 59 +++++++++ .../Infrastructure/AnotherClass.cs | 46 +++++++ .../ProtectedExtensionsTests.cs | 123 ++++++++++++++++++ 5 files changed, 252 insertions(+) create mode 100644 src/NSubstitute/Extensions/ProtectedExtensions.cs create mode 100644 tests/NSubstitute.Acceptance.Specs/Infrastructure/AnotherClass.cs create mode 100644 tests/NSubstitute.Acceptance.Specs/ProtectedExtensionsTests.cs diff --git a/src/NSubstitute/Core/IThreadLocalContext.cs b/src/NSubstitute/Core/IThreadLocalContext.cs index 5f1354d0..209eac16 100644 --- a/src/NSubstitute/Core/IThreadLocalContext.cs +++ b/src/NSubstitute/Core/IThreadLocalContext.cs @@ -24,6 +24,12 @@ public interface IThreadLocalContext void EnqueueArgumentSpecification(IArgumentSpecification spec); IList DequeueAllArgumentSpecifications(); + /// + /// Peeks into the argument specifications + /// + /// Enqueued argument specifications + IList PeekAllArgumentSpecifications(); + void SetPendingRaisingEventArgumentsFactory(Func getArguments); /// /// Returns the previously set arguments factory and resets the stored value. diff --git a/src/NSubstitute/Core/ThreadLocalContext.cs b/src/NSubstitute/Core/ThreadLocalContext.cs index 73e1100b..cad64dc2 100644 --- a/src/NSubstitute/Core/ThreadLocalContext.cs +++ b/src/NSubstitute/Core/ThreadLocalContext.cs @@ -108,6 +108,24 @@ public IList DequeueAllArgumentSpecifications() return queue; } + /// + public IList PeekAllArgumentSpecifications() + { + var queue = _argumentSpecifications.Value; + if (queue == null) { throw new SubstituteInternalException("Argument specification queue is null."); } + + if (queue.Count > 0) + { + var items = new IArgumentSpecification[queue.Count]; + + queue.CopyTo(items, 0); + + return items; + } + + return EmptySpecifications; + } + public void SetPendingRaisingEventArgumentsFactory(Func getArguments) { _getArgumentsForRaisingEvent.Value = getArguments; diff --git a/src/NSubstitute/Extensions/ProtectedExtensions.cs b/src/NSubstitute/Extensions/ProtectedExtensions.cs new file mode 100644 index 00000000..89693fbc --- /dev/null +++ b/src/NSubstitute/Extensions/ProtectedExtensions.cs @@ -0,0 +1,59 @@ +using System.Reflection; +using NSubstitute.Core; +using NSubstitute.Core.Arguments; + +// Disable nullability for client API, so it does not affect clients. +#nullable disable annotations + +namespace NSubstitute.Extensions; + +public static class ProtectedExtensions +{ + /// + /// Configure behavior for a protected method with return value + /// + /// + /// The object. + /// Name of the method. + /// The method arguments. + /// Result object from the method invocation. + /// Substitute - Cannot mock null object + /// Must provide valid protected method name to mock - methodName + public static object Protected(this T obj, string methodName, params object[] args) where T : class + { + if (obj == null) { throw new ArgumentNullException(nameof(obj), "Cannot mock null object"); } + if (string.IsNullOrWhiteSpace(methodName)) { throw new ArgumentException("Must provide valid protected method name to mock", nameof(methodName)); } + + IList argTypes = SubstitutionContext.Current.ThreadContext.PeekAllArgumentSpecifications(); + MethodInfo mthdInfo = obj.GetType().GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance, Type.DefaultBinder, argTypes.Select(x => x.ForType).ToArray(), null); + + if (mthdInfo == null) { throw new Exception($"Method {methodName} not found"); } + if (!mthdInfo.IsVirtual) { throw new Exception($"Method {methodName} is not virtual"); } + + return mthdInfo.Invoke(obj, args); + } + + /// + /// Configure behavior for a protected method with no return vlaue + /// + /// + /// The object. + /// Name of the method. + /// The method arguments. + /// WhenCalled<T>. + /// Substitute - Cannot mock null object + /// Must provide valid protected method name to mock - methodName + public static WhenCalled When(this T obj, string methodName, params object[] args) where T : class + { + if (obj == null) { throw new ArgumentNullException(nameof(obj), "Cannot mock null object"); } + if (string.IsNullOrWhiteSpace(methodName)) { throw new ArgumentException("Must provide valid protected method name to mock", nameof(methodName)); } + + IList argTypes = SubstitutionContext.Current.ThreadContext.PeekAllArgumentSpecifications(); + MethodInfo mthdInfo = obj.GetType().GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance, Type.DefaultBinder, argTypes.Select(y => y.ForType).ToArray(), null); + + if (mthdInfo == null) { throw new Exception($"Method {methodName} not found"); } + if (!mthdInfo.IsVirtual) { throw new Exception($"Method {methodName} is not virtual"); } + + return new WhenCalled(SubstitutionContext.Current, obj, x => mthdInfo.Invoke(x, args), MatchArgs.AsSpecifiedInCall); + } +} \ No newline at end of file diff --git a/tests/NSubstitute.Acceptance.Specs/Infrastructure/AnotherClass.cs b/tests/NSubstitute.Acceptance.Specs/Infrastructure/AnotherClass.cs new file mode 100644 index 00000000..ce97937f --- /dev/null +++ b/tests/NSubstitute.Acceptance.Specs/Infrastructure/AnotherClass.cs @@ -0,0 +1,46 @@ +namespace NSubstitute.Acceptance.Specs.Infrastructure; + +public abstract class AnotherClass +{ + protected abstract string ProtectedMethod(); + + protected abstract string ProtectedMethod(int i); + + protected abstract string ProtectedMethod(string msg, int i, char j); + + protected abstract void ProtectedMethodWithNoReturn(); + + protected abstract void ProtectedMethodWithNoReturn(int i); + + protected abstract void ProtectedMethodWithNoReturn(string msg, int i, char j); + + public string DoWork() + { + return ProtectedMethod(); + } + + public string DoWork(int i) + { + return ProtectedMethod(i); + } + + public string DoWork(string msg, int i, char j) + { + return ProtectedMethod(msg, i, j); + } + + public void DoVoidWork() + { + ProtectedMethodWithNoReturn(); + } + + public void DoVoidWork(int i) + { + ProtectedMethodWithNoReturn(i); + } + + public void DoVoidWork(string msg, int i, char j) + { + ProtectedMethodWithNoReturn(msg, i, j); + } +} \ No newline at end of file diff --git a/tests/NSubstitute.Acceptance.Specs/ProtectedExtensionsTests.cs b/tests/NSubstitute.Acceptance.Specs/ProtectedExtensionsTests.cs new file mode 100644 index 00000000..a16d213d --- /dev/null +++ b/tests/NSubstitute.Acceptance.Specs/ProtectedExtensionsTests.cs @@ -0,0 +1,123 @@ +using NSubstitute.Acceptance.Specs.Infrastructure; +using NSubstitute.Extensions; +using NUnit.Framework; + +namespace NSubstitute.Acceptance.Specs; + +public class ProtectedExtensionsTests +{ + [Test] + public void Should_mock_and_verify_protected_method_with_no_args() + { + var expectedMsg = "unit test message"; + var sub = Substitute.For(); + var worker = new Worker(); + + sub.Protected("ProtectedMethod").Returns(expectedMsg); + + Assert.That(worker.DoWork(sub), Is.EqualTo(expectedMsg)); + sub.Received(1).Protected("ProtectedMethod"); + } + + [Test] + public void Should_mock_and_verify_protected_method_with_arg() + { + var expectedMsg = "unit test message"; + var sub = Substitute.For(); + var worker = new Worker(); + + sub.Protected("ProtectedMethod", Arg.Any()).Returns(expectedMsg); + + Assert.That(worker.DoMoreWork(sub, 5), Is.EqualTo(expectedMsg)); + var a = sub.Received(1); + a.Protected("ProtectedMethod", Arg.Any()); + } + + [Test] + public void Should_mock_and_verify_protected_method_with_multiple_args() + { + var expectedMsg = "unit test message"; + var sub = Substitute.For(); + var worker = new Worker(); + + sub.Protected("ProtectedMethod", Arg.Any(), Arg.Any(), Arg.Any()).Returns(expectedMsg); + + Assert.That(worker.DoEvenMoreWork(sub, 3, 'x'), Is.EqualTo(expectedMsg)); + sub.Received(1).Protected("ProtectedMethod", Arg.Any(), Arg.Any(), Arg.Any()); + } + + [Test] + public void Should_mock_and_verify_method_with_no_return_and_no_args() + { + var count = 0; + var sub = Substitute.For(); + var worker = new Worker(); + + sub.When("ProtectedMethodWithNoReturn").Do(x => count++); + + worker.DoVoidWork(sub); + Assert.That(count, Is.EqualTo(1)); + sub.Received(1).Protected("ProtectedMethodWithNoReturn"); + } + + [Test] + public void Should_mock_and_verify_method_with_no_return_with_arg() + { + var count = 0; + var sub = Substitute.For(); + var worker = new Worker(); + + sub.When("ProtectedMethodWithNoReturn", Arg.Any()).Do(x => count++); + + worker.DoVoidWork(sub, 5); + Assert.That(count, Is.EqualTo(1)); + sub.Received(1).Protected("ProtectedMethodWithNoReturn", Arg.Any()); + } + + [Test] + public void Should_mock_and_verify_method_with_no_return_with_multiple_args() + { + var count = 0; + var sub = Substitute.For(); + var worker = new Worker(); + + sub.When("ProtectedMethodWithNoReturn", Arg.Any(), Arg.Any(), Arg.Any()).Do(x => count++); + + worker.DoVoidWork(sub, 5, 'x'); + Assert.That(count, Is.EqualTo(1)); + sub.Received(1).Protected("ProtectedMethodWithNoReturn", Arg.Any(), Arg.Any(), Arg.Any()); + } + + private class Worker + { + internal string DoWork(AnotherClass worker) + { + return worker.DoWork(); + } + + internal string DoMoreWork(AnotherClass worker, int i) + { + return worker.DoWork(i); + } + + internal string DoEvenMoreWork(AnotherClass worker, int i, char j) + { + return worker.DoWork("worker", i, j); + } + + internal void DoVoidWork(AnotherClass worker) + { + worker.DoVoidWork(); + } + + internal void DoVoidWork(AnotherClass worker, int i) + { + worker.DoVoidWork(i); + } + + internal void DoVoidWork(AnotherClass worker, int i, char j) + { + worker.DoVoidWork("void worker", i, j); + } + } +} \ No newline at end of file