Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to mock protected methods with and without return value #845

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/NSubstitute/Core/IThreadLocalContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ public interface IThreadLocalContext
void EnqueueArgumentSpecification(IArgumentSpecification spec);
IList<IArgumentSpecification> DequeueAllArgumentSpecifications();

/// <summary>
/// Peeks into the argument specifications
/// </summary>
/// <returns>Enqueued argument specifications</returns>
IList<IArgumentSpecification> PeekAllArgumentSpecifications();

void SetPendingRaisingEventArgumentsFactory(Func<ICall, object?[]> getArguments);
/// <summary>
/// Returns the previously set arguments factory and resets the stored value.
Expand Down
18 changes: 18 additions & 0 deletions src/NSubstitute/Core/ThreadLocalContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,24 @@ public IList<IArgumentSpecification> DequeueAllArgumentSpecifications()
return queue;
}

/// <inheritdoc/>
public IList<IArgumentSpecification> PeekAllArgumentSpecifications()
{
var queue = _argumentSpecifications.Value;
if (queue == null) { throw new SubstituteInternalException("Argument specification queue is null."); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: do you know under what circumstances this occurs? If it is expected can we just return EmptySpecifications in that case?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh it doesn't look like it could ever be null. I decided to have/keep this in line with enqueue and dequeue methods in case it is something I failed to see

Should we change all 3 (along with enqueue, dequeue) methods to provide a consistent behavior?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only changed PeekAllArgumentSpecifications, but happy to update DequeueAllArgumentSpecifications if you think it makes sense


if (queue.Count > 0)
{
var items = new IArgumentSpecification[queue.Count];

queue.CopyTo(items, 0);

return items;
}

return EmptySpecifications;
}

public void SetPendingRaisingEventArgumentsFactory(Func<ICall, object?[]> getArguments)
{
_getArgumentsForRaisingEvent.Value = getArguments;
Expand Down
59 changes: 59 additions & 0 deletions src/NSubstitute/Extensions/ProtectedExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
using System.Reflection;
using NSubstitute.Core;
using NSubstitute.Core.Arguments;

// Disable nullability for client API, so it does not affect clients.
#nullable disable annotations

namespace NSubstitute.Extensions;

public static class ProtectedExtensions
{
/// <summary>
/// Configure behavior for a protected method with return value
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="obj">The object.</param>
/// <param name="methodName">Name of the method.</param>
/// <param name="args">The method arguments.</param>
/// <returns>Result object from the method invocation.</returns>
/// <exception cref="System.ArgumentNullException">Substitute - Cannot mock null object</exception>
/// <exception cref="System.ArgumentException">Must provide valid protected method name to mock - methodName</exception>
public static object Protected<T>(this T obj, string methodName, params object[] args) where T : class
{
if (obj == null) { throw new ArgumentNullException(nameof(obj), "Cannot mock null object"); }
if (string.IsNullOrWhiteSpace(methodName)) { throw new ArgumentException("Must provide valid protected method name to mock", nameof(methodName)); }

IList<IArgumentSpecification> argTypes = SubstitutionContext.Current.ThreadContext.PeekAllArgumentSpecifications();
MethodInfo mthdInfo = obj.GetType().GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance, Type.DefaultBinder, argTypes.Select(x => x.ForType).ToArray(), null);

if (mthdInfo == null) { throw new Exception($"Method {methodName} not found"); }
if (!mthdInfo.IsVirtual) { throw new Exception($"Method {methodName} is not virtual"); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (nitpick): ideally should use a custom exception(subclassing SubstituteException) with details on what they should do to fix the issue. (see NSubstitute.Exceptions for examples.)

For obj == null, we have NullSubstituteReferenceException. Should also ensure receiver is a substitute otherwise throw NotASubstituteException.

For mthdInfo == null, maybe include arg types checked, something like "No method found with signature Foo(Int, String) on IMySubstitute. Check the method name and arguments are correct."

For non-virtual, can probably use a summary of the warning given in the documentation.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, this makes sense

Copy link
Author

@Jason31569 Jason31569 Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the suggested changes and added additional unit tests to check the error conditions...which helped me find a problem

If you run these tests (in sequence) the second and third test will fail. Removing the first test, the other tests run successfully:

  1. Should_throw_on_mock_method_arg_mismatch
  2. Should_throw_on_mock_non_virtual
  3. Should_throw_on_mock_non_virtual_void_method

I tracked it down to orphaned arg spec from test 1 because ThreadLocalContext.DequeueAllArgumentSpecifications is not invoked when the method is invalid (even if I did not throw exception when method is not virtual). I am now dequeuing all arg specs before I throw exception. Not sure if this is the best approach?


return mthdInfo.Invoke(obj, args);
}

/// <summary>
/// Configure behavior for a protected method with no return vlaue
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="obj">The object.</param>
/// <param name="methodName">Name of the method.</param>
/// <param name="args">The method arguments.</param>
/// <returns>WhenCalled&lt;T&gt;.</returns>
/// <exception cref="System.ArgumentNullException">Substitute - Cannot mock null object</exception>
/// <exception cref="System.ArgumentException">Must provide valid protected method name to mock - methodName</exception>
public static WhenCalled<T> When<T>(this T obj, string methodName, params object[] args) where T : class
Jason31569 marked this conversation as resolved.
Show resolved Hide resolved
{
if (obj == null) { throw new ArgumentNullException(nameof(obj), "Cannot mock null object"); }
if (string.IsNullOrWhiteSpace(methodName)) { throw new ArgumentException("Must provide valid protected method name to mock", nameof(methodName)); }

IList<IArgumentSpecification> argTypes = SubstitutionContext.Current.ThreadContext.PeekAllArgumentSpecifications();
MethodInfo mthdInfo = obj.GetType().GetMethod(methodName, BindingFlags.NonPublic | BindingFlags.Instance, Type.DefaultBinder, argTypes.Select(y => y.ForType).ToArray(), null);

if (mthdInfo == null) { throw new Exception($"Method {methodName} not found"); }
if (!mthdInfo.IsVirtual) { throw new Exception($"Method {methodName} is not virtual"); }

return new WhenCalled<T>(SubstitutionContext.Current, obj, x => mthdInfo.Invoke(x, args), MatchArgs.AsSpecifiedInCall);
}
}
46 changes: 46 additions & 0 deletions tests/NSubstitute.Acceptance.Specs/Infrastructure/AnotherClass.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
namespace NSubstitute.Acceptance.Specs.Infrastructure;

public abstract class AnotherClass
{
protected abstract string ProtectedMethod();

protected abstract string ProtectedMethod(int i);

protected abstract string ProtectedMethod(string msg, int i, char j);

protected abstract void ProtectedMethodWithNoReturn();

protected abstract void ProtectedMethodWithNoReturn(int i);

protected abstract void ProtectedMethodWithNoReturn(string msg, int i, char j);

public string DoWork()
{
return ProtectedMethod();
}

public string DoWork(int i)
{
return ProtectedMethod(i);
}

public string DoWork(string msg, int i, char j)
{
return ProtectedMethod(msg, i, j);
}

public void DoVoidWork()
{
ProtectedMethodWithNoReturn();
}

public void DoVoidWork(int i)
{
ProtectedMethodWithNoReturn(i);
}

public void DoVoidWork(string msg, int i, char j)
{
ProtectedMethodWithNoReturn(msg, i, j);
}
}
123 changes: 123 additions & 0 deletions tests/NSubstitute.Acceptance.Specs/ProtectedExtensionsTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
using NSubstitute.Acceptance.Specs.Infrastructure;
using NSubstitute.Extensions;
using NUnit.Framework;

namespace NSubstitute.Acceptance.Specs;

public class ProtectedExtensionsTests
{
[Test]
public void Should_mock_and_verify_protected_method_with_no_args()
{
var expectedMsg = "unit test message";
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.Protected("ProtectedMethod").Returns(expectedMsg);

Assert.That(worker.DoWork(sub), Is.EqualTo(expectedMsg));
sub.Received(1).Protected("ProtectedMethod");
}

[Test]
public void Should_mock_and_verify_protected_method_with_arg()
{
var expectedMsg = "unit test message";
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.Protected("ProtectedMethod", Arg.Any<int>()).Returns(expectedMsg);

Assert.That(worker.DoMoreWork(sub, 5), Is.EqualTo(expectedMsg));
var a = sub.Received(1);
a.Protected("ProtectedMethod", Arg.Any<int>());
}

[Test]
public void Should_mock_and_verify_protected_method_with_multiple_args()
{
var expectedMsg = "unit test message";
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.Protected("ProtectedMethod", Arg.Any<string>(), Arg.Any<int>(), Arg.Any<char>()).Returns(expectedMsg);

Assert.That(worker.DoEvenMoreWork(sub, 3, 'x'), Is.EqualTo(expectedMsg));
sub.Received(1).Protected("ProtectedMethod", Arg.Any<string>(), Arg.Any<int>(), Arg.Any<char>());
}

[Test]
public void Should_mock_and_verify_method_with_no_return_and_no_args()
{
var count = 0;
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.When("ProtectedMethodWithNoReturn").Do(x => count++);

worker.DoVoidWork(sub);
Assert.That(count, Is.EqualTo(1));
sub.Received(1).Protected("ProtectedMethodWithNoReturn");
}

[Test]
public void Should_mock_and_verify_method_with_no_return_with_arg()
{
var count = 0;
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.When("ProtectedMethodWithNoReturn", Arg.Any<int>()).Do(x => count++);

worker.DoVoidWork(sub, 5);
Assert.That(count, Is.EqualTo(1));
sub.Received(1).Protected("ProtectedMethodWithNoReturn", Arg.Any<int>());
}

[Test]
public void Should_mock_and_verify_method_with_no_return_with_multiple_args()
{
var count = 0;
var sub = Substitute.For<AnotherClass>();
var worker = new Worker();

sub.When("ProtectedMethodWithNoReturn", Arg.Any<string>(), Arg.Any<int>(), Arg.Any<char>()).Do(x => count++);

worker.DoVoidWork(sub, 5, 'x');
Assert.That(count, Is.EqualTo(1));
sub.Received(1).Protected("ProtectedMethodWithNoReturn", Arg.Any<string>(), Arg.Any<int>(), Arg.Any<char>());
}

private class Worker
{
internal string DoWork(AnotherClass worker)
{
return worker.DoWork();
}

internal string DoMoreWork(AnotherClass worker, int i)
{
return worker.DoWork(i);
}

internal string DoEvenMoreWork(AnotherClass worker, int i, char j)
{
return worker.DoWork("worker", i, j);
}

internal void DoVoidWork(AnotherClass worker)
{
worker.DoVoidWork();
}

internal void DoVoidWork(AnotherClass worker, int i)
{
worker.DoVoidWork(i);
}

internal void DoVoidWork(AnotherClass worker, int i, char j)
{
worker.DoVoidWork("void worker", i, j);
}
}
}
Loading