Skip to content

Commit

Permalink
Merge pull request #1566 from rabbitmq/rabbitmq-dotnet-client-1356
Browse files Browse the repository at this point in the history
Enforce max message size with mutiple content frames
  • Loading branch information
lukebakken committed May 16, 2024
2 parents bf9a35a + c63959f commit e52d703
Show file tree
Hide file tree
Showing 16 changed files with 120 additions and 84 deletions.
12 changes: 6 additions & 6 deletions projects/RabbitMQ.Client/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ const RabbitMQ.Client.AmqpTcpEndpoint.DefaultAmqpSslPort = 5671 -> int
const RabbitMQ.Client.AmqpTcpEndpoint.UseDefaultPort = -1 -> int
const RabbitMQ.Client.ConnectionFactory.DefaultChannelMax = 2047 -> ushort
const RabbitMQ.Client.ConnectionFactory.DefaultFrameMax = 0 -> uint
const RabbitMQ.Client.ConnectionFactory.DefaultMaxMessageSize = 134217728 -> uint
const RabbitMQ.Client.ConnectionFactory.DefaultMaxInboundMessageBodySize = 67108864 -> uint
const RabbitMQ.Client.ConnectionFactory.DefaultPass = "guest" -> string
const RabbitMQ.Client.ConnectionFactory.DefaultUser = "guest" -> string
const RabbitMQ.Client.ConnectionFactory.DefaultVHost = "/" -> string
const RabbitMQ.Client.ConnectionFactory.MaximumMaxMessageSize = 536870912 -> uint
const RabbitMQ.Client.Constants.AccessRefused = 403 -> int
const RabbitMQ.Client.Constants.ChannelError = 504 -> int
const RabbitMQ.Client.Constants.CommandInvalid = 503 -> int
Expand Down Expand Up @@ -82,14 +81,13 @@ RabbitMQ.Client.AmqpTcpEndpoint.AddressFamily.set -> void
RabbitMQ.Client.AmqpTcpEndpoint.AmqpTcpEndpoint() -> void
RabbitMQ.Client.AmqpTcpEndpoint.AmqpTcpEndpoint(string hostName, int portOrMinusOne = -1) -> void
RabbitMQ.Client.AmqpTcpEndpoint.AmqpTcpEndpoint(string hostName, int portOrMinusOne, RabbitMQ.Client.SslOption ssl) -> void
RabbitMQ.Client.AmqpTcpEndpoint.AmqpTcpEndpoint(string hostName, int portOrMinusOne, RabbitMQ.Client.SslOption ssl, uint maxMessageSize) -> void
RabbitMQ.Client.AmqpTcpEndpoint.AmqpTcpEndpoint(System.Uri uri) -> void
RabbitMQ.Client.AmqpTcpEndpoint.AmqpTcpEndpoint(System.Uri uri, RabbitMQ.Client.SslOption ssl) -> void
RabbitMQ.Client.AmqpTcpEndpoint.Clone() -> object
RabbitMQ.Client.AmqpTcpEndpoint.CloneWithHostname(string hostname) -> RabbitMQ.Client.AmqpTcpEndpoint
RabbitMQ.Client.AmqpTcpEndpoint.HostName.get -> string
RabbitMQ.Client.AmqpTcpEndpoint.HostName.set -> void
RabbitMQ.Client.AmqpTcpEndpoint.MaxMessageSize.get -> uint
RabbitMQ.Client.AmqpTcpEndpoint.MaxInboundMessageBodySize.get -> uint
RabbitMQ.Client.AmqpTcpEndpoint.Port.get -> int
RabbitMQ.Client.AmqpTcpEndpoint.Port.set -> void
RabbitMQ.Client.AmqpTcpEndpoint.Protocol.get -> RabbitMQ.Client.IProtocol
Expand Down Expand Up @@ -225,8 +223,8 @@ RabbitMQ.Client.ConnectionFactory.HandshakeContinuationTimeout.get -> System.Tim
RabbitMQ.Client.ConnectionFactory.HandshakeContinuationTimeout.set -> void
RabbitMQ.Client.ConnectionFactory.HostName.get -> string
RabbitMQ.Client.ConnectionFactory.HostName.set -> void
RabbitMQ.Client.ConnectionFactory.MaxMessageSize.get -> uint
RabbitMQ.Client.ConnectionFactory.MaxMessageSize.set -> void
RabbitMQ.Client.ConnectionFactory.MaxInboundMessageBodySize.get -> uint
RabbitMQ.Client.ConnectionFactory.MaxInboundMessageBodySize.set -> void
RabbitMQ.Client.ConnectionFactory.NetworkRecoveryInterval.get -> System.TimeSpan
RabbitMQ.Client.ConnectionFactory.NetworkRecoveryInterval.set -> void
RabbitMQ.Client.ConnectionFactory.Password.get -> string
Expand Down Expand Up @@ -787,6 +785,7 @@ readonly RabbitMQ.Client.ConnectionConfig.HandshakeContinuationTimeout -> System
readonly RabbitMQ.Client.ConnectionConfig.HeartbeatInterval -> System.TimeSpan
readonly RabbitMQ.Client.ConnectionConfig.MaxChannelCount -> ushort
readonly RabbitMQ.Client.ConnectionConfig.MaxFrameSize -> uint
readonly RabbitMQ.Client.ConnectionConfig.MaxInboundMessageBodySize -> uint
readonly RabbitMQ.Client.ConnectionConfig.NetworkRecoveryInterval -> System.TimeSpan
readonly RabbitMQ.Client.ConnectionConfig.Password -> string
readonly RabbitMQ.Client.ConnectionConfig.RequestedConnectionTimeout -> System.TimeSpan
Expand Down Expand Up @@ -884,6 +883,7 @@ virtual RabbitMQ.Client.TcpClientAdapter.ReceiveTimeout.set -> void
~const RabbitMQ.Client.RabbitMQActivitySource.PublisherSourceName = "RabbitMQ.Client.Publisher" -> string
~const RabbitMQ.Client.RabbitMQActivitySource.SubscriberSourceName = "RabbitMQ.Client.Subscriber" -> string
~override RabbitMQ.Client.Events.EventingBasicConsumer.HandleBasicDeliverAsync(string consumerTag, ulong deliveryTag, bool redelivered, string exchange, string routingKey, RabbitMQ.Client.ReadOnlyBasicProperties properties, System.ReadOnlyMemory<byte> body) -> System.Threading.Tasks.Task
~RabbitMQ.Client.AmqpTcpEndpoint.AmqpTcpEndpoint(string hostName, int portOrMinusOne, RabbitMQ.Client.SslOption ssl, uint maxInboundMessageBodySize) -> void
~RabbitMQ.Client.ConnectionFactory.CreateConnectionAsync(RabbitMQ.Client.IEndpointResolver endpointResolver, string clientProvidedName, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<RabbitMQ.Client.IConnection>
~RabbitMQ.Client.ConnectionFactory.CreateConnectionAsync(string clientProvidedName, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<RabbitMQ.Client.IConnection>
~RabbitMQ.Client.ConnectionFactory.CreateConnectionAsync(System.Collections.Generic.IEnumerable<RabbitMQ.Client.AmqpTcpEndpoint> endpoints, string clientProvidedName, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<RabbitMQ.Client.IConnection>
Expand Down
23 changes: 12 additions & 11 deletions projects/RabbitMQ.Client/client/api/AmqpTcpEndpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,23 @@ public class AmqpTcpEndpoint

private int _port;

private readonly uint _maxMessageSize;
private readonly uint _maxInboundMessageBodySize;

/// <summary>
/// Creates a new instance of the <see cref="AmqpTcpEndpoint"/>.
/// </summary>
/// <param name="hostName">Hostname.</param>
/// <param name="portOrMinusOne"> Port number. If the port number is -1, the default port number will be used.</param>
/// <param name="ssl">Ssl option.</param>
/// <param name="maxMessageSize">Maximum message size from RabbitMQ. <see cref="ConnectionFactory.MaximumMaxMessageSize"/>. It defaults to
/// MaximumMaxMessageSize if the parameter is greater than MaximumMaxMessageSize.</param>
public AmqpTcpEndpoint(string hostName, int portOrMinusOne, SslOption ssl, uint maxMessageSize)
/// <param name="maxInboundMessageBodySize">Maximum message size from RabbitMQ.</param>
public AmqpTcpEndpoint(string hostName, int portOrMinusOne, SslOption ssl,
uint maxInboundMessageBodySize)
{
HostName = hostName;
_port = portOrMinusOne;
Ssl = ssl;
_maxMessageSize = Math.Min(maxMessageSize, ConnectionFactory.MaximumMaxMessageSize);
_maxInboundMessageBodySize = Math.Min(maxInboundMessageBodySize,
InternalConstants.DefaultRabbitMqMaxInboundMessageBodySize);
}

/// <summary>
Expand All @@ -87,7 +88,7 @@ public AmqpTcpEndpoint(string hostName, int portOrMinusOne, SslOption ssl, uint
/// <param name="portOrMinusOne"> Port number. If the port number is -1, the default port number will be used.</param>
/// <param name="ssl">Ssl option.</param>
public AmqpTcpEndpoint(string hostName, int portOrMinusOne, SslOption ssl) :
this(hostName, portOrMinusOne, ssl, ConnectionFactory.DefaultMaxMessageSize)
this(hostName, portOrMinusOne, ssl, ConnectionFactory.DefaultMaxInboundMessageBodySize)
{
}

Expand Down Expand Up @@ -134,7 +135,7 @@ public AmqpTcpEndpoint(Uri uri) : this(uri.Host, uri.Port)
/// <returns>A copy with the same hostname, port, and TLS settings</returns>
public object Clone()
{
return new AmqpTcpEndpoint(HostName, _port, Ssl, _maxMessageSize);
return new AmqpTcpEndpoint(HostName, _port, Ssl, _maxInboundMessageBodySize);
}

/// <summary>
Expand All @@ -144,7 +145,7 @@ public object Clone()
/// <returns>A copy with the provided hostname and port/TLS settings of this endpoint</returns>
public AmqpTcpEndpoint CloneWithHostname(string hostname)
{
return new AmqpTcpEndpoint(hostname, _port, Ssl, _maxMessageSize);
return new AmqpTcpEndpoint(hostname, _port, Ssl, _maxInboundMessageBodySize);
}

/// <summary>
Expand Down Expand Up @@ -195,11 +196,11 @@ public IProtocol Protocol

/// <summary>
/// Get the maximum size for a message in bytes.
/// The default value is defined in <see cref="ConnectionFactory.DefaultMaxMessageSize"/>.
/// The default value is defined in <see cref="ConnectionFactory.DefaultMaxInboundMessageBodySize"/>.
/// </summary>
public uint MaxMessageSize
public uint MaxInboundMessageBodySize
{
get { return _maxMessageSize; }
get { return _maxInboundMessageBodySize; }
}

/// <summary>
Expand Down
8 changes: 7 additions & 1 deletion projects/RabbitMQ.Client/client/api/ConnectionConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ public sealed class ConnectionConfig
/// </summary>
public readonly uint MaxFrameSize;

/// <summary>
/// Maximum body size of a message (in bytes).
/// </summary>
public readonly uint MaxInboundMessageBodySize;

/// <summary>
/// Set to false to make automatic connection recovery not recover topology (exchanges, queues, bindings, etc).
/// </summary>
Expand Down Expand Up @@ -149,7 +154,7 @@ public sealed class ConnectionConfig
ICredentialsProvider credentialsProvider, ICredentialsRefresher credentialsRefresher,
IEnumerable<IAuthMechanismFactory> authMechanisms,
IDictionary<string, object?> clientProperties, string? clientProvidedName,
ushort maxChannelCount, uint maxFrameSize, bool topologyRecoveryEnabled,
ushort maxChannelCount, uint maxFrameSize, uint maxInboundMessageBodySize, bool topologyRecoveryEnabled,
TopologyRecoveryFilter topologyRecoveryFilter, TopologyRecoveryExceptionHandler topologyRecoveryExceptionHandler,
TimeSpan networkRecoveryInterval, TimeSpan heartbeatInterval, TimeSpan continuationTimeout, TimeSpan handshakeContinuationTimeout, TimeSpan requestedConnectionTimeout,
bool dispatchConsumersAsync, int dispatchConsumerConcurrency,
Expand All @@ -165,6 +170,7 @@ public sealed class ConnectionConfig
ClientProvidedName = clientProvidedName;
MaxChannelCount = maxChannelCount;
MaxFrameSize = maxFrameSize;
MaxInboundMessageBodySize = maxInboundMessageBodySize;
TopologyRecoveryEnabled = topologyRecoveryEnabled;
TopologyRecoveryFilter = topologyRecoveryFilter;
TopologyRecoveryExceptionHandler = topologyRecoveryExceptionHandler;
Expand Down
21 changes: 8 additions & 13 deletions projects/RabbitMQ.Client/client/api/ConnectionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ namespace RabbitMQ.Client
/// factory.VirtualHost = ConnectionFactory.DefaultVHost;
/// factory.HostName = hostName;
/// factory.Port = AmqpTcpEndpoint.UseDefaultPort;
/// factory.MaxMessageSize = 512 * 1024 * 1024;
/// factory.MaxInboundMessageBodySize = 512 * 1024 * 1024;
/// //
/// IConnection conn = factory.CreateConnection();
/// //
Expand Down Expand Up @@ -107,15 +107,9 @@ public sealed class ConnectionFactory : ConnectionFactoryBase, IConnectionFactor
public const uint DefaultFrameMax = 0;

/// <summary>
/// Default value for <code>ConnectionFactory</code>'s <code>MaxMessageSize</code>.
/// Default value for <code>ConnectionFactory</code>'s <code>MaxInboundMessageBodySize</code>.
/// </summary>
public const uint DefaultMaxMessageSize = 134217728;
/// <summary>
/// Largest message size, in bytes, allowed in RabbitMQ.
/// Note: <code>rabbit.max_message_size</code> setting (https://www.rabbitmq.com/configure.html)
/// configures the largest message size which should be lower than this maximum of 536 Mbs.
/// </summary>
public const uint MaximumMaxMessageSize = 536870912;
public const uint DefaultMaxInboundMessageBodySize = 1_048_576 * 64;

/// <summary>
/// Default value for desired heartbeat interval. Default is 60 seconds,
Expand Down Expand Up @@ -291,13 +285,13 @@ public ConnectionFactory()
/// </summary>
public AmqpTcpEndpoint Endpoint
{
get { return new AmqpTcpEndpoint(HostName, Port, Ssl, MaxMessageSize); }
get { return new AmqpTcpEndpoint(HostName, Port, Ssl, MaxInboundMessageBodySize); }
set
{
Port = value.Port;
HostName = value.HostName;
Ssl = value.Ssl;
MaxMessageSize = value.MaxMessageSize;
MaxInboundMessageBodySize = value.MaxInboundMessageBodySize;
}
}

Expand Down Expand Up @@ -359,7 +353,7 @@ public AmqpTcpEndpoint Endpoint
/// Maximum allowed message size, in bytes, from RabbitMQ.
/// Corresponds to the <code>ConnectionFactory.DefaultMaxMessageSize</code> setting.
/// </summary>
public uint MaxMessageSize { get; set; } = DefaultMaxMessageSize;
public uint MaxInboundMessageBodySize { get; set; } = DefaultMaxInboundMessageBodySize;

/// <summary>
/// The uri to use for the connection.
Expand Down Expand Up @@ -484,7 +478,7 @@ public IAuthMechanismFactory AuthMechanismFactory(IEnumerable<string> argServerM
public Task<IConnection> CreateConnectionAsync(IEnumerable<string> hostnames, string clientProvidedName,
CancellationToken cancellationToken = default)
{
IEnumerable<AmqpTcpEndpoint> endpoints = hostnames.Select(h => new AmqpTcpEndpoint(h, Port, Ssl, MaxMessageSize));
IEnumerable<AmqpTcpEndpoint> endpoints = hostnames.Select(h => new AmqpTcpEndpoint(h, Port, Ssl, MaxInboundMessageBodySize));
return CreateConnectionAsync(EndpointResolverFactory(endpoints), clientProvidedName, cancellationToken);
}

Expand Down Expand Up @@ -602,6 +596,7 @@ private ConnectionConfig CreateConfig(string clientProvidedName)
clientProvidedName,
RequestedChannelMax,
RequestedFrameMax,
MaxInboundMessageBodySize,
TopologyRecoveryEnabled,
TopologyRecoveryFilter,
TopologyRecoveryExceptionHandler,
Expand Down
7 changes: 7 additions & 0 deletions projects/RabbitMQ.Client/client/api/InternalConstants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,12 @@ internal static class InternalConstants
{
internal static readonly TimeSpan DefaultConnectionAbortTimeout = TimeSpan.FromSeconds(5);
internal static readonly TimeSpan DefaultConnectionCloseTimeout = TimeSpan.FromSeconds(30);

/// <summary>
/// Largest message size, in bytes, allowed in RabbitMQ.
/// Note: <code>rabbit.max_message_size</code> setting (https://www.rabbitmq.com/configure.html)
/// configures the largest message size which should be lower than this maximum of 128MiB.
/// </summary>
internal const uint DefaultRabbitMqMaxInboundMessageBodySize = 1_048_576 * 128;
}
}
12 changes: 11 additions & 1 deletion projects/RabbitMQ.Client/client/impl/CommandAssembler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@ internal sealed class CommandAssembler
private int _offset;
private AssemblyState _state;

public CommandAssembler()
private readonly uint _maxBodyLength;

public CommandAssembler(uint maxBodyLength)
{
_maxBodyLength = maxBodyLength;
Reset();
}

Expand Down Expand Up @@ -152,6 +155,13 @@ private bool ParseHeaderFrame(in InboundFrame frame)
{
throw new UnexpectedFrameException(frame.Type);
}

if (totalBodyBytes > _maxBodyLength)
{
string msg = $"Frame body size '{totalBodyBytes}' exceeds maximum of '{_maxBodyLength}' bytes";
throw new MalformedFrameException(message: msg, canShutdownCleanly: false);
}

_rentedHeaderArray = totalBodyBytes != 0 ? frame.TakeoverPayload() : Array.Empty<byte>();

_headerMemory = frame.Payload.Slice(12);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ await FinishCloseAsync(cancellationToken)
}

ushort channelMax = (ushort)NegotiatedMaxValue(_config.MaxChannelCount, connectionTune.m_channelMax);
_sessionManager = new SessionManager(this, channelMax);
_sessionManager = new SessionManager(this, channelMax, _config.MaxInboundMessageBodySize);

uint frameMax = NegotiatedMaxValue(_config.MaxFrameSize, connectionTune.m_frameMax);
FrameMax = frameMax;
Expand Down
4 changes: 2 additions & 2 deletions projects/RabbitMQ.Client/client/impl/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ internal Connection(ConnectionConfig config, IFrameHandler frameHandler)
_connectionUnblockedWrapper = new EventingWrapper<EventArgs>("OnConnectionUnblocked", onException);
_connectionShutdownWrapper = new EventingWrapper<ShutdownEventArgs>("OnShutdown", onException);

_sessionManager = new SessionManager(this, 0);
_session0 = new MainSession(this);
_sessionManager = new SessionManager(this, 0, config.MaxInboundMessageBodySize);
_session0 = new MainSession(this, config.MaxInboundMessageBodySize);
_channel0 = new Channel(_config, _session0); ;

ClientProperties = new Dictionary<string, object?>(_config.ClientProperties)
Expand Down
17 changes: 10 additions & 7 deletions projects/RabbitMQ.Client/client/impl/Frame.cs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ private static void ProcessProtocolHeader(ReadOnlySequence<byte> buffer)
}
}

internal static async ValueTask<InboundFrame> ReadFromPipeAsync(PipeReader reader, uint maxMessageSize,
internal static async ValueTask<InboundFrame> ReadFromPipeAsync(PipeReader reader,
uint maxInboundMessageBodySize,
CancellationToken mainLoopCancellationToken)
{
ReadResult result = await reader.ReadAsync(mainLoopCancellationToken)
Expand All @@ -266,7 +267,7 @@ private static void ProcessProtocolHeader(ReadOnlySequence<byte> buffer)

InboundFrame frame;
// Loop until we have enough data to read an entire frame, or until the pipe is completed.
while (!TryReadFrame(ref buffer, maxMessageSize, out frame))
while (!TryReadFrame(ref buffer, maxInboundMessageBodySize, out frame))
{
reader.AdvanceTo(buffer.Start, buffer.End);

Expand All @@ -283,15 +284,16 @@ private static void ProcessProtocolHeader(ReadOnlySequence<byte> buffer)
return frame;
}

internal static bool TryReadFrameFromPipe(PipeReader reader, uint maxMessageSize, out InboundFrame frame)
internal static bool TryReadFrameFromPipe(PipeReader reader,
uint maxInboundMessageBodySize, out InboundFrame frame)
{
if (reader.TryRead(out ReadResult result))
{
ReadOnlySequence<byte> buffer = result.Buffer;

MaybeThrowEndOfStream(result, buffer);

if (TryReadFrame(ref buffer, maxMessageSize, out frame))
if (TryReadFrame(ref buffer, maxInboundMessageBodySize, out frame))
{
reader.AdvanceTo(buffer.Start);
return true;
Expand All @@ -306,7 +308,8 @@ internal static bool TryReadFrameFromPipe(PipeReader reader, uint maxMessageSize
return false;
}

internal static bool TryReadFrame(ref ReadOnlySequence<byte> buffer, uint maxMessageSize, out InboundFrame frame)
internal static bool TryReadFrame(ref ReadOnlySequence<byte> buffer,
uint maxInboundMessageBodySize, out InboundFrame frame)
{
if (buffer.Length < 7)
{
Expand All @@ -332,9 +335,9 @@ internal static bool TryReadFrame(ref ReadOnlySequence<byte> buffer, uint maxMes
FrameType type = (FrameType)firstByte;
int channel = NetworkOrderDeserializer.ReadUInt16(buffer.Slice(1));
int payloadSize = NetworkOrderDeserializer.ReadInt32(buffer.Slice(3));
if ((maxMessageSize > 0) && (payloadSize > maxMessageSize))
if ((maxInboundMessageBodySize > 0) && (payloadSize > maxInboundMessageBodySize))
{
string msg = $"Frame payload size '{payloadSize}' exceeds maximum of '{maxMessageSize}' bytes";
string msg = $"Frame payload size '{payloadSize}' exceeds maximum of '{maxInboundMessageBodySize}' bytes";
throw new MalformedFrameException(message: msg, canShutdownCleanly: false);
}

Expand Down

0 comments on commit e52d703

Please sign in to comment.