diff --git a/src/Discord.Net.Core/Entities/Gateway/BotGateway.cs b/src/Discord.Net.Core/Entities/Gateway/BotGateway.cs
new file mode 100644
index 000000000..c9be0ac1f
--- /dev/null
+++ b/src/Discord.Net.Core/Entities/Gateway/BotGateway.cs
@@ -0,0 +1,22 @@
+namespace Discord
+{
+ ///
+ /// Stores the gateway information related to the current bot.
+ ///
+ public class BotGateway
+ {
+ ///
+ /// Gets the WSS URL that can be used for connecting to the gateway.
+ ///
+ public string Url { get; internal set; }
+ ///
+ /// Gets the recommended number of shards to use when connecting.
+ ///
+ public int Shards { get; internal set; }
+ ///
+ /// Gets the that contains the information
+ /// about the current session start limit.
+ ///
+ public SessionStartLimit SessionStartLimit { get; internal set; }
+ }
+}
diff --git a/src/Discord.Net.Core/Entities/Gateway/SessionStartLimit.cs b/src/Discord.Net.Core/Entities/Gateway/SessionStartLimit.cs
new file mode 100644
index 000000000..74ae96af1
--- /dev/null
+++ b/src/Discord.Net.Core/Entities/Gateway/SessionStartLimit.cs
@@ -0,0 +1,38 @@
+namespace Discord
+{
+ ///
+ /// Stores the information related to the gateway identify request.
+ ///
+ public class SessionStartLimit
+ {
+ ///
+ /// Gets the total number of session starts the current user is allowed.
+ ///
+ ///
+ /// The maximum amount of session starts the current user is allowed.
+ ///
+ public int Total { get; internal set; }
+ ///
+ /// Gets the remaining number of session starts the current user is allowed.
+ ///
+ ///
+ /// The remaining amount of session starts the current user is allowed.
+ ///
+ public int Remaining { get; internal set; }
+ ///
+ /// Gets the number of milliseconds after which the limit resets.
+ ///
+ ///
+ /// The milliseconds until the limit resets back to the .
+ ///
+ public int ResetAfter { get; internal set; }
+ ///
+ /// Gets the maximum concurrent identify requests in a time window.
+ ///
+ ///
+ /// The maximum concurrent identify requests in a time window,
+ /// limited to the same rate limit key.
+ ///
+ public int MaxConcurrency { get; internal set; }
+ }
+}
diff --git a/src/Discord.Net.Core/IDiscordClient.cs b/src/Discord.Net.Core/IDiscordClient.cs
index f972cd71d..d7d6d2856 100644
--- a/src/Discord.Net.Core/IDiscordClient.cs
+++ b/src/Discord.Net.Core/IDiscordClient.cs
@@ -274,5 +274,15 @@ namespace Discord
/// that represents the number of shards that should be used with this account.
///
Task GetRecommendedShardCountAsync(RequestOptions options = null);
+
+ ///
+ /// Gets the gateway information related to the bot.
+ ///
+ /// The options to be used when sending the request.
+ ///
+ /// A task that represents the asynchronous get operation. The task result contains a
+ /// that represents the gateway information related to the bot.
+ ///
+ Task GetBotGatewayAsync(RequestOptions options = null);
}
}
diff --git a/src/Discord.Net.Core/RequestOptions.cs b/src/Discord.Net.Core/RequestOptions.cs
index ad0a4e33f..dbb240273 100644
--- a/src/Discord.Net.Core/RequestOptions.cs
+++ b/src/Discord.Net.Core/RequestOptions.cs
@@ -61,6 +61,7 @@ namespace Discord
internal BucketId BucketId { get; set; }
internal bool IsClientBucket { get; set; }
internal bool IsReactionBucket { get; set; }
+ internal bool IsGatewayBucket { get; set; }
internal static RequestOptions CreateOrClone(RequestOptions options)
{
diff --git a/src/Discord.Net.Rest/API/Common/SessionStartLimit.cs b/src/Discord.Net.Rest/API/Common/SessionStartLimit.cs
new file mode 100644
index 000000000..29d5ddf85
--- /dev/null
+++ b/src/Discord.Net.Rest/API/Common/SessionStartLimit.cs
@@ -0,0 +1,16 @@
+using Newtonsoft.Json;
+
+namespace Discord.API.Rest
+{
+ internal class SessionStartLimit
+ {
+ [JsonProperty("total")]
+ public int Total { get; set; }
+ [JsonProperty("remaining")]
+ public int Remaining { get; set; }
+ [JsonProperty("reset_after")]
+ public int ResetAfter { get; set; }
+ [JsonProperty("max_concurrency")]
+ public int MaxConcurrency { get; set; }
+ }
+}
diff --git a/src/Discord.Net.Rest/API/Rest/GetBotGatewayResponse.cs b/src/Discord.Net.Rest/API/Rest/GetBotGatewayResponse.cs
index 111fcf3db..d3285051b 100644
--- a/src/Discord.Net.Rest/API/Rest/GetBotGatewayResponse.cs
+++ b/src/Discord.Net.Rest/API/Rest/GetBotGatewayResponse.cs
@@ -1,4 +1,4 @@
-#pragma warning disable CS1591
+#pragma warning disable CS1591
using Newtonsoft.Json;
namespace Discord.API.Rest
@@ -9,5 +9,7 @@ namespace Discord.API.Rest
public string Url { get; set; }
[JsonProperty("shards")]
public int Shards { get; set; }
+ [JsonProperty("session_start_limit")]
+ public SessionStartLimit SessionStartLimit { get; set; }
}
}
diff --git a/src/Discord.Net.Rest/BaseDiscordClient.cs b/src/Discord.Net.Rest/BaseDiscordClient.cs
index b641fa1c3..68589a4f1 100644
--- a/src/Discord.Net.Rest/BaseDiscordClient.cs
+++ b/src/Discord.Net.Rest/BaseDiscordClient.cs
@@ -152,6 +152,10 @@ namespace Discord.Rest
public Task GetRecommendedShardCountAsync(RequestOptions options = null)
=> ClientHelper.GetRecommendShardCountAsync(this, options);
+ ///
+ public Task GetBotGatewayAsync(RequestOptions options = null)
+ => ClientHelper.GetBotGatewayAsync(this, options);
+
//IDiscordClient
///
ConnectionState IDiscordClient.ConnectionState => ConnectionState.Disconnected;
diff --git a/src/Discord.Net.Rest/ClientHelper.cs b/src/Discord.Net.Rest/ClientHelper.cs
index 6ebdbcacb..8910e999a 100644
--- a/src/Discord.Net.Rest/ClientHelper.cs
+++ b/src/Discord.Net.Rest/ClientHelper.cs
@@ -184,5 +184,22 @@ namespace Discord.Rest
var response = await client.ApiClient.GetBotGatewayAsync(options).ConfigureAwait(false);
return response.Shards;
}
+
+ public static async Task GetBotGatewayAsync(BaseDiscordClient client, RequestOptions options)
+ {
+ var response = await client.ApiClient.GetBotGatewayAsync(options).ConfigureAwait(false);
+ return new BotGateway
+ {
+ Url = response.Url,
+ Shards = response.Shards,
+ SessionStartLimit = new SessionStartLimit
+ {
+ Total = response.SessionStartLimit.Total,
+ Remaining = response.SessionStartLimit.Remaining,
+ ResetAfter = response.SessionStartLimit.ResetAfter,
+ MaxConcurrency = response.SessionStartLimit.MaxConcurrency
+ }
+ };
+ }
}
}
diff --git a/src/Discord.Net.Rest/DiscordRestClient.cs b/src/Discord.Net.Rest/DiscordRestClient.cs
index bef4e6b2a..48c40fdfa 100644
--- a/src/Discord.Net.Rest/DiscordRestClient.cs
+++ b/src/Discord.Net.Rest/DiscordRestClient.cs
@@ -29,10 +29,10 @@ namespace Discord.Rest
internal DiscordRestClient(DiscordRestConfig config, API.DiscordRestApiClient api) : base(config, api) { }
private static API.DiscordRestApiClient CreateApiClient(DiscordRestConfig config)
- => new API.DiscordRestApiClient(config.RestClientProvider,
- DiscordRestConfig.UserAgent,
- rateLimitPrecision: config.RateLimitPrecision,
- useSystemClock: config.UseSystemClock);
+ => new API.DiscordRestApiClient(config.RestClientProvider,
+ DiscordRestConfig.UserAgent,
+ rateLimitPrecision: config.RateLimitPrecision,
+ useSystemClock: config.UseSystemClock);
internal override void Dispose(bool disposing)
{
diff --git a/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs b/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs
new file mode 100644
index 000000000..aa849018a
--- /dev/null
+++ b/src/Discord.Net.Rest/Net/Queue/GatewayBucket.cs
@@ -0,0 +1,53 @@
+using System.Collections.Immutable;
+
+namespace Discord.Net.Queue
+{
+ public enum GatewayBucketType
+ {
+ Unbucketed = 0,
+ Identify = 1,
+ PresenceUpdate = 2,
+ }
+ internal struct GatewayBucket
+ {
+ private static readonly ImmutableDictionary DefsByType;
+ private static readonly ImmutableDictionary DefsById;
+
+ static GatewayBucket()
+ {
+ var buckets = new[]
+ {
+ // Limit is 120/60s, but 3 will be reserved for heartbeats (2 for possible heartbeats in the same timeframe and a possible failure)
+ new GatewayBucket(GatewayBucketType.Unbucketed, BucketId.Create(null, "", null), 117, 60),
+ new GatewayBucket(GatewayBucketType.Identify, BucketId.Create(null, "", null), 1, 5),
+ new GatewayBucket(GatewayBucketType.PresenceUpdate, BucketId.Create(null, "", null), 5, 60),
+ };
+
+ var builder = ImmutableDictionary.CreateBuilder();
+ foreach (var bucket in buckets)
+ builder.Add(bucket.Type, bucket);
+ DefsByType = builder.ToImmutable();
+
+ var builder2 = ImmutableDictionary.CreateBuilder();
+ foreach (var bucket in buckets)
+ builder2.Add(bucket.Id, bucket);
+ DefsById = builder2.ToImmutable();
+ }
+
+ public static GatewayBucket Get(GatewayBucketType type) => DefsByType[type];
+ public static GatewayBucket Get(BucketId id) => DefsById[id];
+
+ public GatewayBucketType Type { get; }
+ public BucketId Id { get; }
+ public int WindowCount { get; set; }
+ public int WindowSeconds { get; set; }
+
+ public GatewayBucket(GatewayBucketType type, BucketId id, int count, int seconds)
+ {
+ Type = type;
+ Id = id;
+ WindowCount = count;
+ WindowSeconds = seconds;
+ }
+ }
+}
diff --git a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs
index 127a48cf3..2bf8e20b0 100644
--- a/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs
+++ b/src/Discord.Net.Rest/Net/Queue/RequestQueue.cs
@@ -89,9 +89,18 @@ namespace Discord.Net.Queue
}
public async Task SendAsync(WebSocketRequest request)
{
- //TODO: Re-impl websocket buckets
- request.CancelToken = _requestCancelToken;
- await request.SendAsync().ConfigureAwait(false);
+ CancellationTokenSource createdTokenSource = null;
+ if (request.Options.CancelToken.CanBeCanceled)
+ {
+ createdTokenSource = CancellationTokenSource.CreateLinkedTokenSource(_requestCancelToken, request.Options.CancelToken);
+ request.Options.CancelToken = createdTokenSource.Token;
+ }
+ else
+ request.Options.CancelToken = _requestCancelToken;
+
+ var bucket = GetOrCreateBucket(request.Options, request);
+ await bucket.SendAsync(request).ConfigureAwait(false);
+ createdTokenSource?.Dispose();
}
internal async Task EnterGlobalAsync(int id, RestRequest request)
@@ -109,8 +118,23 @@ namespace Discord.Net.Queue
{
_waitUntil = DateTimeOffset.UtcNow.AddMilliseconds(info.RetryAfter.Value + (info.Lag?.TotalMilliseconds ?? 0.0));
}
+ internal async Task EnterGlobalAsync(int id, WebSocketRequest request)
+ {
+ //If this is a global request (unbucketed), it'll be dealt in EnterAsync
+ var requestBucket = GatewayBucket.Get(request.Options.BucketId);
+ if (requestBucket.Type == GatewayBucketType.Unbucketed)
+ return;
+
+ //It's not a global request, so need to remove one from global (per-session)
+ var globalBucketType = GatewayBucket.Get(GatewayBucketType.Unbucketed);
+ var options = RequestOptions.CreateOrClone(request.Options);
+ options.BucketId = globalBucketType.Id;
+ var globalRequest = new WebSocketRequest(null, null, false, false, options);
+ var globalBucket = GetOrCreateBucket(options, globalRequest);
+ await globalBucket.TriggerAsync(id, globalRequest);
+ }
- private RequestBucket GetOrCreateBucket(RequestOptions options, RestRequest request)
+ private RequestBucket GetOrCreateBucket(RequestOptions options, IRequest request)
{
var bucketId = options.BucketId;
object obj = _buckets.GetOrAdd(bucketId, x => new RequestBucket(this, request, x));
@@ -137,6 +161,12 @@ namespace Discord.Net.Queue
return (null, null);
}
+ public void ClearGatewayBuckets()
+ {
+ foreach (var gwBucket in (GatewayBucketType[])Enum.GetValues(typeof(GatewayBucketType)))
+ _buckets.TryRemove(GatewayBucket.Get(gwBucket).Id, out _);
+ }
+
private async Task RunCleanup()
{
try
diff --git a/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs b/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs
index edd55f158..3fb45e55d 100644
--- a/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs
+++ b/src/Discord.Net.Rest/Net/Queue/RequestQueueBucket.cs
@@ -25,7 +25,7 @@ namespace Discord.Net.Queue
public int WindowCount { get; private set; }
public DateTimeOffset LastAttemptAt { get; private set; }
- public RequestBucket(RequestQueue queue, RestRequest request, BucketId id)
+ public RequestBucket(RequestQueue queue, IRequest request, BucketId id)
{
_queue = queue;
Id = id;
@@ -33,7 +33,9 @@ namespace Discord.Net.Queue
_lock = new object();
if (request.Options.IsClientBucket)
- WindowCount = ClientBucket.Get(Id).WindowCount;
+ WindowCount = ClientBucket.Get(request.Options.BucketId).WindowCount;
+ else if (request.Options.IsGatewayBucket)
+ WindowCount = GatewayBucket.Get(request.Options.BucketId).WindowCount;
else
WindowCount = 1; //Only allow one request until we get a header back
_semaphore = WindowCount;
@@ -154,8 +156,68 @@ namespace Discord.Net.Queue
}
}
}
+ public async Task SendAsync(WebSocketRequest request)
+ {
+ int id = Interlocked.Increment(ref nextId);
+#if DEBUG_LIMITS
+ Debug.WriteLine($"[{id}] Start");
+#endif
+ LastAttemptAt = DateTimeOffset.UtcNow;
+ while (true)
+ {
+ await _queue.EnterGlobalAsync(id, request).ConfigureAwait(false);
+ await EnterAsync(id, request).ConfigureAwait(false);
- private async Task EnterAsync(int id, RestRequest request)
+#if DEBUG_LIMITS
+ Debug.WriteLine($"[{id}] Sending...");
+#endif
+ try
+ {
+ await request.SendAsync().ConfigureAwait(false);
+ return;
+ }
+ catch (TimeoutException)
+ {
+#if DEBUG_LIMITS
+ Debug.WriteLine($"[{id}] Timeout");
+#endif
+ if ((request.Options.RetryMode & RetryMode.RetryTimeouts) == 0)
+ throw;
+
+ await Task.Delay(500).ConfigureAwait(false);
+ continue; //Retry
+ }
+ /*catch (Exception)
+ {
+#if DEBUG_LIMITS
+ Debug.WriteLine($"[{id}] Error");
+#endif
+ if ((request.Options.RetryMode & RetryMode.RetryErrors) == 0)
+ throw;
+
+ await Task.Delay(500);
+ continue; //Retry
+ }*/
+ finally
+ {
+ UpdateRateLimit(id, request, default(RateLimitInfo), false);
+#if DEBUG_LIMITS
+ Debug.WriteLine($"[{id}] Stop");
+#endif
+ }
+ }
+ }
+
+ internal async Task TriggerAsync(int id, IRequest request)
+ {
+#if DEBUG_LIMITS
+ Debug.WriteLine($"[{id}] Trigger Bucket");
+#endif
+ await EnterAsync(id, request).ConfigureAwait(false);
+ UpdateRateLimit(id, request, default(RateLimitInfo), false);
+ }
+
+ private async Task EnterAsync(int id, IRequest request)
{
int windowCount;
DateTimeOffset? resetAt;
@@ -186,8 +248,31 @@ namespace Discord.Net.Queue
{
if (!isRateLimited)
{
+ bool ignoreRatelimit = false;
isRateLimited = true;
- await _queue.RaiseRateLimitTriggered(Id, null, $"{request.Method} {request.Endpoint}").ConfigureAwait(false);
+ switch (request)
+ {
+ case RestRequest restRequest:
+ await _queue.RaiseRateLimitTriggered(Id, null, $"{restRequest.Method} {restRequest.Endpoint}").ConfigureAwait(false);
+ break;
+ case WebSocketRequest webSocketRequest:
+ if (webSocketRequest.IgnoreLimit)
+ {
+ ignoreRatelimit = true;
+ break;
+ }
+ await _queue.RaiseRateLimitTriggered(Id, null, Id.Endpoint).ConfigureAwait(false);
+ break;
+ default:
+ throw new InvalidOperationException("Unknown request type");
+ }
+ if (ignoreRatelimit)
+ {
+#if DEBUG_LIMITS
+ Debug.WriteLine($"[{id}] Ignoring ratelimit");
+#endif
+ break;
+ }
}
ThrowRetryLimit(request);
@@ -223,7 +308,7 @@ namespace Discord.Net.Queue
}
}
- private void UpdateRateLimit(int id, RestRequest request, RateLimitInfo info, bool is429, bool redirected = false)
+ private void UpdateRateLimit(int id, IRequest request, RateLimitInfo info, bool is429, bool redirected = false)
{
if (WindowCount == 0)
return;
@@ -316,6 +401,23 @@ namespace Discord.Net.Queue
Debug.WriteLine($"[{id}] Client Bucket ({ClientBucket.Get(Id).WindowSeconds * 1000} ms)");
#endif
}
+ else if (request.Options.IsGatewayBucket && request.Options.BucketId != null)
+ {
+ resetTick = DateTimeOffset.UtcNow.AddSeconds(GatewayBucket.Get(request.Options.BucketId).WindowSeconds);
+#if DEBUG_LIMITS
+ Debug.WriteLine($"[{id}] Gateway Bucket ({GatewayBucket.Get(request.Options.BucketId).WindowSeconds * 1000} ms)");
+#endif
+ if (!hasQueuedReset)
+ {
+ _resetTick = resetTick;
+ LastAttemptAt = resetTick.Value;
+#if DEBUG_LIMITS
+ Debug.WriteLine($"[{id}] Reset in {(int)Math.Ceiling((resetTick - DateTimeOffset.UtcNow).Value.TotalMilliseconds)} ms");
+#endif
+ var _ = QueueReset(id, (int)Math.Ceiling((_resetTick.Value - DateTimeOffset.UtcNow).TotalMilliseconds), request);
+ }
+ return;
+ }
if (resetTick == null)
{
@@ -336,12 +438,12 @@ namespace Discord.Net.Queue
if (!hasQueuedReset)
{
- var _ = QueueReset(id, (int)Math.Ceiling((_resetTick.Value - DateTimeOffset.UtcNow).TotalMilliseconds));
+ var _ = QueueReset(id, (int)Math.Ceiling((_resetTick.Value - DateTimeOffset.UtcNow).TotalMilliseconds), request);
}
}
}
}
- private async Task QueueReset(int id, int millis)
+ private async Task QueueReset(int id, int millis, IRequest request)
{
while (true)
{
@@ -363,7 +465,7 @@ namespace Discord.Net.Queue
}
}
- private void ThrowRetryLimit(RestRequest request)
+ private void ThrowRetryLimit(IRequest request)
{
if ((request.Options.RetryMode & RetryMode.RetryRatelimit) == 0)
throw new RateLimitedException(request);
diff --git a/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs b/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs
index 81eb40b31..ebebd7bef 100644
--- a/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs
+++ b/src/Discord.Net.Rest/Net/Queue/Requests/WebSocketRequest.cs
@@ -9,22 +9,22 @@ namespace Discord.Net.Queue
public class WebSocketRequest : IRequest
{
public IWebSocketClient Client { get; }
- public string BucketId { get; }
public byte[] Data { get; }
public bool IsText { get; }
+ public bool IgnoreLimit { get; }
public DateTimeOffset? TimeoutAt { get; }
public TaskCompletionSource Promise { get; }
public RequestOptions Options { get; }
public CancellationToken CancelToken { get; internal set; }
- public WebSocketRequest(IWebSocketClient client, string bucketId, byte[] data, bool isText, RequestOptions options)
+ public WebSocketRequest(IWebSocketClient client, byte[] data, bool isText, bool ignoreLimit, RequestOptions options)
{
Preconditions.NotNull(options, nameof(options));
Client = client;
- BucketId = bucketId;
Data = data;
IsText = isText;
+ IgnoreLimit = ignoreLimit;
Options = options;
TimeoutAt = options.Timeout.HasValue ? DateTimeOffset.UtcNow.AddMilliseconds(options.Timeout.Value) : (DateTimeOffset?)null;
Promise = new TaskCompletionSource();
diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs
index a8780a7b0..a2c89d4e5 100644
--- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs
@@ -12,12 +12,14 @@ namespace Discord.WebSocket
public partial class DiscordShardedClient : BaseSocketClient, IDiscordClient
{
private readonly DiscordSocketConfig _baseConfig;
- private readonly SemaphoreSlim _connectionGroupLock;
private readonly Dictionary _shardIdsToIndex;
private readonly bool _automaticShards;
private int[] _shardIds;
private DiscordSocketClient[] _shards;
private int _totalShards;
+ private SemaphoreSlim[] _identifySemaphores;
+ private object _semaphoreResetLock;
+ private Task _semaphoreResetTask;
private bool _isDisposed;
@@ -62,10 +64,10 @@ namespace Discord.WebSocket
if (ids != null && config.TotalShards == null)
throw new ArgumentException($"Custom ids are not supported when {nameof(config.TotalShards)} is not specified.");
+ _semaphoreResetLock = new object();
_shardIdsToIndex = new Dictionary();
config.DisplayInitialLog = false;
_baseConfig = config;
- _connectionGroupLock = new SemaphoreSlim(1, 1);
if (config.TotalShards == null)
_automaticShards = true;
@@ -74,12 +76,15 @@ namespace Discord.WebSocket
_totalShards = config.TotalShards.Value;
_shardIds = ids ?? Enumerable.Range(0, _totalShards).ToArray();
_shards = new DiscordSocketClient[_shardIds.Length];
+ _identifySemaphores = new SemaphoreSlim[config.IdentifyMaxConcurrency];
+ for (int i = 0; i < config.IdentifyMaxConcurrency; i++)
+ _identifySemaphores[i] = new SemaphoreSlim(1, 1);
for (int i = 0; i < _shardIds.Length; i++)
{
_shardIdsToIndex.Add(_shardIds[i], i);
var newConfig = config.Clone();
newConfig.ShardId = _shardIds[i];
- _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null);
+ _shards[i] = new DiscordSocketClient(newConfig, this, i != 0 ? _shards[0] : null);
RegisterEvents(_shards[i], i == 0);
}
}
@@ -88,21 +93,53 @@ namespace Discord.WebSocket
=> new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent,
rateLimitPrecision: config.RateLimitPrecision);
+ internal async Task AcquireIdentifyLockAsync(int shardId, CancellationToken token)
+ {
+ int semaphoreIdx = shardId % _baseConfig.IdentifyMaxConcurrency;
+ await _identifySemaphores[semaphoreIdx].WaitAsync(token).ConfigureAwait(false);
+ }
+
+ internal void ReleaseIdentifyLock()
+ {
+ lock (_semaphoreResetLock)
+ {
+ if (_semaphoreResetTask == null)
+ _semaphoreResetTask = ResetSemaphoresAsync();
+ }
+ }
+
+ private async Task ResetSemaphoresAsync()
+ {
+ await Task.Delay(5000).ConfigureAwait(false);
+ lock (_semaphoreResetLock)
+ {
+ foreach (var semaphore in _identifySemaphores)
+ if (semaphore.CurrentCount == 0)
+ semaphore.Release();
+ _semaphoreResetTask = null;
+ }
+ }
+
internal override async Task OnLoginAsync(TokenType tokenType, string token)
{
if (_automaticShards)
{
- var shardCount = await GetRecommendedShardCountAsync().ConfigureAwait(false);
- _shardIds = Enumerable.Range(0, shardCount).ToArray();
+ var botGateway = await GetBotGatewayAsync().ConfigureAwait(false);
+ _shardIds = Enumerable.Range(0, botGateway.Shards).ToArray();
_totalShards = _shardIds.Length;
_shards = new DiscordSocketClient[_shardIds.Length];
+ int maxConcurrency = botGateway.SessionStartLimit.MaxConcurrency;
+ _baseConfig.IdentifyMaxConcurrency = maxConcurrency;
+ _identifySemaphores = new SemaphoreSlim[maxConcurrency];
+ for (int i = 0; i < maxConcurrency; i++)
+ _identifySemaphores[i] = new SemaphoreSlim(1, 1);
for (int i = 0; i < _shardIds.Length; i++)
{
_shardIdsToIndex.Add(_shardIds[i], i);
var newConfig = _baseConfig.Clone();
newConfig.ShardId = _shardIds[i];
newConfig.TotalShards = _totalShards;
- _shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null);
+ _shards[i] = new DiscordSocketClient(newConfig, this, i != 0 ? _shards[0] : null);
RegisterEvents(_shards[i], i == 0);
}
}
@@ -398,7 +435,6 @@ namespace Discord.WebSocket
foreach (var client in _shards)
client?.Dispose();
}
- _connectionGroupLock?.Dispose();
}
_isDisposed = true;
diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
index 1b21bd666..07ebc87ec 100644
--- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
@@ -132,6 +132,8 @@ namespace Discord.API
if (WebSocketClient == null)
throw new NotSupportedException("This client is not configured with WebSocket support.");
+ RequestQueue.ClearGatewayBuckets();
+
//Re-create streams to reset the zlib state
_compressed?.Dispose();
_decompressor?.Dispose();
@@ -205,7 +207,11 @@ namespace Discord.API
payload = new SocketFrame { Operation = (int)opCode, Payload = payload };
if (payload != null)
bytes = Encoding.UTF8.GetBytes(SerializeJson(payload));
- await RequestQueue.SendAsync(new WebSocketRequest(WebSocketClient, null, bytes, true, options)).ConfigureAwait(false);
+
+ options.IsGatewayBucket = true;
+ if (options.BucketId == null)
+ options.BucketId = GatewayBucket.Get(GatewayBucketType.Unbucketed).Id;
+ await RequestQueue.SendAsync(new WebSocketRequest(WebSocketClient, bytes, true, opCode == GatewayOpCode.Heartbeat, options)).ConfigureAwait(false);
await _sentGatewayMessageEvent.InvokeAsync(opCode).ConfigureAwait(false);
}
@@ -225,6 +231,8 @@ namespace Discord.API
if (totalShards > 1)
msg.ShardingParams = new int[] { shardID, totalShards };
+ options.BucketId = GatewayBucket.Get(GatewayBucketType.Identify).Id;
+
if (gatewayIntents.HasValue)
msg.Intents = (int)gatewayIntents.Value;
else
@@ -258,6 +266,7 @@ namespace Discord.API
IsAFK = isAFK,
Game = game
};
+ options.BucketId = GatewayBucket.Get(GatewayBucketType.PresenceUpdate).Id;
await SendGatewayAsync(GatewayOpCode.StatusUpdate, args, options: options).ConfigureAwait(false);
}
public async Task SendRequestMembersAsync(IEnumerable guildIds, RequestOptions options = null)
diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs
index dfdad99fc..d53387afc 100644
--- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs
@@ -26,7 +26,7 @@ namespace Discord.WebSocket
{
private readonly ConcurrentQueue _largeGuilds;
private readonly JsonSerializer _serializer;
- private readonly SemaphoreSlim _connectionGroupLock;
+ private readonly DiscordShardedClient _shardedClient;
private readonly DiscordSocketClient _parentClient;
private readonly ConcurrentQueue _heartbeatTimes;
private readonly ConnectionManager _connection;
@@ -120,9 +120,9 @@ namespace Discord.WebSocket
/// The configuration to be used with the client.
#pragma warning disable IDISP004
public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config), null, null) { }
- internal DiscordSocketClient(DiscordSocketConfig config, SemaphoreSlim groupLock, DiscordSocketClient parentClient) : this(config, CreateApiClient(config), groupLock, parentClient) { }
+ internal DiscordSocketClient(DiscordSocketConfig config, DiscordShardedClient shardedClient, DiscordSocketClient parentClient) : this(config, CreateApiClient(config), shardedClient, parentClient) { }
#pragma warning restore IDISP004
- private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, SemaphoreSlim groupLock, DiscordSocketClient parentClient)
+ private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, DiscordShardedClient shardedClient, DiscordSocketClient parentClient)
: base(config, client)
{
ShardId = config.ShardId ?? 0;
@@ -148,7 +148,7 @@ namespace Discord.WebSocket
_connection.Disconnected += (ex, recon) => TimedInvokeAsync(_disconnectedEvent, nameof(Disconnected), ex);
_nextAudioId = 1;
- _connectionGroupLock = groupLock;
+ _shardedClient = shardedClient;
_parentClient = parentClient;
_serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() };
@@ -229,8 +229,12 @@ namespace Discord.WebSocket
private async Task OnConnectingAsync()
{
- if (_connectionGroupLock != null)
- await _connectionGroupLock.WaitAsync(_connection.CancelToken).ConfigureAwait(false);
+ bool locked = false;
+ if (_shardedClient != null && _sessionId == null)
+ {
+ await _shardedClient.AcquireIdentifyLockAsync(ShardId, _connection.CancelToken).ConfigureAwait(false);
+ locked = true;
+ }
try
{
await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
@@ -255,11 +259,8 @@ namespace Discord.WebSocket
}
finally
{
- if (_connectionGroupLock != null)
- {
- await Task.Delay(5000).ConfigureAwait(false);
- _connectionGroupLock.Release();
- }
+ if (locked)
+ _shardedClient.ReleaseIdentifyLock();
}
}
private async Task OnDisconnectingAsync(Exception ex)
@@ -519,7 +520,15 @@ namespace Discord.WebSocket
_sessionId = null;
_lastSeq = 0;
- await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards, guildSubscriptions: _guildSubscriptions, gatewayIntents: _gatewayIntents).ConfigureAwait(false);
+ await _shardedClient.AcquireIdentifyLockAsync(ShardId, _connection.CancelToken).ConfigureAwait(false);
+ try
+ {
+ await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards, guildSubscriptions: _guildSubscriptions, gatewayIntents: _gatewayIntents).ConfigureAwait(false);
+ }
+ finally
+ {
+ _shardedClient.ReleaseIdentifyLock();
+ }
}
break;
case GatewayOpCode.Reconnect:
diff --git a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs
index 0e8fbe73f..6b0c5ebc4 100644
--- a/src/Discord.Net.WebSocket/DiscordSocketConfig.cs
+++ b/src/Discord.Net.WebSocket/DiscordSocketConfig.cs
@@ -126,6 +126,14 @@ namespace Discord.WebSocket
public bool GuildSubscriptions { get; set; } = true;
///
+ /// Gets or sets the maximum identify concurrency.
+ ///
+ ///
+ /// This information is provided by Discord.
+ /// It is only used when using a and auto-sharding is disabled.
+ ///
+ public int IdentifyMaxConcurrency { get; set; } = 1;
+
/// Gets or sets the maximum wait time in milliseconds between GUILD_AVAILABLE events before firing READY.
///
/// If zero, READY will fire as soon as it is received and all guilds will be unavailable.