Skip to content

Commit

Permalink
* Add more use of CancellationToken in Async methods. (#1468)
Browse files Browse the repository at this point in the history
* Correctly dispose of `CancellationTokenSource` and `CancellationTokenRegistration` instances.
* Refactor `Connection.Close` to use async internally.
* Fix test by adding `WaitAsync` that also takes a timeout.
* Add `ConfigureAwait` where it was missing.
* Always create `CancellationTokenSource` for recovery, and dispose it.
* Modify `WaitAsync` `Task` extension to see if `Task` has already completed.
* Don't swallow exceptions unless `abort` is specified.
* Add `TaskCreationOptions` to two spots.
* Add `SetSessionClosingAsync`
* Use `CancellationToken` to stop receieve loop.
* Pass the main loop `CancellationToken` into `HardProtocolExceptionHandlerAsync`.
* Pass `CancellationToken` to `IFrameHandler.CloseAsync`.
* Remove remaining usage of `ThreadPool`
  • Loading branch information
lukebakken committed Feb 1, 2024
1 parent 220f5a5 commit 1fa0562
Show file tree
Hide file tree
Showing 42 changed files with 949 additions and 710 deletions.
29 changes: 0 additions & 29 deletions projects/RabbitMQ.Client/FrameworkExtension/Interlocked.cs

This file was deleted.

2 changes: 1 addition & 1 deletion projects/RabbitMQ.Client/PublicAPI.Unshipped.txt
Expand Up @@ -918,7 +918,7 @@ virtual RabbitMQ.Client.TcpClientAdapter.ReceiveTimeout.set -> void
~RabbitMQ.Client.IChannel.TxCommitAsync() -> System.Threading.Tasks.Task
~RabbitMQ.Client.IChannel.TxRollbackAsync() -> System.Threading.Tasks.Task
~RabbitMQ.Client.IChannel.TxSelectAsync() -> System.Threading.Tasks.Task
~RabbitMQ.Client.IConnection.CloseAsync(ushort reasonCode, string reasonText, System.TimeSpan timeout, bool abort) -> System.Threading.Tasks.Task
~RabbitMQ.Client.IConnection.CloseAsync(ushort reasonCode, string reasonText, System.TimeSpan timeout, bool abort, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task
~RabbitMQ.Client.IConnection.CreateChannelAsync() -> System.Threading.Tasks.Task<RabbitMQ.Client.IChannel>
~RabbitMQ.Client.IConnection.UpdateSecretAsync(string newSecret, string reason) -> System.Threading.Tasks.Task
~RabbitMQ.Client.IConnectionFactory.CreateConnectionAsync(string clientProvidedName, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<RabbitMQ.Client.IConnection>
Expand Down
70 changes: 63 additions & 7 deletions projects/RabbitMQ.Client/client/TaskExtensions.cs
Expand Up @@ -53,14 +53,67 @@ public static bool IsCompletedSuccessfully(this Task task)
private static readonly TaskContinuationOptions s_tco = TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously;
private static void IgnoreTaskContinuation(Task t, object s) => t.Exception.Handle(e => true);

public static async Task WithCancellation(this Task task, CancellationToken cancellationToken)
// https://devblogs.microsoft.com/pfxteam/how-do-i-cancel-non-cancelable-async-operations/
public static Task WaitAsync(this Task task, TimeSpan timeout, CancellationToken cancellationToken)
{
var tcs = new TaskCompletionSource<bool>();
if (task.IsCompletedSuccessfully())
{
return task;
}
else
{
return DoWaitWithTimeoutAsync(task, timeout, cancellationToken);
}
}

private static async Task DoWaitWithTimeoutAsync(this Task task, TimeSpan timeout, CancellationToken cancellationToken)
{
using var timeoutTokenCts = new CancellationTokenSource(timeout);
CancellationToken timeoutToken = timeoutTokenCts.Token;

var linkedTokenTcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(timeoutToken, cancellationToken);
using CancellationTokenRegistration cancellationTokenRegistration =
linkedCts.Token.Register(s => ((TaskCompletionSource<bool>)s).TrySetResult(true),
state: linkedTokenTcs, useSynchronizationContext: false);

if (task != await Task.WhenAny(task, linkedTokenTcs.Task).ConfigureAwait(false))
{
task.Ignore();
if (timeoutToken.IsCancellationRequested)
{
throw new OperationCanceledException($"Operation timed out after {timeout}");
}
else
{
throw new OperationCanceledException(cancellationToken);
}
}

// https://devblogs.microsoft.com/pfxteam/how-do-i-cancel-non-cancelable-async-operations/
using (cancellationToken.Register(s => ((TaskCompletionSource<bool>)s).TrySetResult(true), tcs))
await task.ConfigureAwait(false);
}

// https://devblogs.microsoft.com/pfxteam/how-do-i-cancel-non-cancelable-async-operations/
public static Task WaitAsync(this Task task, CancellationToken cancellationToken)
{
if (task.IsCompletedSuccessfully())
{
if (task != await Task.WhenAny(task, tcs.Task))
return task;
}
else
{
return DoWaitAsync(task, cancellationToken);
}
}

private static async Task DoWaitAsync(this Task task, CancellationToken cancellationToken)
{
var cancellationTokenTcs = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);

using (cancellationToken.Register(s => ((TaskCompletionSource<bool>)s).TrySetResult(true),
state: cancellationTokenTcs, useSynchronizationContext: false))
{
if (task != await Task.WhenAny(task, cancellationTokenTcs.Task).ConfigureAwait(false))
{
task.Ignore();
throw new OperationCanceledException(cancellationToken);
Expand Down Expand Up @@ -172,10 +225,13 @@ public static T EnsureCompleted<T>(this ValueTask<T> task)

public static void EnsureCompleted(this ValueTask task)
{
task.GetAwaiter().GetResult();
if (false == task.IsCompletedSuccessfully)
{
task.GetAwaiter().GetResult();
}
}

#if !NET6_0_OR_GREATER
#if NETSTANDARD
// https://github.com/dotnet/runtime/issues/23878
// https://github.com/dotnet/runtime/issues/23878#issuecomment-1398958645
public static void Ignore(this Task task)
Expand Down
1 change: 1 addition & 0 deletions projects/RabbitMQ.Client/client/api/ConnectionFactory.cs
Expand Up @@ -618,6 +618,7 @@ private ConnectionConfig CreateConfig(string clientProvidedName)
internal async Task<IFrameHandler> CreateFrameHandlerAsync(
AmqpTcpEndpoint endpoint, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
IFrameHandler fh = new SocketFrameHandler(endpoint, SocketFactory, RequestedConnectionTimeout, SocketReadTimeout, SocketWriteTimeout);
await fh.ConnectAsync(cancellationToken)
.ConfigureAwait(false);
Expand Down
6 changes: 4 additions & 2 deletions projects/RabbitMQ.Client/client/api/IConnection.cs
Expand Up @@ -31,6 +31,7 @@

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using RabbitMQ.Client.Events;
using RabbitMQ.Client.Exceptions;
Expand Down Expand Up @@ -222,9 +223,10 @@ public interface IConnection : INetworkConnection, IDisposable
/// </summary>
/// <param name="reasonCode">The close code (See under "Reply Codes" in the AMQP 0-9-1 specification).</param>
/// <param name="reasonText">A message indicating the reason for closing the connection.</param>
/// <param name="timeout">Operation timeout.</param>
/// <param name="timeout"></param>
/// <param name="abort">Whether or not this close is an abort (ignores certain exceptions).</param>
Task CloseAsync(ushort reasonCode, string reasonText, TimeSpan timeout, bool abort);
/// <param name="cancellationToken"></param>
Task CloseAsync(ushort reasonCode, string reasonText, TimeSpan timeout, bool abort, CancellationToken cancellationToken = default);

/// <summary>
/// Asynchronously create and return a fresh channel, session, and channel.
Expand Down
24 changes: 16 additions & 8 deletions projects/RabbitMQ.Client/client/api/IConnectionExtensions.cs
Expand Up @@ -20,7 +20,8 @@ public static class IConnectionExtensions
/// </remarks>
public static Task CloseAsync(this IConnection connection)
{
return connection.CloseAsync(Constants.ReplySuccess, "Goodbye", InternalConstants.DefaultConnectionCloseTimeout, false);
return connection.CloseAsync(Constants.ReplySuccess, "Goodbye", InternalConstants.DefaultConnectionCloseTimeout, false,
CancellationToken.None);
}

/// <summary>
Expand All @@ -38,7 +39,8 @@ public static Task CloseAsync(this IConnection connection)
/// </remarks>
public static Task CloseAsync(this IConnection connection, ushort reasonCode, string reasonText)
{
return connection.CloseAsync(reasonCode, reasonText, InternalConstants.DefaultConnectionCloseTimeout, false);
return connection.CloseAsync(reasonCode, reasonText, InternalConstants.DefaultConnectionCloseTimeout, false,
CancellationToken.None);
}

/// <summary>
Expand All @@ -58,7 +60,8 @@ public static Task CloseAsync(this IConnection connection, ushort reasonCode, st
/// </remarks>
public static Task CloseAsync(this IConnection connection, TimeSpan timeout)
{
return connection.CloseAsync(Constants.ReplySuccess, "Goodbye", timeout, false);
return connection.CloseAsync(Constants.ReplySuccess, "Goodbye", timeout, false,
CancellationToken.None);
}

/// <summary>
Expand All @@ -80,7 +83,8 @@ public static Task CloseAsync(this IConnection connection, TimeSpan timeout)
/// </remarks>
public static Task CloseAsync(this IConnection connection, ushort reasonCode, string reasonText, TimeSpan timeout)
{
return connection.CloseAsync(reasonCode, reasonText, timeout, false);
return connection.CloseAsync(reasonCode, reasonText, timeout, false,
CancellationToken.None);
}

/// <summary>
Expand All @@ -94,7 +98,8 @@ public static Task CloseAsync(this IConnection connection, ushort reasonCode, st
/// </remarks>
public static Task AbortAsync(this IConnection connection)
{
return connection.CloseAsync(Constants.ReplySuccess, "Connection close forced", InternalConstants.DefaultConnectionAbortTimeout, true);
return connection.CloseAsync(Constants.ReplySuccess, "Connection close forced", InternalConstants.DefaultConnectionAbortTimeout, true,
CancellationToken.None);
}

/// <summary>
Expand All @@ -112,7 +117,8 @@ public static Task AbortAsync(this IConnection connection)
/// </remarks>
public static Task AbortAsync(this IConnection connection, ushort reasonCode, string reasonText)
{
return connection.CloseAsync(reasonCode, reasonText, InternalConstants.DefaultConnectionAbortTimeout, true);
return connection.CloseAsync(reasonCode, reasonText, InternalConstants.DefaultConnectionAbortTimeout, true,
CancellationToken.None);
}

/// <summary>
Expand All @@ -130,7 +136,8 @@ public static Task AbortAsync(this IConnection connection, ushort reasonCode, st
/// </remarks>
public static Task AbortAsync(this IConnection connection, TimeSpan timeout)
{
return connection.CloseAsync(Constants.ReplySuccess, "Connection close forced", timeout, true);
return connection.CloseAsync(Constants.ReplySuccess, "Connection close forced", timeout, true,
CancellationToken.None);
}

/// <summary>
Expand All @@ -149,7 +156,8 @@ public static Task AbortAsync(this IConnection connection, TimeSpan timeout)
/// </remarks>
public static Task AbortAsync(this IConnection connection, ushort reasonCode, string reasonText, TimeSpan timeout)
{
return connection.CloseAsync(reasonCode, reasonText, timeout, true);
return connection.CloseAsync(reasonCode, reasonText, timeout, true,
CancellationToken.None);
}
}
}
Expand Up @@ -45,6 +45,7 @@ public static class EndpointResolverExtensions
var exceptions = new List<Exception>();
foreach (AmqpTcpEndpoint ep in resolver.All())
{
cancellationToken.ThrowIfCancellationRequested();
try
{
t = await selector(ep, cancellationToken).ConfigureAwait(false);
Expand Down
2 changes: 1 addition & 1 deletion projects/RabbitMQ.Client/client/api/TcpClientAdapter.cs
Expand Up @@ -29,7 +29,7 @@ public virtual Task ConnectAsync(IPAddress ep, int port, CancellationToken cance
#else
public virtual Task ConnectAsync(IPAddress ep, int port, CancellationToken cancellationToken = default)
{
return _sock.ConnectAsync(ep, port).WithCancellation(cancellationToken);
return _sock.ConnectAsync(ep, port).WaitAsync(cancellationToken);
}
#endif

Expand Down
Expand Up @@ -92,7 +92,8 @@ public override void HandleBasicConsumeOk(string consumerTag)
BasicDeliverEventArgs eventArgs = new BasicDeliverEventArgs(consumerTag, deliveryTag, redelivered, exchange, routingKey, properties, body);
using (Activity activity = RabbitMQActivitySource.SubscriberHasListeners ? RabbitMQActivitySource.Deliver(eventArgs) : default)
{
await base.HandleBasicDeliverAsync(consumerTag, deliveryTag, redelivered, exchange, routingKey, properties, body);
await base.HandleBasicDeliverAsync(consumerTag, deliveryTag, redelivered, exchange, routingKey, properties, body)
.ConfigureAwait(false);
Received?.Invoke(this, eventArgs);
}
}
Expand Down
11 changes: 8 additions & 3 deletions projects/RabbitMQ.Client/client/framing/Channel.cs
Expand Up @@ -29,6 +29,8 @@
// Copyright (c) 2007-2020 VMware, Inc. All rights reserved.
//---------------------------------------------------------------------------

using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using RabbitMQ.Client.client.framing;
using RabbitMQ.Client.Impl;
Expand Down Expand Up @@ -69,19 +71,22 @@ public override void _Private_ConnectionCloseOk()
public override ValueTask BasicAckAsync(ulong deliveryTag, bool multiple)
{
var method = new BasicAck(deliveryTag, multiple);
return ModelSendAsync(method);
// TODO cancellation token?
return ModelSendAsync(method, CancellationToken.None);
}

public override ValueTask BasicNackAsync(ulong deliveryTag, bool multiple, bool requeue)
{
var method = new BasicNack(deliveryTag, multiple, requeue);
return ModelSendAsync(method);
// TODO use cancellation token
return ModelSendAsync(method, CancellationToken.None);
}

public override Task BasicRejectAsync(ulong deliveryTag, bool requeue)
{
var method = new BasicReject(deliveryTag, requeue);
return ModelSendAsync(method).AsTask();
// TODO cancellation token?
return ModelSendAsync(method, CancellationToken.None).AsTask();
}

protected override bool DispatchAsynchronous(in IncomingCommand cmd)
Expand Down
31 changes: 27 additions & 4 deletions projects/RabbitMQ.Client/client/impl/AsyncRpcContinuations.cs
Expand Up @@ -44,8 +44,8 @@ namespace RabbitMQ.Client.Impl
internal abstract class AsyncRpcContinuation<T> : IRpcContinuation, IDisposable
{
private readonly CancellationTokenSource _cancellationTokenSource;
private readonly CancellationTokenRegistration _cancellationTokenRegistration;
private readonly ConfiguredTaskAwaitable<T> _tcsConfiguredTaskAwaitable;

protected readonly TaskCompletionSource<T> _tcs = new TaskCompletionSource<T>(TaskCreationOptions.RunContinuationsAsynchronously);

private bool _disposedValue;
Expand All @@ -59,21 +59,43 @@ public AsyncRpcContinuation(TimeSpan continuationTimeout)
*/
_cancellationTokenSource = new CancellationTokenSource(continuationTimeout);

_cancellationTokenSource.Token.Register(() =>
#if NET6_0_OR_GREATER
_cancellationTokenRegistration = _cancellationTokenSource.Token.UnsafeRegister((object state) =>
{
var tcs = (TaskCompletionSource<T>)state;
if (tcs.TrySetCanceled())
{
// TODO LRB rabbitmq/rabbitmq-dotnet-client#1347
// Cancellation was successful, does this mean we should set a TimeoutException
// in the same manner as BlockingCell?
}
}, _tcs);
#else
_cancellationTokenRegistration = _cancellationTokenSource.Token.Register((object state) =>
{
if (_tcs.TrySetCanceled())
var tcs = (TaskCompletionSource<T>)state;
if (tcs.TrySetCanceled())
{
// TODO LRB rabbitmq/rabbitmq-dotnet-client#1347
// Cancellation was successful, does this mean we should set a TimeoutException
// in the same manner as BlockingCell?
}
}, useSynchronizationContext: false);
}, state: _tcs, useSynchronizationContext: false);
#endif

_tcsConfiguredTaskAwaitable = _tcs.Task.ConfigureAwait(false);
}

internal DateTime StartTime { get; } = DateTime.UtcNow;

public CancellationToken CancellationToken
{
get
{
return _cancellationTokenSource.Token;
}
}

public ConfiguredTaskAwaitable<T>.ConfiguredTaskAwaiter GetAwaiter()
{
return _tcsConfiguredTaskAwaitable.GetAwaiter();
Expand All @@ -92,6 +114,7 @@ protected virtual void Dispose(bool disposing)
{
if (disposing)
{
_cancellationTokenRegistration.Dispose();
_cancellationTokenSource.Dispose();
}

Expand Down

0 comments on commit 1fa0562

Please sign in to comment.