Skip to content

Commit

Permalink
Feature: Enable call forwarding and substitution for non virtual meth…
Browse files Browse the repository at this point in the history
…ods or sealed classes implementing an interface.

How to use:
var substitute = Substitute.ForTypeForwardingTo <ISomeInterface,SomeImplementation>(argsList);

In this case, it doesn't matter if methods are virtual or not; it will intercept all calls since we will be working with an interface all the time.
For
Limitations:

Overriding virtual methods effectively replaces its implementation both for internal and external calls. With this implementation Nsubstitute will only intercept calls made by client classes using the interface. Calls made from inside the object itself to it's own method, will hit the actual implementation.
  • Loading branch information
marcoregueira committed Oct 29, 2022
1 parent a04dc32 commit e88249f
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 93 deletions.
27 changes: 27 additions & 0 deletions src/NSubstitute/Exceptions/TypeForwardingException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using System;

namespace NSubstitute.Exceptions
{
public abstract class TypeForwardingException : SubstituteException
{
protected TypeForwardingException(string message) : base(message) { }
}

public sealed class CanNotForwardCallsToClassNotImplementingInterfaceException : TypeForwardingException
{
public CanNotForwardCallsToClassNotImplementingInterfaceException(Type type) : base(DescribeProblem(type)) { }
private static string DescribeProblem(Type type)
{
return string.Format("The provided class '{0}' doesn't implement all requested interfaces. ", type.Name);
}
}

public sealed class CanNotForwardCallsToAbstractClassException : TypeForwardingException
{
public CanNotForwardCallsToAbstractClassException(Type type) : base(DescribeProblem(type)) { }
private static string DescribeProblem(Type type)
{
return string.Format("The provided class '{0}' is abstract. ", type.Name);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,15 @@ private static void VerifyClassImplementsAllInterfaces(Type classType, IEnumerab
{
if (!additionalInterfaces.All(x => x.GetTypeInfo().IsAssignableFrom(classType.GetTypeInfo())))
{
throw new SubstituteException("The provided class doesn't implement all requested interfaces.");
throw new CanNotForwardCallsToClassNotImplementingInterfaceException(classType);
}
}

private static void VerifyClassIsNotAbstract(Type classType)
{
if (classType.GetTypeInfo().IsAbstract)
{
throw new SubstituteException("The provided class is abstract.");
throw new CanNotForwardCallsToAbstractClassException(classType);
}
}

Expand Down
10 changes: 9 additions & 1 deletion src/NSubstitute/Substitute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,19 @@ public static T ForPartsOf<T>(params object[] constructorArguments)
return (T) substituteFactory.CreatePartial(new[] {typeof (T)}, constructorArguments);
}

public static T ForPartsOf<T,TClass>(params object[] constructorArguments)
public static T ForTypeForwardingTo<T,TClass>(params object[] constructorArguments)
where T : class
{
var substituteFactory = SubstitutionContext.Current.SubstituteFactory;
return (T)substituteFactory.CreatePartial(new[] { typeof(T), typeof(TClass) }, constructorArguments);
}

//public static T ForTypeForwardingTo<T, T2, T3>(params object[] constructorArguments)
// where T : class
//{
// var substituteFactory = SubstitutionContext.Current.SubstituteFactory;
// return (T)substituteFactory.CreatePartial(new[] { typeof(T), typeof(TClass) }, constructorArguments);
//}

}
}
90 changes: 0 additions & 90 deletions tests/NSubstitute.Acceptance.Specs/PartialSubs.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
using System;

using NSubstitute.Core;
using NSubstitute.Exceptions;
using NSubstitute.Extensions;

using NUnit.Framework;

namespace NSubstitute.Acceptance.Specs
Expand Down Expand Up @@ -89,63 +87,6 @@ public void UseImplementedVirtualMethod()
Assert.That(testAbstractClass.CalledTimes, Is.EqualTo(1));
}


[Test]
public void UseImplementedNonVirtualMethod()
{
var testAbstractClass = Substitute.ForPartsOf<ITestInterface, TestNonVirtualClass>();
Assert.That(testAbstractClass.MethodReturnsSameInt(1), Is.EqualTo(1));
Assert.That(testAbstractClass.CalledTimes, Is.EqualTo(1));
testAbstractClass.Received().MethodReturnsSameInt(1);
Assert.That(testAbstractClass.CalledTimes, Is.EqualTo(1));
}

[Test]
public void UseSubstitutedNonVirtualMethod()
{
var testInterface = Substitute.ForPartsOf<ITestInterface, TestNonVirtualClass>();
testInterface.Configure().MethodReturnsSameInt(1).Returns(2);
Assert.That(testInterface.MethodReturnsSameInt(1), Is.EqualTo(2));
Assert.That(testInterface.MethodReturnsSameInt(3), Is.EqualTo(3));
testInterface.ReceivedWithAnyArgs(2).MethodReturnsSameInt(default);
Assert.That(testInterface.CalledTimes, Is.EqualTo(1));
}

[Test]
public void UseSubstitutedNonVirtualMethodHonorsDoNotCallBase()
{
var testInterface = Substitute.ForPartsOf<ITestInterface, TestNonVirtualClass>();
testInterface.Configure().MethodReturnsSameInt(1).Returns(2);
testInterface.WhenForAnyArgs(x => x.MethodReturnsSameInt(default)).DoNotCallBase();
Assert.That(testInterface.MethodReturnsSameInt(1), Is.EqualTo(2));
Assert.That(testInterface.MethodReturnsSameInt(3), Is.EqualTo(0));
testInterface.ReceivedWithAnyArgs(2).MethodReturnsSameInt(default);
Assert.That(testInterface.CalledTimes, Is.EqualTo(0));
}

[Test]
public void PartialSubstituteCallsConstructorWithParameters()
{
var testInterface = Substitute.ForPartsOf<ITestInterface, TestNonVirtualClass>(50);
Assert.That(testInterface.MethodReturnsSameInt(1), Is.EqualTo(1));
Assert.That(testInterface.CalledTimes, Is.EqualTo(51));
}

[Test]
public void PartialSubstituteFailsIfClassDoesntImplementInterface()
{
Assert.Throws<SubstituteException>(
() => Substitute.ForPartsOf<ITestInterface, TestAbstractClass>());
}


[Test]
public void PartialSubstituteFailsIfClassIsAbstract()
{
Assert.Throws<SubstituteException>(
() => Substitute.ForPartsOf<ITestInterface, TestAbstractClassWithInterface>(), "The provided class is abstract.");
}

[Test]
public void ReturnDefaultForUnimplementedAbstractMethod()
{
Expand Down Expand Up @@ -366,39 +307,8 @@ public void ShouldThrowExceptionIfConfigureGlobalCallBaseForDelegateProxy()

public interface ITestInterface
{
public int CalledTimes { get; set; }

void VoidTestMethod();
int TestMethodReturnsInt();
int MethodReturnsSameInt(int i);
}

public class TestNonVirtualClass : ITestInterface
{
public TestNonVirtualClass() { }
public TestNonVirtualClass(int initialCounter) => CalledTimes = initialCounter;

public int CalledTimes { get; set; }

public int TestMethodReturnsInt() => throw new NotImplementedException();

public void VoidTestMethod() => throw new NotImplementedException();
public int MethodReturnsSameInt(int i)
{
CalledTimes++;
return i;
}
}

public abstract class TestAbstractClassWithInterface : ITestInterface
{
public int CalledTimes { get; set; }

public abstract int MethodReturnsSameInt(int i);

public abstract int TestMethodReturnsInt();

public abstract void VoidTestMethod();
}

public abstract class TestAbstractClass
Expand Down
109 changes: 109 additions & 0 deletions tests/NSubstitute.Acceptance.Specs/TypeForwarding.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
using System;

using NSubstitute.Core;
using NSubstitute.Exceptions;
using NSubstitute.Extensions;

using NUnit.Framework;

namespace NSubstitute.Acceptance.Specs
{
public class TypeForwarding
{
[Test]
public void UseImplementedNonVirtualMethod()
{
var testAbstractClass = Substitute.ForTypeForwardingTo<ITestInterface, TestSealedNonVirtualClass>();
Assert.That(testAbstractClass.MethodReturnsSameInt(1), Is.EqualTo(1));
Assert.That(testAbstractClass.CalledTimes, Is.EqualTo(1));
testAbstractClass.Received().MethodReturnsSameInt(1);
Assert.That(testAbstractClass.CalledTimes, Is.EqualTo(1));
}

[Test]
public void UseSubstitutedNonVirtualMethod()
{
var testInterface = Substitute.ForTypeForwardingTo<ITestInterface, TestSealedNonVirtualClass>();
testInterface.Configure().MethodReturnsSameInt(1).Returns(2);
Assert.That(testInterface.MethodReturnsSameInt(1), Is.EqualTo(2));
Assert.That(testInterface.MethodReturnsSameInt(3), Is.EqualTo(3));
testInterface.ReceivedWithAnyArgs(2).MethodReturnsSameInt(default);
Assert.That(testInterface.CalledTimes, Is.EqualTo(1));
}

[Test]
public void UseSubstitutedNonVirtualMethodHonorsDoNotCallBase()
{
var testInterface = Substitute.ForTypeForwardingTo<ITestInterface, TestSealedNonVirtualClass>();
testInterface.Configure().MethodReturnsSameInt(1).Returns(2);
testInterface.WhenForAnyArgs(x => x.MethodReturnsSameInt(default)).DoNotCallBase();
Assert.That(testInterface.MethodReturnsSameInt(1), Is.EqualTo(2));
Assert.That(testInterface.MethodReturnsSameInt(3), Is.EqualTo(0));
testInterface.ReceivedWithAnyArgs(2).MethodReturnsSameInt(default);
Assert.That(testInterface.CalledTimes, Is.EqualTo(0));
}

[Test]
public void PartialSubstituteCallsConstructorWithParameters()
{
var testInterface = Substitute.ForTypeForwardingTo<ITestInterface, TestSealedNonVirtualClass>(50);
Assert.That(testInterface.MethodReturnsSameInt(1), Is.EqualTo(1));
Assert.That(testInterface.CalledTimes, Is.EqualTo(51));
}

[Test]
public void PartialSubstituteFailsIfClassDoesntImplementInterface()
{
Assert.Throws<CanNotForwardCallsToClassNotImplementingInterfaceException>(
() => Substitute.ForTypeForwardingTo<ITestInterface, TestRandomConcreteClass>());
}

[Test]
public void PartialSubstituteFailsIfClassIsAbstract()
{
Assert.Throws<CanNotForwardCallsToAbstractClassException>(
() => Substitute.ForTypeForwardingTo<ITestInterface, TestAbstractClassWithInterface>(), "The provided class is abstract.");
}

public interface ITestInterface
{
public int CalledTimes { get; set; }

void VoidTestMethod();
int TestMethodReturnsInt();
int MethodReturnsSameInt(int i);
}

public sealed class TestSealedNonVirtualClass : ITestInterface
{
public TestSealedNonVirtualClass(int initialCounter) => CalledTimes = initialCounter;
public TestSealedNonVirtualClass() { }

public int CalledTimes { get; set; }

public int TestMethodReturnsInt() => throw new NotImplementedException();

public void VoidTestMethod() => throw new NotImplementedException();
public int MethodReturnsSameInt(int i)
{
CalledTimes++;
return i;
}
}

public abstract class TestAbstractClassWithInterface : ITestInterface
{
public int CalledTimes { get; set; }

public abstract int MethodReturnsSameInt(int i);

public abstract int TestMethodReturnsInt();

public abstract void VoidTestMethod();
}

public class TestRandomConcreteClass { }

public abstract class TestAbstractClass { }
}
}

0 comments on commit e88249f

Please sign in to comment.