* Implemented resume_gateway_url. * Made the requested changes. * Implemented passing the gateway URL down from DiscordShardedClient. Co-authored-by: Quin Lynch <49576606+quinchs@users.noreply.github.com>pull/2464/head
@@ -20,6 +20,8 @@ namespace Discord.API.Gateway | |||||
public User User { get; set; } | public User User { get; set; } | ||||
[JsonProperty("session_id")] | [JsonProperty("session_id")] | ||||
public string SessionId { get; set; } | public string SessionId { get; set; } | ||||
[JsonProperty("resume_gateway_url")] | |||||
public string ResumeGatewayUrl { get; set; } | |||||
[JsonProperty("read_state")] | [JsonProperty("read_state")] | ||||
public ReadState[] ReadStates { get; set; } | public ReadState[] ReadStates { get; set; } | ||||
[JsonProperty("guilds")] | [JsonProperty("guilds")] | ||||
@@ -139,9 +139,9 @@ namespace Discord.WebSocket | |||||
internal override async Task OnLoginAsync(TokenType tokenType, string token) | internal override async Task OnLoginAsync(TokenType tokenType, string token) | ||||
{ | { | ||||
var botGateway = await GetBotGatewayAsync().ConfigureAwait(false); | |||||
if (_automaticShards) | if (_automaticShards) | ||||
{ | { | ||||
var botGateway = await GetBotGatewayAsync().ConfigureAwait(false); | |||||
_shardIds = Enumerable.Range(0, botGateway.Shards).ToArray(); | _shardIds = Enumerable.Range(0, botGateway.Shards).ToArray(); | ||||
_totalShards = _shardIds.Length; | _totalShards = _shardIds.Length; | ||||
_shards = new DiscordSocketClient[_shardIds.Length]; | _shards = new DiscordSocketClient[_shardIds.Length]; | ||||
@@ -163,7 +163,12 @@ namespace Discord.WebSocket | |||||
//Assume thread safe: already in a connection lock | //Assume thread safe: already in a connection lock | ||||
for (int i = 0; i < _shards.Length; i++) | for (int i = 0; i < _shards.Length; i++) | ||||
{ | |||||
// Set the gateway URL to the one returned by Discord, if a custom one isn't set. | |||||
_shards[i].ApiClient.GatewayUrl = botGateway.Url; | |||||
await _shards[i].LoginAsync(tokenType, token); | await _shards[i].LoginAsync(tokenType, token); | ||||
} | |||||
if(_defaultStickers.Length == 0 && _baseConfig.AlwaysDownloadDefaultStickers) | if(_defaultStickers.Length == 0 && _baseConfig.AlwaysDownloadDefaultStickers) | ||||
await DownloadDefaultStickersAsync().ConfigureAwait(false); | await DownloadDefaultStickersAsync().ConfigureAwait(false); | ||||
@@ -175,7 +180,12 @@ namespace Discord.WebSocket | |||||
if (_shards != null) | if (_shards != null) | ||||
{ | { | ||||
for (int i = 0; i < _shards.Length; i++) | for (int i = 0; i < _shards.Length; i++) | ||||
{ | |||||
// Reset the gateway URL set for the shard. | |||||
_shards[i].ApiClient.GatewayUrl = null; | |||||
await _shards[i].LogoutAsync(); | await _shards[i].LogoutAsync(); | ||||
} | |||||
} | } | ||||
if (_automaticShards) | if (_automaticShards) | ||||
@@ -28,6 +28,7 @@ namespace Discord.API | |||||
private readonly bool _isExplicitUrl; | private readonly bool _isExplicitUrl; | ||||
private CancellationTokenSource _connectCancelToken; | private CancellationTokenSource _connectCancelToken; | ||||
private string _gatewayUrl; | private string _gatewayUrl; | ||||
private string _resumeGatewayUrl; | |||||
//Store our decompression streams for zlib shared state | //Store our decompression streams for zlib shared state | ||||
private MemoryStream _compressed; | private MemoryStream _compressed; | ||||
@@ -37,6 +38,32 @@ namespace Discord.API | |||||
public ConnectionState ConnectionState { get; private set; } | public ConnectionState ConnectionState { get; private set; } | ||||
/// <summary> | |||||
/// Sets the gateway URL used for identifies. | |||||
/// </summary> | |||||
/// <remarks> | |||||
/// If a custom URL is set, setting this property does nothing. | |||||
/// </remarks> | |||||
public string GatewayUrl | |||||
{ | |||||
set | |||||
{ | |||||
// Makes the sharded client not override the custom value. | |||||
if (_isExplicitUrl) | |||||
return; | |||||
_gatewayUrl = FormatGatewayUrl(value); | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Sets the gateway URL used for resumes. | |||||
/// </summary> | |||||
public string ResumeGatewayUrl | |||||
{ | |||||
set => _resumeGatewayUrl = FormatGatewayUrl(value); | |||||
} | |||||
public DiscordSocketApiClient(RestClientProvider restClientProvider, WebSocketProvider webSocketProvider, string userAgent, | public DiscordSocketApiClient(RestClientProvider restClientProvider, WebSocketProvider webSocketProvider, string userAgent, | ||||
string url = null, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, JsonSerializer serializer = null, | string url = null, RetryMode defaultRetryMode = RetryMode.AlwaysRetry, JsonSerializer serializer = null, | ||||
bool useSystemClock = true, Func<IRateLimitInfo, Task> defaultRatelimitCallback = null) | bool useSystemClock = true, Func<IRateLimitInfo, Task> defaultRatelimitCallback = null) | ||||
@@ -157,6 +184,17 @@ namespace Discord.API | |||||
#endif | #endif | ||||
} | } | ||||
/// <summary> | |||||
/// Appends necessary query parameters to the specified gateway URL. | |||||
/// </summary> | |||||
private static string FormatGatewayUrl(string gatewayUrl) | |||||
{ | |||||
if (gatewayUrl == null) | |||||
return null; | |||||
return $"{gatewayUrl}?v={DiscordConfig.APIVersion}&encoding={DiscordSocketConfig.GatewayEncoding}&compress=zlib-stream"; | |||||
} | |||||
public async Task ConnectAsync() | public async Task ConnectAsync() | ||||
{ | { | ||||
await _stateLock.WaitAsync().ConfigureAwait(false); | await _stateLock.WaitAsync().ConfigureAwait(false); | ||||
@@ -191,24 +229,32 @@ namespace Discord.API | |||||
if (WebSocketClient != null) | if (WebSocketClient != null) | ||||
WebSocketClient.SetCancelToken(_connectCancelToken.Token); | WebSocketClient.SetCancelToken(_connectCancelToken.Token); | ||||
if (!_isExplicitUrl) | |||||
string gatewayUrl; | |||||
if (_resumeGatewayUrl == null) | |||||
{ | |||||
if (!_isExplicitUrl && _gatewayUrl == null) | |||||
{ | |||||
var gatewayResponse = await GetBotGatewayAsync().ConfigureAwait(false); | |||||
_gatewayUrl = FormatGatewayUrl(gatewayResponse.Url); | |||||
} | |||||
gatewayUrl = _gatewayUrl; | |||||
} | |||||
else | |||||
{ | { | ||||
var gatewayResponse = await GetGatewayAsync().ConfigureAwait(false); | |||||
_gatewayUrl = $"{gatewayResponse.Url}?v={DiscordConfig.APIVersion}&encoding={DiscordSocketConfig.GatewayEncoding}&compress=zlib-stream"; | |||||
gatewayUrl = _resumeGatewayUrl; | |||||
} | } | ||||
#if DEBUG_PACKETS | #if DEBUG_PACKETS | ||||
Console.WriteLine("Connecting to gateway: " + _gatewayUrl); | |||||
Console.WriteLine("Connecting to gateway: " + gatewayUrl); | |||||
#endif | #endif | ||||
await WebSocketClient.ConnectAsync(_gatewayUrl).ConfigureAwait(false); | |||||
await WebSocketClient.ConnectAsync(gatewayUrl).ConfigureAwait(false); | |||||
ConnectionState = ConnectionState.Connected; | ConnectionState = ConnectionState.Connected; | ||||
} | } | ||||
catch | catch | ||||
{ | { | ||||
if (!_isExplicitUrl) | |||||
_gatewayUrl = null; //Uncache in case the gateway url changed | |||||
await DisconnectInternalAsync().ConfigureAwait(false); | await DisconnectInternalAsync().ConfigureAwait(false); | ||||
throw; | throw; | ||||
} | } | ||||
@@ -322,7 +322,6 @@ namespace Discord.WebSocket | |||||
} | } | ||||
private async Task OnDisconnectingAsync(Exception ex) | private async Task OnDisconnectingAsync(Exception ex) | ||||
{ | { | ||||
await _gatewayLogger.DebugAsync("Disconnecting ApiClient").ConfigureAwait(false); | await _gatewayLogger.DebugAsync("Disconnecting ApiClient").ConfigureAwait(false); | ||||
await ApiClient.DisconnectAsync(ex).ConfigureAwait(false); | await ApiClient.DisconnectAsync(ex).ConfigureAwait(false); | ||||
@@ -353,6 +352,10 @@ namespace Discord.WebSocket | |||||
if (guild.IsAvailable) | if (guild.IsAvailable) | ||||
await GuildUnavailableAsync(guild).ConfigureAwait(false); | await GuildUnavailableAsync(guild).ConfigureAwait(false); | ||||
} | } | ||||
_sessionId = null; | |||||
_lastSeq = 0; | |||||
ApiClient.ResumeGatewayUrl = null; | |||||
} | } | ||||
/// <inheritdoc /> | /// <inheritdoc /> | ||||
@@ -834,6 +837,7 @@ namespace Discord.WebSocket | |||||
_sessionId = null; | _sessionId = null; | ||||
_lastSeq = 0; | _lastSeq = 0; | ||||
ApiClient.ResumeGatewayUrl = null; | |||||
if (_shardedClient != null) | if (_shardedClient != null) | ||||
{ | { | ||||
@@ -891,6 +895,7 @@ namespace Discord.WebSocket | |||||
AddPrivateChannel(data.PrivateChannels[i], state); | AddPrivateChannel(data.PrivateChannels[i], state); | ||||
_sessionId = data.SessionId; | _sessionId = data.SessionId; | ||||
ApiClient.ResumeGatewayUrl = data.ResumeGatewayUrl; | |||||
_unavailableGuildCount = unavailableGuilds; | _unavailableGuildCount = unavailableGuilds; | ||||
CurrentUser = currentUser; | CurrentUser = currentUser; | ||||
_previousSessionUser = CurrentUser; | _previousSessionUser = CurrentUser; | ||||