Skip to content

Commit

Permalink
Feature: Extend PartsOf to mock non-virtual methods implementing an i…
Browse files Browse the repository at this point in the history
…nterface.

How to use:
var substitute = Substitute.ForPartsOf<ISomeInterface,SomeImplementation>(argsList);

In this case, it doesn't matter if methods are virtual or not; it will intercept all calls since we will be working with an interface all the time.

Limitations:

Overriding virtual methods effectively replaces its implementation both for internal and external calls. With this implementation Nsubstitute will only intercept calls made by client classes using the interface. Calls made from inside the object itself to it's own method, will hit the actual implementation.
  • Loading branch information
marcoregueira committed Jul 22, 2022
1 parent fed59f6 commit a04dc32
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/NSubstitute/Core/IProxyFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ namespace NSubstitute.Core
{
public interface IProxyFactory
{
object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments);
object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, bool isPartial, object?[]? constructorArguments);
}
}
9 changes: 5 additions & 4 deletions src/NSubstitute/Core/SubstituteFactory.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Linq;
using System.Reflection;

using NSubstitute.Exceptions;

namespace NSubstitute.Core
Expand All @@ -26,7 +27,7 @@ public SubstituteFactory(ISubstituteStateFactory substituteStateFactory, ICallRo
/// <returns></returns>
public object Create(Type[] typesToProxy, object?[] constructorArguments)
{
return Create(typesToProxy, constructorArguments, callBaseByDefault: false);
return Create(typesToProxy, constructorArguments, callBaseByDefault: false, isPartial: false);
}

/// <summary>
Expand All @@ -45,10 +46,10 @@ public object CreatePartial(Type[] typesToProxy, object?[] constructorArguments)
throw new CanNotPartiallySubForInterfaceOrDelegateException(primaryProxyType);
}

return Create(typesToProxy, constructorArguments, callBaseByDefault: true);
return Create(typesToProxy, constructorArguments, callBaseByDefault: true, isPartial: true);
}

private object Create(Type[] typesToProxy, object?[] constructorArguments, bool callBaseByDefault)
private object Create(Type[] typesToProxy, object?[] constructorArguments, bool callBaseByDefault, bool isPartial)
{
var substituteState = _substituteStateFactory.Create(this);
substituteState.CallBaseConfiguration.CallBaseByDefault = callBaseByDefault;
Expand All @@ -58,7 +59,7 @@ private object Create(Type[] typesToProxy, object?[] constructorArguments, bool

var callRouter = _callRouterFactory.Create(substituteState, canConfigureBaseCalls);
var additionalTypes = typesToProxy.Where(x => x != primaryProxyType).ToArray();
var proxy = _proxyFactory.GenerateProxy(callRouter, primaryProxyType, additionalTypes, constructorArguments);
var proxy = _proxyFactory.GenerateProxy(callRouter, primaryProxyType, additionalTypes, isPartial, constructorArguments);
return proxy;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;

using Castle.DynamicProxy;

using NSubstitute.Core;
using NSubstitute.Exceptions;

Expand All @@ -22,14 +25,14 @@ public CastleDynamicProxyFactory(ICallFactory callFactory, IArgumentSpecificatio
_allMethodsExceptCallRouterCallsHook = new AllMethodsExceptCallRouterCallsHook();
}

public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments)
public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, bool isPartial, object?[]? constructorArguments)
{
return typeToProxy.IsDelegate()
? GenerateDelegateProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments)
: GenerateTypeProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments);
: GenerateTypeProxy(callRouter, typeToProxy, additionalInterfaces, isPartial, constructorArguments);
}

private object GenerateTypeProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments)
private object GenerateTypeProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, bool isPartial, object?[]? constructorArguments)
{
VerifyClassHasNotBeenPassedAsAnAdditionalInterface(additionalInterfaces);

Expand All @@ -42,8 +45,9 @@ private object GenerateTypeProxy(ICallRouter callRouter, Type typeToProxy, Type[
typeToProxy,
additionalInterfaces,
constructorArguments,
new IInterceptor[] {proxyIdInterceptor, forwardingInterceptor},
proxyGenerationOptions);
new IInterceptor[] { proxyIdInterceptor, forwardingInterceptor },
proxyGenerationOptions,
isPartial);

forwardingInterceptor.SwitchToFullDispatchMode();
return proxy;
Expand All @@ -65,8 +69,9 @@ private object GenerateDelegateProxy(ICallRouter callRouter, Type delegateType,
typeToProxy: typeof(object),
additionalInterfaces: null,
constructorArguments: null,
interceptors: new IInterceptor[] {proxyIdInterceptor, forwardingInterceptor},
proxyGenerationOptions);
interceptors: new IInterceptor[] { proxyIdInterceptor, forwardingInterceptor },
proxyGenerationOptions,
isPartial: false);

forwardingInterceptor.SwitchToFullDispatchMode();

Expand All @@ -87,8 +92,13 @@ private CastleForwardingInterceptor CreateForwardingInterceptor(ICallRouter call
private object CreateProxyUsingCastleProxyGenerator(Type typeToProxy, Type[]? additionalInterfaces,
object?[]? constructorArguments,
IInterceptor[] interceptors,
ProxyGenerationOptions proxyGenerationOptions)
ProxyGenerationOptions proxyGenerationOptions,
bool isPartial)
{
if (isPartial)
return CreatePartialProxy(typeToProxy, additionalInterfaces, constructorArguments, interceptors, proxyGenerationOptions, isPartial);


if (typeToProxy.GetTypeInfo().IsInterface)
{
VerifyNoConstructorArgumentsGivenForInterface(constructorArguments);
Expand All @@ -108,13 +118,40 @@ private CastleForwardingInterceptor CreateForwardingInterceptor(ICallRouter call
additionalInterfaces = interfaces;
}


return _proxyGenerator.CreateClassProxy(typeToProxy,
additionalInterfaces,
proxyGenerationOptions,
constructorArguments,
interceptors);
}

private object CreatePartialProxy(Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments, IInterceptor[] interceptors, ProxyGenerationOptions proxyGenerationOptions, bool isPartial)
{
if (typeToProxy.GetTypeInfo().IsClass &&
additionalInterfaces != null &&
additionalInterfaces.Any())
{
VerifyClassIsNotAbstract(typeToProxy);
VerifyClassImplementsAllInterfaces(typeToProxy, additionalInterfaces);

var targetObject = Activator.CreateInstance(typeToProxy, constructorArguments);
typeToProxy = additionalInterfaces.First();

return _proxyGenerator.CreateInterfaceProxyWithTarget(typeToProxy,
additionalInterfaces,
target: targetObject,
options: proxyGenerationOptions,
interceptors: interceptors);
}

return _proxyGenerator.CreateClassProxy(typeToProxy,
additionalInterfaces,
proxyGenerationOptions,
constructorArguments,
interceptors);
}

private ProxyGenerationOptions GetOptionsToMixinCallRouterProvider(ICallRouter callRouter)
{
var options = new ProxyGenerationOptions(_allMethodsExceptCallRouterCallsHook);
Expand All @@ -128,6 +165,22 @@ private ProxyGenerationOptions GetOptionsToMixinCallRouterProvider(ICallRouter c
return options;
}

private static void VerifyClassImplementsAllInterfaces(Type classType, IEnumerable<Type> additionalInterfaces)
{
if (!additionalInterfaces.All(x => x.GetTypeInfo().IsAssignableFrom(classType.GetTypeInfo())))
{
throw new SubstituteException("The provided class doesn't implement all requested interfaces.");
}
}

private static void VerifyClassIsNotAbstract(Type classType)
{
if (classType.GetTypeInfo().IsAbstract)
{
throw new SubstituteException("The provided class is abstract.");
}
}

private static void VerifyNoConstructorArgumentsGivenForInterface(object?[]? constructorArguments)
{
if (HasItems(constructorArguments))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ public virtual ICall Map(IInvocation castleInvocation)
Func<object>? baseMethod = null;
if (castleInvocation.InvocationTarget != null &&
castleInvocation.MethodInvocationTarget.IsVirtual &&
!castleInvocation.MethodInvocationTarget.IsAbstract &&
!castleInvocation.MethodInvocationTarget.IsFinal)
!castleInvocation.MethodInvocationTarget.IsAbstract)
{
baseMethod = CreateBaseResultInvocation(castleInvocation);
}
Expand Down
4 changes: 2 additions & 2 deletions src/NSubstitute/Proxies/DelegateProxy/DelegateProxyFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ public DelegateProxyFactory(CastleDynamicProxyFactory objectProxyFactory)
_castleObjectProxyFactory = objectProxyFactory ?? throw new ArgumentNullException(nameof(objectProxyFactory));
}

public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments)
public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, bool isPartial, object?[]? constructorArguments)
{
// Castle factory can now resolve delegate proxies as well.
return _castleObjectProxyFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments);
return _castleObjectProxyFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, isPartial, constructorArguments);
}
}
}
9 changes: 5 additions & 4 deletions src/NSubstitute/Proxies/ProxyFactory.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Reflection;

using NSubstitute.Core;

namespace NSubstitute.Proxies
Expand All @@ -16,12 +17,12 @@ public ProxyFactory(IProxyFactory delegateFactory, IProxyFactory dynamicProxyFac
_dynamicProxyFactory = dynamicProxyFactory;
}

public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, object?[]? constructorArguments)
public object GenerateProxy(ICallRouter callRouter, Type typeToProxy, Type[]? additionalInterfaces, bool isPartial, object?[]? constructorArguments)
{
var isDelegate = typeToProxy.IsDelegate();
return isDelegate
? _delegateFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments)
: _dynamicProxyFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, constructorArguments);
return isDelegate
? _delegateFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, isPartial, constructorArguments)
: _dynamicProxyFactory.GenerateProxy(callRouter, typeToProxy, additionalInterfaces, isPartial, constructorArguments);
}
}
}
7 changes: 7 additions & 0 deletions src/NSubstitute/Substitute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,12 @@ public static T ForPartsOf<T>(params object[] constructorArguments)
var substituteFactory = SubstitutionContext.Current.SubstituteFactory;
return (T) substituteFactory.CreatePartial(new[] {typeof (T)}, constructorArguments);
}

public static T ForPartsOf<T,TClass>(params object[] constructorArguments)
where T : class
{
var substituteFactory = SubstitutionContext.Current.SubstituteFactory;
return (T)substituteFactory.CreatePartial(new[] { typeof(T), typeof(TClass) }, constructorArguments);
}
}
}
90 changes: 90 additions & 0 deletions tests/NSubstitute.Acceptance.Specs/PartialSubs.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using System;

using NSubstitute.Core;
using NSubstitute.Exceptions;
using NSubstitute.Extensions;

using NUnit.Framework;

namespace NSubstitute.Acceptance.Specs
Expand Down Expand Up @@ -87,6 +89,63 @@ public void UseImplementedVirtualMethod()
Assert.That(testAbstractClass.CalledTimes, Is.EqualTo(1));
}


[Test]
public void UseImplementedNonVirtualMethod()
{
var testAbstractClass = Substitute.ForPartsOf<ITestInterface, TestNonVirtualClass>();
Assert.That(testAbstractClass.MethodReturnsSameInt(1), Is.EqualTo(1));
Assert.That(testAbstractClass.CalledTimes, Is.EqualTo(1));
testAbstractClass.Received().MethodReturnsSameInt(1);
Assert.That(testAbstractClass.CalledTimes, Is.EqualTo(1));
}

[Test]
public void UseSubstitutedNonVirtualMethod()
{
var testInterface = Substitute.ForPartsOf<ITestInterface, TestNonVirtualClass>();
testInterface.Configure().MethodReturnsSameInt(1).Returns(2);
Assert.That(testInterface.MethodReturnsSameInt(1), Is.EqualTo(2));
Assert.That(testInterface.MethodReturnsSameInt(3), Is.EqualTo(3));
testInterface.ReceivedWithAnyArgs(2).MethodReturnsSameInt(default);
Assert.That(testInterface.CalledTimes, Is.EqualTo(1));
}

[Test]
public void UseSubstitutedNonVirtualMethodHonorsDoNotCallBase()
{
var testInterface = Substitute.ForPartsOf<ITestInterface, TestNonVirtualClass>();
testInterface.Configure().MethodReturnsSameInt(1).Returns(2);
testInterface.WhenForAnyArgs(x => x.MethodReturnsSameInt(default)).DoNotCallBase();
Assert.That(testInterface.MethodReturnsSameInt(1), Is.EqualTo(2));
Assert.That(testInterface.MethodReturnsSameInt(3), Is.EqualTo(0));
testInterface.ReceivedWithAnyArgs(2).MethodReturnsSameInt(default);
Assert.That(testInterface.CalledTimes, Is.EqualTo(0));
}

[Test]
public void PartialSubstituteCallsConstructorWithParameters()
{
var testInterface = Substitute.ForPartsOf<ITestInterface, TestNonVirtualClass>(50);
Assert.That(testInterface.MethodReturnsSameInt(1), Is.EqualTo(1));
Assert.That(testInterface.CalledTimes, Is.EqualTo(51));
}

[Test]
public void PartialSubstituteFailsIfClassDoesntImplementInterface()
{
Assert.Throws<SubstituteException>(
() => Substitute.ForPartsOf<ITestInterface, TestAbstractClass>());
}


[Test]
public void PartialSubstituteFailsIfClassIsAbstract()
{
Assert.Throws<SubstituteException>(
() => Substitute.ForPartsOf<ITestInterface, TestAbstractClassWithInterface>(), "The provided class is abstract.");
}

[Test]
public void ReturnDefaultForUnimplementedAbstractMethod()
{
Expand Down Expand Up @@ -307,8 +366,39 @@ public void ShouldThrowExceptionIfConfigureGlobalCallBaseForDelegateProxy()

public interface ITestInterface
{
public int CalledTimes { get; set; }

void VoidTestMethod();
int TestMethodReturnsInt();
int MethodReturnsSameInt(int i);
}

public class TestNonVirtualClass : ITestInterface
{
public TestNonVirtualClass() { }
public TestNonVirtualClass(int initialCounter) => CalledTimes = initialCounter;

public int CalledTimes { get; set; }

public int TestMethodReturnsInt() => throw new NotImplementedException();

public void VoidTestMethod() => throw new NotImplementedException();
public int MethodReturnsSameInt(int i)
{
CalledTimes++;
return i;
}
}

public abstract class TestAbstractClassWithInterface : ITestInterface
{
public int CalledTimes { get; set; }

public abstract int MethodReturnsSameInt(int i);

public abstract int TestMethodReturnsInt();

public abstract void VoidTestMethod();
}

public abstract class TestAbstractClass
Expand Down

0 comments on commit a04dc32

Please sign in to comment.