Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for cluster keyspace notifications to subscriber #1536

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/StackExchange.Redis/RedisChannel.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Linq;
using System.Text;

namespace StackExchange.Redis
Expand All @@ -8,8 +9,12 @@ namespace StackExchange.Redis
/// </summary>
public readonly struct RedisChannel : IEquatable<RedisChannel>
{
private static readonly byte[] __keyBytes = Encoding.UTF8.GetBytes("__key");
private static readonly byte[] __KEYBytes = Encoding.UTF8.GetBytes("__KEY");

internal readonly byte[] Value;
internal readonly bool IsPatternBased;
internal readonly bool IsKeyspaceChannel;

/// <summary>
/// Indicates whether the channel-name is either null or a zero-length value
Expand All @@ -36,6 +41,15 @@ private RedisChannel(byte[] value, bool isPatternBased)
{
Value = value;
IsPatternBased = isPatternBased;
if (value != null && value.Length >= __keyBytes.Length)
{
var prefix = new ArraySegment<byte>(value, 0, 5);
IsKeyspaceChannel = prefix.SequenceEqual(__keyBytes) || prefix.SequenceEqual(__KEYBytes);
}
else
{
IsKeyspaceChannel = false;
}
}

private static bool DeterminePatternBased(byte[] value, PatternMode mode)
Expand Down
131 changes: 103 additions & 28 deletions src/StackExchange.Redis/RedisSubscriber.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Runtime.CompilerServices;
using System.Threading;
Expand Down Expand Up @@ -191,7 +192,7 @@ internal sealed class Subscription
{
private Action<RedisChannel, RedisValue> _handlers;
private ChannelMessageQueue _queues;
private ServerEndPoint owner;
private readonly HashSet<ServerEndPoint> owners = new HashSet<ServerEndPoint>();

public void Add(Action<RedisChannel, RedisValue> handler, ChannelMessageQueue queue)
{
Expand Down Expand Up @@ -221,14 +222,51 @@ public bool Remove(Action<RedisChannel, RedisValue> handler, ChannelMessageQueue

public Task SubscribeToServer(ConnectionMultiplexer multiplexer, in RedisChannel channel, CommandFlags flags, object asyncState, bool internalCall)
{
var selected = multiplexer.SelectServer(RedisCommand.SUBSCRIBE, flags, default(RedisKey));
// subscribe to all masters in cluster for keyspace/keyevent notifications
if (channel.IsKeyspaceChannel)
{
return SubscribeToMasters(multiplexer, channel, flags, asyncState, internalCall);
}
return SubscribeToSelectedEndpoint(multiplexer, channel, flags, asyncState, internalCall);
}

private Task SubscribeToMasters(ConnectionMultiplexer multiplexer, RedisChannel channel, CommandFlags flags, object asyncState, bool internalCall)
{
var subscribeTasks = new List<Task>();
var masters = multiplexer.GetServerSnapshot().ToArray()
.Where(s => !s.IsReplica && s.ClusterConfiguration != null && s.EndPoint.Equals(s.ClusterConfiguration.Origin));

lock (owners)
{
foreach (var master in masters)
{
if (owners.Contains(master)) continue;
owners.Add(master);

var state = SubscribeToSelectedEndpoint(multiplexer, channel, flags, asyncState, internalCall);
subscribeTasks.Add(state ?? CompletedTask<bool>.Default(asyncState));
}
}

return Task.WhenAll(subscribeTasks);
}

private Task SubscribeToSelectedEndpoint(ConnectionMultiplexer multiplexer, RedisChannel channel, CommandFlags flags, object asyncState, bool internalCall)
{
var cmd = channel.IsPatternBased ? RedisCommand.PSUBSCRIBE : RedisCommand.SUBSCRIBE;
var selected = multiplexer.SelectServer(cmd, flags, default(RedisKey));
var bridge = selected?.GetBridge(ConnectionType.Subscription, true);
if (bridge == null) return null;

// note: check we can create the message validly *before* we swap the owner over (Interlocked)
var state = PendingSubscriptionState.Create(channel, this, flags, true, internalCall, asyncState, selected.IsReplica);

if (Interlocked.CompareExchange(ref owner, selected, null) != null) return null;
lock (owners)
{
if (!owners.Add(selected))
{
return null;
}
}
try
{
if (!bridge.TryEnqueueBackgroundSubscriptionWrite(state))
Expand All @@ -241,25 +279,48 @@ public Task SubscribeToServer(ConnectionMultiplexer multiplexer, in RedisChannel
catch
{
// clear the owner if it is still us
Interlocked.CompareExchange(ref owner, null, selected);
lock (owners)
{
owners.Remove(selected);
}
throw;
}
}

public Task UnsubscribeFromServer(in RedisChannel channel, CommandFlags flags, object asyncState, bool internalCall)
{
var oldOwner = Interlocked.Exchange(ref owner, null);
var bridge = oldOwner?.GetBridge(ConnectionType.Subscription, false);
if (bridge == null) return null;
lock (owners)
{
if (owners.Count == 0) return null;
}

var state = PendingSubscriptionState.Create(channel, this, flags, false, internalCall, asyncState, oldOwner.IsReplica);
var queuedTasks = new List<Task>();

if (!bridge.TryEnqueueBackgroundSubscriptionWrite(state))
var cmd = channel.IsPatternBased ? RedisCommand.PUNSUBSCRIBE : RedisCommand.UNSUBSCRIBE;

var msg = Message.Create(-1, flags, cmd, channel);
if (internalCall) msg.SetInternalCall();

lock (owners)
{
state.Abort();
return null;
foreach (var owner in owners)
{
var bridge = owner?.GetBridge(ConnectionType.Subscription, false);
if (bridge == null) return null;

var state = PendingSubscriptionState.Create(channel, this, flags, false, internalCall, asyncState, owner.IsReplica);

if (!bridge.TryEnqueueBackgroundSubscriptionWrite(state))
{
state.Abort();
return null;
}
queuedTasks.Add(state.Task);
}

owners.Clear();
}
return state.Task;
return Task.WhenAll(queuedTasks);
}

internal readonly struct PendingSubscriptionState
Expand Down Expand Up @@ -294,36 +355,50 @@ private PendingSubscriptionState(object asyncState, RedisChannel channel, Subscr
}
}

internal ServerEndPoint GetOwner() => Volatile.Read(ref owner);
internal ServerEndPoint GetOwner()
{
lock (owners)
{
return owners.FirstOrDefault();
}
}

internal void Resubscribe(in RedisChannel channel, ServerEndPoint server)
{
if (server != null && Interlocked.CompareExchange(ref owner, server, server) == server)
bool hasOwner;
lock (owners)
{
var cmd = channel.IsPatternBased ? RedisCommand.PSUBSCRIBE : RedisCommand.SUBSCRIBE;
var msg = Message.Create(-1, CommandFlags.FireAndForget, cmd, channel);
msg.SetInternalCall();
hasOwner = owners.Contains(server);
}
if (server == null || !hasOwner)
{
return;
}
var cmd = channel.IsPatternBased ? RedisCommand.PSUBSCRIBE : RedisCommand.SUBSCRIBE;
var msg = Message.Create(-1, CommandFlags.FireAndForget, cmd, channel);
msg.SetInternalCall();
#pragma warning disable CS0618
server.WriteDirectFireAndForgetSync(msg, ResultProcessor.TrackSubscriptions);
server.WriteDirectFireAndForgetSync(msg, ResultProcessor.TrackSubscriptions);
#pragma warning restore CS0618
}
}

internal bool Validate(ConnectionMultiplexer multiplexer, in RedisChannel channel)
{
bool changed = false;
var oldOwner = Volatile.Read(ref owner);
if (oldOwner != null && !oldOwner.IsSelectable(RedisCommand.PSUBSCRIBE))
lock (owners)
{
if (UnsubscribeFromServer(channel, CommandFlags.FireAndForget, null, true) != null)
if (owners.Count != 0 && !owners.All(o => o.IsSelectable(RedisCommand.PSUBSCRIBE)))
{
if (UnsubscribeFromServer(channel, CommandFlags.FireAndForget, null, true) != null)
{
changed = true;
}
owners.Clear();
}
if (owners.Count == 0 && SubscribeToServer(multiplexer, channel, CommandFlags.FireAndForget, null, true) != null)
{
changed = true;
}
oldOwner = null;
}
if (oldOwner == null && SubscribeToServer(multiplexer, channel, CommandFlags.FireAndForget, null, true) != null)
{
changed = true;
}
return changed;
}
Expand Down