Skip to content

Commit

Permalink
Merge pull request #4640 from RenderMichael/awaitadapter
Browse files Browse the repository at this point in the history
Properly handle generic ValueTask await adapter
  • Loading branch information
stevenaw committed Feb 27, 2024
2 parents 1b91bdd + 5af6158 commit 331944e
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 14 deletions.
6 changes: 3 additions & 3 deletions src/NUnitFramework/framework/Internal/AwaitAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ public static AwaitAdapter FromAwaitable(object? awaitable)
if (awaitable is System.Threading.Tasks.ValueTask valueTask)
return ValueTaskAwaitAdapter.Create(valueTask);

// Await all the (C# and F#) things
var adapter =
CSharpPatternBasedAwaitAdapter.TryCreate(awaitable)
var adapter = ValueTaskAwaitAdapter.TryCreate(awaitable)
// Await all the (C# and F#) things
?? CSharpPatternBasedAwaitAdapter.TryCreate(awaitable)
?? FSharpAsyncAwaitAdapter.TryCreate(awaitable);
if (adapter is not null)
return adapter;
Expand Down
21 changes: 10 additions & 11 deletions src/NUnitFramework/framework/Internal/ValueTaskAwaitAdapter.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Charlie Poole, Rob Prouse and Contributors. MIT License - see LICENSE.txt

using System;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;

Expand All @@ -11,21 +10,21 @@ internal static class ValueTaskAwaitAdapter
{
public static AwaitAdapter Create(ValueTask task)
{
var genericValueTaskType = task
.GetType()
.TypeAndBaseTypes()
.FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ValueTask<>));
return new NonGenericAdapter(task);
}

if (genericValueTaskType is not null)
public static AwaitAdapter? TryCreate(object task)
{
Type taskType = task.GetType();
if (taskType.IsGenericType && taskType.GetGenericTypeDefinition() == typeof(ValueTask<>))
{
var typeArgument = genericValueTaskType.GetGenericArguments()[0];
return (AwaitAdapter)typeof(GenericAdapter<>)
.MakeGenericType(typeArgument)
.GetConstructor(new[] { typeof(ValueTask<>).MakeGenericType(typeArgument) })!
.Invoke(new object[] { task });
.MakeGenericType(taskType.GetGenericArguments()[0])
.GetConstructor(new[] { taskType })!
.Invoke(new object[] { task });
}

return new NonGenericAdapter(task);
return null;
}

private sealed class NonGenericAdapter : AwaitAdapter
Expand Down
56 changes: 56 additions & 0 deletions src/NUnitFramework/testdata/AwaitableReturnTypeFixture.cs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,34 @@ public void OnCompleted(Action continuation)
}
}

public ValueTask ReturnsValueTask()
{
_workload.BeforeReturningAwaitable();
_workload.BeforeReturningAwaiter();

var source = new TaskCompletionSource<object?>();

var complete = new Action(() =>
{
try
{
_workload.GetResult();
source.SetResult(null);
}
catch (Exception ex)
{
source.SetException(ex);
}
});

if (_workload.IsCompleted)
complete.Invoke();
else
_workload.OnCompleted(complete);

return new ValueTask(source.Task);
}

#endregion

#region Non-void result
Expand Down Expand Up @@ -228,6 +256,34 @@ public Task<object> ReturnsNonVoidResultTask()
return source.Task;
}

[Test(ExpectedResult = 42)]
public ValueTask<object> ReturnsNonVoidResultValueTask()
{
_workload.BeforeReturningAwaitable();
_workload.BeforeReturningAwaiter();

var source = new TaskCompletionSource<object>();

var complete = new Action(() =>
{
try
{
source.SetResult(_workload.GetResult());
}
catch (Exception ex)
{
source.SetException(ex);
}
});

if (_workload.IsCompleted)
complete.Invoke();
else
_workload.OnCompleted(complete);

return new ValueTask<object>(source.Task);
}

[Test(ExpectedResult = 42)]
public NonVoidResultCustomTask ReturnsNonVoidResultCustomTask()
{
Expand Down
1 change: 1 addition & 0 deletions src/NUnitFramework/tests/AwaitableReturnTypeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace NUnit.Framework.Tests
[TestFixture(nameof(F.ReturnsCustomAwaitable))]
[TestFixture(nameof(F.ReturnsCustomAwaitableWithImplicitOnCompleted))]
[TestFixture(nameof(F.ReturnsCustomAwaitableWithImplicitUnsafeOnCompleted))]
[TestFixture(nameof(F.ReturnsValueTask))]
public class AwaitableReturnTypeTests
{
private readonly string _methodName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace NUnit.Framework.Tests
[TestFixture(nameof(F.ReturnsNonVoidResultCustomAwaitable))]
[TestFixture(nameof(F.ReturnsNonVoidResultCustomAwaitableWithImplicitOnCompleted))]
[TestFixture(nameof(F.ReturnsNonVoidResultCustomAwaitableWithImplicitUnsafeOnCompleted))]
[TestFixture(nameof(F.ReturnsNonVoidResultValueTask))]
public sealed class NonVoidResultAwaitableReturnTypeTests : AwaitableReturnTypeTests
{
public NonVoidResultAwaitableReturnTypeTests(string methodName)
Expand Down

0 comments on commit 331944e

Please sign in to comment.