Browse Source

Validate login state before sending REST requests or RPC msgs

pull/204/head
RogueException 8 years ago
parent
commit
ae867451b5
4 changed files with 41 additions and 26 deletions
  1. +27
    -17
      src/Discord.Net/API/DiscordRestApiClient.cs
  2. +11
    -8
      src/Discord.Net/API/DiscordRpcAPIClient.cs
  3. +2
    -0
      src/Discord.Net/API/DiscordSocketApiClient.cs
  4. +1
    -1
      src/Discord.Net/WebSocket/DiscordSocketClient.cs

+ 27
- 17
src/Discord.Net/API/DiscordRestApiClient.cs View File

@@ -152,30 +152,30 @@ namespace Discord.API

//REST
public Task SendAsync(string method, string endpoint,
GlobalBucket bucket = GlobalBucket.GeneralRest, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, null, true, BucketGroup.Global, (int)bucket, 0, options);
GlobalBucket bucket = GlobalBucket.GeneralRest, bool ignoreState = false, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, null, true, BucketGroup.Global, (int)bucket, 0, ignoreState, options);
public Task SendAsync(string method, string endpoint, object payload,
GlobalBucket bucket = GlobalBucket.GeneralRest, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, payload, true, BucketGroup.Global, (int)bucket, 0, options);
GlobalBucket bucket = GlobalBucket.GeneralRest, bool ignoreState = false, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, payload, true, BucketGroup.Global, (int)bucket, 0, ignoreState, options);
public async Task<TResponse> SendAsync<TResponse>(string method, string endpoint,
GlobalBucket bucket = GlobalBucket.GeneralRest, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, null, false, BucketGroup.Global, (int)bucket, 0, options).ConfigureAwait(false));
GlobalBucket bucket = GlobalBucket.GeneralRest, bool ignoreState = false, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, null, false, BucketGroup.Global, (int)bucket, 0, ignoreState, options).ConfigureAwait(false));
public async Task<TResponse> SendAsync<TResponse>(string method, string endpoint, object payload, GlobalBucket bucket =
GlobalBucket.GeneralRest, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, payload, false, BucketGroup.Global, (int)bucket, 0, options).ConfigureAwait(false));
GlobalBucket.GeneralRest, bool ignoreState = false, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, payload, false, BucketGroup.Global, (int)bucket, 0, ignoreState, options).ConfigureAwait(false));

public Task SendAsync(string method, string endpoint,
GuildBucket bucket, ulong guildId, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, null, true, BucketGroup.Guild, (int)bucket, guildId, options);
GuildBucket bucket, ulong guildId, bool ignoreState = false, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, null, true, BucketGroup.Guild, (int)bucket, guildId, ignoreState, options);
public Task SendAsync(string method, string endpoint, object payload,
GuildBucket bucket, ulong guildId, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, payload, true, BucketGroup.Guild, (int)bucket, guildId, options);
GuildBucket bucket, ulong guildId, bool ignoreState = false, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, payload, true, BucketGroup.Guild, (int)bucket, guildId, ignoreState, options);
public async Task<TResponse> SendAsync<TResponse>(string method, string endpoint,
GuildBucket bucket, ulong guildId, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, null, false, BucketGroup.Guild, (int)bucket, guildId, options).ConfigureAwait(false));
GuildBucket bucket, ulong guildId, bool ignoreState = false, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, null, false, BucketGroup.Guild, (int)bucket, guildId, ignoreState, options).ConfigureAwait(false));
public async Task<TResponse> SendAsync<TResponse>(string method, string endpoint, object payload,
GuildBucket bucket, ulong guildId, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, payload, false, BucketGroup.Guild, (int)bucket, guildId, options).ConfigureAwait(false));
GuildBucket bucket, ulong guildId, bool ignoreState = false, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, payload, false, BucketGroup.Guild, (int)bucket, guildId, ignoreState, options).ConfigureAwait(false));

//REST - Multipart
public Task SendMultipartAsync(string method, string endpoint, IReadOnlyDictionary<string, object> multipartArgs,
@@ -194,8 +194,11 @@ namespace Discord.API

//Core
private async Task<Stream> SendInternalAsync(string method, string endpoint, object payload, bool headerOnly,
BucketGroup group, int bucketId, ulong guildId, RequestOptions options = null)
BucketGroup group, int bucketId, ulong guildId, bool ignoreState, RequestOptions options = null)
{
if (!ignoreState)
CheckState();

var stopwatch = Stopwatch.StartNew();
string json = null;
if (payload != null)
@@ -211,6 +214,8 @@ namespace Discord.API
private async Task<Stream> SendMultipartInternalAsync(string method, string endpoint, IReadOnlyDictionary<string, object> multipartArgs, bool headerOnly,
BucketGroup group, int bucketId, ulong guildId, RequestOptions options = null)
{
CheckState();

var stopwatch = Stopwatch.StartNew();
var responseStream = await RequestQueue.SendAsync(new RestRequest(_restClient, method, endpoint, multipartArgs, headerOnly, options), group, bucketId, guildId).ConfigureAwait(false);
int bytes = headerOnly ? 0 : (int)responseStream.Length;
@@ -1039,6 +1044,11 @@ namespace Discord.API
}

//Helpers
protected void CheckState()
{
if (LoginState != LoginState.LoggedIn)
throw new InvalidOperationException("Client is not logged in.");
}
protected static double ToMilliseconds(Stopwatch stopwatch) => Math.Round((double)stopwatch.ElapsedTicks / (double)Stopwatch.Frequency * 1000.0, 2);
protected string SerializeJson(object value)
{


+ 11
- 8
src/Discord.Net/API/DiscordRpcAPIClient.cs View File

@@ -140,7 +140,7 @@ namespace Discord.API
internal override async Task ConnectInternalAsync()
{
/*if (LoginState != LoginState.LoggedIn)
throw new InvalidOperationException("You must log in before connecting.");*/
throw new InvalidOperationException("Client is not logged in.");*/

ConnectionState = ConnectionState.Connecting;
try
@@ -207,17 +207,20 @@ namespace Discord.API

//Core
public Task<TResponse> SendRpcAsync<TResponse>(string cmd, object payload, GlobalBucket bucket = GlobalBucket.GeneralRpc,
Optional<string> evt = default(Optional<string>), RequestOptions options = null)
Optional<string> evt = default(Optional<string>), bool ignoreState = false, RequestOptions options = null)
where TResponse : class
=> SendRpcAsyncInternal<TResponse>(cmd, payload, BucketGroup.Global, (int)bucket, 0, evt, options);
=> SendRpcAsyncInternal<TResponse>(cmd, payload, BucketGroup.Global, (int)bucket, 0, evt, ignoreState, options);
public Task<TResponse> SendRpcAsync<TResponse>(string cmd, object payload, GuildBucket bucket, ulong guildId,
Optional<string> evt = default(Optional<string>), RequestOptions options = null)
Optional<string> evt = default(Optional<string>), bool ignoreState = false, RequestOptions options = null)
where TResponse : class
=> SendRpcAsyncInternal<TResponse>(cmd, payload, BucketGroup.Guild, (int)bucket, guildId, evt, options);
=> SendRpcAsyncInternal<TResponse>(cmd, payload, BucketGroup.Guild, (int)bucket, guildId, evt, ignoreState, options);
private async Task<TResponse> SendRpcAsyncInternal<TResponse>(string cmd, object payload, BucketGroup group, int bucketId, ulong guildId,
Optional<string> evt, RequestOptions options)
Optional<string> evt, bool ignoreState, RequestOptions options)
where TResponse : class
{
if (!ignoreState)
CheckState();

byte[] bytes = null;
var guid = Guid.NewGuid();
payload = new API.Rpc.RpcMessage { Cmd = cmd, Event = evt, Args = payload, Nonce = guid };
@@ -242,7 +245,7 @@ namespace Discord.API
{
AccessToken = _authToken
};
return await SendRpcAsync<AuthenticateResponse>("AUTHENTICATE", msg, options: options).ConfigureAwait(false);
return await SendRpcAsync<AuthenticateResponse>("AUTHENTICATE", msg, ignoreState: true, options: options).ConfigureAwait(false);
}
public async Task<AuthorizeResponse> SendAuthorizeAsync(string[] scopes, string rpcToken = null, RequestOptions options = null)
{
@@ -256,7 +259,7 @@ namespace Discord.API
options = new RequestOptions();
if (options.Timeout == null)
options.Timeout = 60000; //This requires manual input on the user's end, lets give them more time
return await SendRpcAsync<AuthorizeResponse>("AUTHORIZE", msg, options: options).ConfigureAwait(false);
return await SendRpcAsync<AuthorizeResponse>("AUTHORIZE", msg, ignoreState: true, options: options).ConfigureAwait(false);
}

public async Task<GetGuildsResponse> SendGetGuildsAsync(RequestOptions options = null)


+ 2
- 0
src/Discord.Net/API/DiscordSocketApiClient.cs View File

@@ -159,6 +159,8 @@ namespace Discord.API
private async Task SendGatewayInternalAsync(GatewayOpCode opCode, object payload,
BucketGroup group, int bucketId, ulong guildId, RequestOptions options)
{
CheckState();

//TODO: Add ETF
byte[] bytes = null;
payload = new WebSocketMessage { Operation = (int)opCode, Payload = payload };


+ 1
- 1
src/Discord.Net/WebSocket/DiscordSocketClient.cs View File

@@ -144,7 +144,7 @@ namespace Discord.WebSocket
private async Task ConnectInternalAsync(bool isReconnecting)
{
if (LoginState != LoginState.LoggedIn)
throw new InvalidOperationException("You must log in before connecting.");
throw new InvalidOperationException("Client is not logged in.");

if (!isReconnecting && _reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested)
_reconnectCancelToken.Cancel();


Loading…
Cancel
Save