Skip to content

Commit

Permalink
fix(ws): fix #22037 - concurrent subscription requests (#22309)
Browse files Browse the repository at this point in the history
* fix(ws): fix #22037 - concurrent subscription requests

* trigger build

* explitict cast

* fix second race condition

* add changes for watch multiple

* build fix

* build fix

---------

Co-authored-by: carlosmiei <43336371+carlosmiei@users.noreply.github.com>
  • Loading branch information
pcriadoperez and carlosmiei committed Apr 29, 2024
1 parent e12c99d commit 4153572
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 55 deletions.
3 changes: 2 additions & 1 deletion build/dummy.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ used to trigger builds/deploys without changing
the history of some really important file :D
goo builldd
prr!!!!!
deploy!!
deploy!
trigger!
35 changes: 5 additions & 30 deletions cs/ccxt/ws/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ public class WebSocketClient
{
public string url; // Replace with your WebSocket server URL
public ClientWebSocket webSocket = new ClientWebSocket();
// public ClientWebSocket webSocket;

public IDictionary<string, Future> futures = new ConcurrentDictionary<string, Future>();
public IDictionary<string, object> subscriptions = new ConcurrentDictionary<string, object>();
public IDictionary<string, object> rejections = new ConcurrentDictionary<string, object>();
public IDictionary<string, object> rejections = new ConcurrentDictionary<string, object>(); // Currently not being used

public bool verbose = false;
public bool isConnected = false;
Expand Down Expand Up @@ -76,27 +75,7 @@ public WebSocketClient(string url, string proxy, handleMessageDelegate handleMes
public Future future(object messageHash2)
{
var messageHash = messageHash2.ToString();
// var tcs = new TaskCompletionSource<object>();
// this.futures[messageHash] = tcs;
// return tcs.Task;
if (!this.futures.ContainsKey(messageHash))
{
// var tcs = new TaskCompletionSource<object>();
var future = new Future();
lock (this.futures)
{
// Console.WriteLine("Adding future, inside lock");
this.futures[messageHash] = future;
}
// Console.WriteLine("outside lock");
// return future.task;
return future;
}
else
{
// return (Task<object>)this.futures[messageHash].task;
return this.futures[messageHash];
}
return (this.futures as ConcurrentDictionary<string, Future>).GetOrAdd (messageHash, (key) => new Future());
}

public void resolve(object content, object messageHash2)
Expand All @@ -106,10 +85,8 @@ public void resolve(object content, object messageHash2)
Console.WriteLine("resolve received undefined messageHash");
}
var messageHash = messageHash2.ToString();
if (this.futures.ContainsKey(messageHash))
if ((this.futures as ConcurrentDictionary<string, Future>).TryRemove(messageHash, out Future future))
{
var future = this.futures[messageHash];
this.futures.Remove(messageHash); // this order matters
future.resolve(content);
}
}
Expand All @@ -119,10 +96,8 @@ public void reject(object content, object messageHash2 = null)
if (messageHash2 != null)
{
var messageHash = messageHash2.ToString();
if (this.futures.ContainsKey(messageHash))
if ((this.futures as ConcurrentDictionary<string, Future>).TryRemove(messageHash, out Future future))
{
var future = this.futures[messageHash];
this.futures.Remove(messageHash); // this order matters
future.reject(content);
}
}
Expand Down Expand Up @@ -443,4 +418,4 @@ public async Task Close()
}
}

}
}
35 changes: 11 additions & 24 deletions cs/ccxt/ws/Exchange.WsBridge.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,12 @@ public WebSocketClient client(object url2)
var url = url2.ToString();
var result = this.checkWsProxySettings() as List<object>;
var proxy = this.getWsProxy(result);
if (!this.clients.ContainsKey(url))
return this.clients.GetOrAdd(url, (url) =>
{
object ws = this.safeValue(this.options, "ws", new Dictionary<string, object>() { });
var wsOptions = this.safeValue(ws, "options", new Dictionary<string, object>() { });
var keepAlive = ((Int64)this.safeInteger(wsOptions, "keepAlive", 30000));
this.clients[url] = new WebSocketClient(url, proxy, handleMessage, ping, onClose, onError, this.verbose, keepAlive);
var client = new WebSocketClient(url, proxy, handleMessage, ping, onClose, onError, this.verbose, keepAlive);
var wsHeaders = this.safeValue(wsOptions, "headers", new Dictionary<string, object>() { });
// iterate through headers
Expand All @@ -144,11 +144,11 @@ public WebSocketClient client(object url2)
var headers = wsHeaders as Dictionary<string, object>;
foreach (var key in headers.Keys)
{
this.clients[url].webSocket.Options.SetRequestHeader(key, headers[key].ToString());
client.webSocket.Options.SetRequestHeader(key, headers[key].ToString());
}
}
}
return this.clients[url];
return client;
});
}

public async Task<object> watch(object url2, object messageHash2, object message = null, object subscribeHash2 = null, object subscription = null)
Expand All @@ -158,23 +158,13 @@ public async Task<object> watch(object url2, object messageHash2, object message
var subscribeHash = subscribeHash2?.ToString();
var client = this.client(url);

if ((subscribeHash == null) && (client.futures.ContainsKey(messageHash)))
{
return client.futures[messageHash];
}

var future = client.future(messageHash);

var clientSubscription = (subscribeHash != null && client.subscriptions.ContainsKey(subscribeHash)) ? client.subscriptions[subscribeHash] : null;

if (clientSubscription == null)
{
client.subscriptions[subscribeHash] = subscription ?? true;
var future = (client.futures as ConcurrentDictionary<string, Future>).GetOrAdd (messageHash, (key) => client.future(messageHash));
if (subscribeHash == null) {
return await future;
}

var connected = client.connect(0);

if (clientSubscription == null)
if ((client.subscriptions as ConcurrentDictionary<string, object>).TryAdd(subscribeHash, subscription ?? true))
{
await connected;
if (message != null)
Expand All @@ -192,7 +182,6 @@ public async Task<object> watch(object url2, object messageHash2, object message

}
}

return await future;
}

Expand All @@ -204,7 +193,6 @@ public async Task<object> watchMultiple(object url2, object messageHashes2, obje

var client = this.client(url);


var future = Future.race(messageHashes.Select(subHash => client.future(subHash)).ToArray());

var missingSubscriptions = new List<string>();
Expand All @@ -213,11 +201,10 @@ public async Task<object> watchMultiple(object url2, object messageHashes2, obje
{
foreach (var subscribeHash in subscribeHashes)
{
var clientSubscription = (subscribeHash != null && client.subscriptions.ContainsKey(subscribeHash)) ? client.subscriptions[subscribeHash] : null;
if (subscribeHash == null) continue;

if (clientSubscription == null)
if ((client.subscriptions as ConcurrentDictionary<string, object>).TryAdd (subscribeHash, subscription ?? true))
{
client.subscriptions[subscribeHash] = subscription ?? true;
missingSubscriptions.Add(subscribeHash);
}
}
Expand Down

0 comments on commit 4153572

Please sign in to comment.