Skip to content

Commit

Permalink
Add protected BaseResult() method to CallInfo.
Browse files Browse the repository at this point in the history
Create CallInfo<T> to calls that return results and expose `BaseResult`.
This gets messy to push the generic all the way through the code, so am
just using a cast in `Returns` extensions to handle this. This should be
safe as if we are in `Returns<T>` then the return value should be safe to
cast to a `T`.

Based off discussion here:
#622 (comment)
  • Loading branch information
dtchepak committed Jan 11, 2021
1 parent a4e2e6d commit 102eccc
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 26 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,6 @@ docs/_site/*

# Ignore VIM tmp files
*.swp

# kdiff/merge files
*.orig
17 changes: 16 additions & 1 deletion src/NSubstitute/Core/CallInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,25 @@ namespace NSubstitute.Core
public class CallInfo
{
private readonly Argument[] _callArguments;
private readonly Func<Maybe<object>> _baseResult;

public CallInfo(Argument[] callArguments)
public CallInfo(Argument[] callArguments, Func<Maybe<object>> baseResult)
{
_callArguments = callArguments;
_baseResult = baseResult;
}

protected CallInfo(CallInfo info) : this(info._callArguments, info._baseResult) {
}

/// <summary>
/// Call and returns the result from the base implementation of a substitute for a class.
/// Will throw an exception if no base implementation exists.
/// </summary>
/// <returns>Result from base implementation</returns>
/// <exception cref="NoBaseImplementationException">Throws in no base implementation exists</exception>
protected object GetBaseResult() {
return _baseResult().ValueOr(() => throw new NoBaseImplementationException());
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion src/NSubstitute/Core/CallInfoFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ public class CallInfoFactory : ICallInfoFactory
public CallInfo Create(ICall call)
{
var arguments = GetArgumentsFromCall(call);
return new CallInfo(arguments);
return new CallInfo(arguments, () => call.TryCallBase());
}

private static Argument[] GetArgumentsFromCall(ICall call)
Expand Down
16 changes: 16 additions & 0 deletions src/NSubstitute/Core/CallInfoWithReturns.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
namespace NSubstitute.Core
{
/// <summary>
/// Information for a call that returns a value of type <c>T</c>.
/// </summary>
/// <typeparam name="T"></typeparam>
public class CallInfo<T> : CallInfo
{
internal CallInfo(CallInfo info) : base(info) {
}

public T BaseResult() {
return (T)GetBaseResult();
}
}
}
31 changes: 16 additions & 15 deletions src/NSubstitute/Core/IReturn.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ public ReturnValue(object? value)

public class ReturnValueFromFunc<T> : IReturn
{
private readonly Func<CallInfo, T?> _funcToReturnValue;
private readonly Func<CallInfo<T>, T?> _funcToReturnValue;

public ReturnValueFromFunc(Func<CallInfo, T?>? funcToReturnValue)
public ReturnValueFromFunc(Func<CallInfo<T>, T?>? funcToReturnValue)
{
_funcToReturnValue = funcToReturnValue ?? ReturnNull();
}

public object? ReturnFor(CallInfo info) => _funcToReturnValue(info);
public Type TypeOrNull() => typeof (T);
public bool CanBeAssignedTo(Type t) => typeof (T).IsAssignableFrom(t);
public object? ReturnFor(CallInfo info) => _funcToReturnValue(new CallInfo<T>(info));
public Type TypeOrNull() => typeof(T);
public bool CanBeAssignedTo(Type t) => typeof(T).IsAssignableFrom(t);

private static Func<CallInfo, T?> ReturnNull()
{
if (typeof(T).GetTypeInfo().IsValueType) throw new CannotReturnNullForValueType(typeof(T));
return x => default(T);
return x => default;
}
}

Expand All @@ -70,27 +70,28 @@ public ReturnMultipleValues(T?[] values)

public object? GetReturnValue() => GetNext();
public object? ReturnFor(CallInfo info) => GetReturnValue();
public Type TypeOrNull() => typeof (T);
public bool CanBeAssignedTo(Type t) => typeof (T).IsAssignableFrom(t);
public Type TypeOrNull() => typeof(T);
public bool CanBeAssignedTo(Type t) => typeof(T).IsAssignableFrom(t);

private T? GetNext() => _valuesToReturn.TryDequeue(out var nextResult) ? nextResult : _lastValue;
}

public class ReturnMultipleFuncsValues<T> : IReturn
{
private readonly ConcurrentQueue<Func<CallInfo, T?>> _funcsToReturn;
private readonly Func<CallInfo, T?> _lastFunc;
private readonly ConcurrentQueue<Func<CallInfo<T>, T?>> _funcsToReturn;
private readonly Func<CallInfo<T>, T?> _lastFunc;

public ReturnMultipleFuncsValues(Func<CallInfo, T?>[] funcs)
public ReturnMultipleFuncsValues(Func<CallInfo<T>, T?>[] funcs)
{
_funcsToReturn = new ConcurrentQueue<Func<CallInfo, T?>>(funcs);
_funcsToReturn = new ConcurrentQueue<Func<CallInfo<T>, T?>>(funcs);
_lastFunc = funcs.Last();
}

public object? ReturnFor(CallInfo info) => GetNext(info);
public Type TypeOrNull() => typeof (T);
public bool CanBeAssignedTo(Type t) => typeof (T).IsAssignableFrom(t);
public Type TypeOrNull() => typeof(T);
public bool CanBeAssignedTo(Type t) => typeof(T).IsAssignableFrom(t);

private T? GetNext(CallInfo info) => _funcsToReturn.TryDequeue(out var nextFunc) ? nextFunc(info) : _lastFunc(info);
private T? GetNext(CallInfo info) =>
_funcsToReturn.TryDequeue(out var nextFunc) ? nextFunc(new CallInfo<T>(info)) : _lastFunc(new CallInfo<T>(info));
}
}
11 changes: 11 additions & 0 deletions src/NSubstitute/Exceptions/NoBaseImplementationException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
namespace NSubstitute.Exceptions
{
public class NoBaseImplementationException : SubstituteException
{
private const string Explanation =
"Cannot call the base method as the base method implementation is missing. " +
"You can call base method only if you create a class substitute and the method is not abstract.";

public NoBaseImplementationException() : base(Explanation) { }
}
}
4 changes: 2 additions & 2 deletions src/NSubstitute/SubstituteExtensions.Returns.Task.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public static ConfiguredCall Returns<T>(this Task<T> value, Func<CallInfo, T> re
var wrappedFunc = WrapFuncInTask(returnThis);
var wrappedReturnThese = returnThese.Length > 0 ? returnThese.Select(WrapFuncInTask).ToArray() : null;

return ConfigureReturn(MatchArgs.AsSpecifiedInCall, wrappedFunc, wrappedReturnThese);
return ConfigureFuncReturn(MatchArgs.AsSpecifiedInCall, wrappedFunc, wrappedReturnThese);
}

/// <summary>
Expand Down Expand Up @@ -72,7 +72,7 @@ public static ConfiguredCall ReturnsForAnyArgs<T>(this Task<T> value, Func<CallI
var wrappedFunc = WrapFuncInTask(returnThis);
var wrappedReturnThese = returnThese.Length > 0 ? returnThese.Select(WrapFuncInTask).ToArray() : null;

return ConfigureReturn(MatchArgs.Any, wrappedFunc, wrappedReturnThese);
return ConfigureFuncReturn(MatchArgs.Any, wrappedFunc, wrappedReturnThese);
}

#nullable restore
Expand Down
4 changes: 2 additions & 2 deletions src/NSubstitute/SubstituteExtensions.Returns.ValueTask.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public static ConfiguredCall Returns<T>(this ValueTask<T> value, Func<CallInfo,
var wrappedFunc = WrapFuncInValueTask(returnThis);
var wrappedReturnThese = returnThese.Length > 0 ? returnThese.Select(WrapFuncInValueTask).ToArray() : null;

return ConfigureReturn(MatchArgs.AsSpecifiedInCall, wrappedFunc, wrappedReturnThese);
return ConfigureFuncReturn(MatchArgs.AsSpecifiedInCall, wrappedFunc, wrappedReturnThese);
}

/// <summary>
Expand Down Expand Up @@ -72,7 +72,7 @@ public static ConfiguredCall ReturnsForAnyArgs<T>(this ValueTask<T> value, Func<
var wrappedFunc = WrapFuncInValueTask(returnThis);
var wrappedReturnThese = returnThese.Length > 0 ? returnThese.Select(WrapFuncInValueTask).ToArray() : null;

return ConfigureReturn(MatchArgs.Any, wrappedFunc, wrappedReturnThese);
return ConfigureFuncReturn(MatchArgs.Any, wrappedFunc, wrappedReturnThese);
}

#nullable restore
Expand Down
10 changes: 5 additions & 5 deletions src/NSubstitute/SubstituteExtensions.Returns.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ public static partial class SubstituteExtensions
/// <param name="value"></param>
/// <param name="returnThis">Function to calculate the return value</param>
/// <param name="returnThese">Optionally use these functions next</param>
public static ConfiguredCall Returns<T>(this T value, Func<CallInfo, T> returnThis, params Func<CallInfo, T>[] returnThese) =>
ConfigureReturn(MatchArgs.AsSpecifiedInCall, returnThis, returnThese);
public static ConfiguredCall Returns<T>(this T value, Func<CallInfo<T>, T> returnThis, params Func<CallInfo<T>, T>[] returnThese) =>
ConfigureFuncReturn(MatchArgs.AsSpecifiedInCall, returnThis, returnThese);

/// <summary>
/// Set a return value for this call made with any arguments.
Expand All @@ -43,8 +43,8 @@ public static partial class SubstituteExtensions
/// <param name="returnThis">Function to calculate the return value</param>
/// <param name="returnThese">Optionally use these functions next</param>
/// <returns></returns>
public static ConfiguredCall ReturnsForAnyArgs<T>(this T value, Func<CallInfo, T> returnThis, params Func<CallInfo, T>[] returnThese) =>
ConfigureReturn(MatchArgs.Any, returnThis, returnThese);
public static ConfiguredCall ReturnsForAnyArgs<T>(this T value, Func<CallInfo<T>, T> returnThis, params Func<CallInfo<T>, T>[] returnThese) =>
ConfigureFuncReturn(MatchArgs.Any, returnThis, returnThese);

#nullable restore
private static ConfiguredCall ConfigureReturn<T>(MatchArgs matchArgs, T? returnThis, T?[]? returnThese)
Expand All @@ -64,7 +64,7 @@ private static ConfiguredCall ConfigureReturn<T>(MatchArgs matchArgs, T? returnT
.LastCallShouldReturn(returnValue, matchArgs);
}

private static ConfiguredCall ConfigureReturn<T>(MatchArgs matchArgs, Func<CallInfo, T?> returnThis, Func<CallInfo, T?>[]? returnThese)
private static ConfiguredCall ConfigureFuncReturn<T>(MatchArgs matchArgs, Func<CallInfo<T>, T?> returnThis, Func<CallInfo<T>, T?>[]? returnThese)
{
IReturn returnValue;
if (returnThese == null || returnThese.Length == 0)
Expand Down
52 changes: 52 additions & 0 deletions tests/NSubstitute.Acceptance.Specs/ReturnFromBase.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using System;
using NSubstitute.Exceptions;
using NUnit.Framework;

namespace NSubstitute.Acceptance.Specs
{
public class ReturnFromBase
{
public class Sample
{
public virtual string RepeatButLouder(string s) => s + "!";
public virtual void VoidMethod() { }
}

public abstract class SampleWithAbstractMethod
{
public abstract string NoBaseImplementation();
}

public interface ISample
{
string InterfaceMethod();
}

[Test]
public void UseBaseInReturn() {
var sub = Substitute.For<Sample>();
sub.RepeatButLouder(Arg.Any<string>()).Returns(x => x.BaseResult() + "?");

Assert.AreEqual("Hi!?", sub.RepeatButLouder("Hi"));
}

[Test]
public void CallWithNoBaseImplementation() {
var sub = Substitute.For<SampleWithAbstractMethod>();
sub.NoBaseImplementation().Returns(x => x.BaseResult());

Assert.Throws<NoBaseImplementationException>(() =>
sub.NoBaseImplementation()
);
}

[Test]
public void CallBaseForInterface() {
var sub = Substitute.For<ISample>();
sub.InterfaceMethod().Returns(x => x.BaseResult());
Assert.Throws<NoBaseImplementationException>(() =>
sub.InterfaceMethod()
);
}
}
}

0 comments on commit 102eccc

Please sign in to comment.