Browse Source

Validate login state before sending REST requests or RPC msgs

pull/204/head
RogueException 8 years ago
parent
commit
ae867451b5
4 changed files with 41 additions and 26 deletions
  1. +27
    -17
      src/Discord.Net/API/DiscordRestApiClient.cs
  2. +11
    -8
      src/Discord.Net/API/DiscordRpcAPIClient.cs
  3. +2
    -0
      src/Discord.Net/API/DiscordSocketApiClient.cs
  4. +1
    -1
      src/Discord.Net/WebSocket/DiscordSocketClient.cs

+ 27
- 17
src/Discord.Net/API/DiscordRestApiClient.cs View File

@@ -152,30 +152,30 @@ namespace Discord.API


//REST //REST
public Task SendAsync(string method, string endpoint, public Task SendAsync(string method, string endpoint,
GlobalBucket bucket = GlobalBucket.GeneralRest, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, null, true, BucketGroup.Global, (int)bucket, 0, options);
GlobalBucket bucket = GlobalBucket.GeneralRest, bool ignoreState = false, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, null, true, BucketGroup.Global, (int)bucket, 0, ignoreState, options);
public Task SendAsync(string method, string endpoint, object payload, public Task SendAsync(string method, string endpoint, object payload,
GlobalBucket bucket = GlobalBucket.GeneralRest, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, payload, true, BucketGroup.Global, (int)bucket, 0, options);
GlobalBucket bucket = GlobalBucket.GeneralRest, bool ignoreState = false, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, payload, true, BucketGroup.Global, (int)bucket, 0, ignoreState, options);
public async Task<TResponse> SendAsync<TResponse>(string method, string endpoint, public async Task<TResponse> SendAsync<TResponse>(string method, string endpoint,
GlobalBucket bucket = GlobalBucket.GeneralRest, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, null, false, BucketGroup.Global, (int)bucket, 0, options).ConfigureAwait(false));
GlobalBucket bucket = GlobalBucket.GeneralRest, bool ignoreState = false, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, null, false, BucketGroup.Global, (int)bucket, 0, ignoreState, options).ConfigureAwait(false));
public async Task<TResponse> SendAsync<TResponse>(string method, string endpoint, object payload, GlobalBucket bucket = public async Task<TResponse> SendAsync<TResponse>(string method, string endpoint, object payload, GlobalBucket bucket =
GlobalBucket.GeneralRest, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, payload, false, BucketGroup.Global, (int)bucket, 0, options).ConfigureAwait(false));
GlobalBucket.GeneralRest, bool ignoreState = false, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, payload, false, BucketGroup.Global, (int)bucket, 0, ignoreState, options).ConfigureAwait(false));


public Task SendAsync(string method, string endpoint, public Task SendAsync(string method, string endpoint,
GuildBucket bucket, ulong guildId, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, null, true, BucketGroup.Guild, (int)bucket, guildId, options);
GuildBucket bucket, ulong guildId, bool ignoreState = false, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, null, true, BucketGroup.Guild, (int)bucket, guildId, ignoreState, options);
public Task SendAsync(string method, string endpoint, object payload, public Task SendAsync(string method, string endpoint, object payload,
GuildBucket bucket, ulong guildId, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, payload, true, BucketGroup.Guild, (int)bucket, guildId, options);
GuildBucket bucket, ulong guildId, bool ignoreState = false, RequestOptions options = null)
=> SendInternalAsync(method, endpoint, payload, true, BucketGroup.Guild, (int)bucket, guildId, ignoreState, options);
public async Task<TResponse> SendAsync<TResponse>(string method, string endpoint, public async Task<TResponse> SendAsync<TResponse>(string method, string endpoint,
GuildBucket bucket, ulong guildId, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, null, false, BucketGroup.Guild, (int)bucket, guildId, options).ConfigureAwait(false));
GuildBucket bucket, ulong guildId, bool ignoreState = false, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, null, false, BucketGroup.Guild, (int)bucket, guildId, ignoreState, options).ConfigureAwait(false));
public async Task<TResponse> SendAsync<TResponse>(string method, string endpoint, object payload, public async Task<TResponse> SendAsync<TResponse>(string method, string endpoint, object payload,
GuildBucket bucket, ulong guildId, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, payload, false, BucketGroup.Guild, (int)bucket, guildId, options).ConfigureAwait(false));
GuildBucket bucket, ulong guildId, bool ignoreState = false, RequestOptions options = null) where TResponse : class
=> DeserializeJson<TResponse>(await SendInternalAsync(method, endpoint, payload, false, BucketGroup.Guild, (int)bucket, guildId, ignoreState, options).ConfigureAwait(false));


//REST - Multipart //REST - Multipart
public Task SendMultipartAsync(string method, string endpoint, IReadOnlyDictionary<string, object> multipartArgs, public Task SendMultipartAsync(string method, string endpoint, IReadOnlyDictionary<string, object> multipartArgs,
@@ -194,8 +194,11 @@ namespace Discord.API


//Core //Core
private async Task<Stream> SendInternalAsync(string method, string endpoint, object payload, bool headerOnly, private async Task<Stream> SendInternalAsync(string method, string endpoint, object payload, bool headerOnly,
BucketGroup group, int bucketId, ulong guildId, RequestOptions options = null)
BucketGroup group, int bucketId, ulong guildId, bool ignoreState, RequestOptions options = null)
{ {
if (!ignoreState)
CheckState();

var stopwatch = Stopwatch.StartNew(); var stopwatch = Stopwatch.StartNew();
string json = null; string json = null;
if (payload != null) if (payload != null)
@@ -211,6 +214,8 @@ namespace Discord.API
private async Task<Stream> SendMultipartInternalAsync(string method, string endpoint, IReadOnlyDictionary<string, object> multipartArgs, bool headerOnly, private async Task<Stream> SendMultipartInternalAsync(string method, string endpoint, IReadOnlyDictionary<string, object> multipartArgs, bool headerOnly,
BucketGroup group, int bucketId, ulong guildId, RequestOptions options = null) BucketGroup group, int bucketId, ulong guildId, RequestOptions options = null)
{ {
CheckState();

var stopwatch = Stopwatch.StartNew(); var stopwatch = Stopwatch.StartNew();
var responseStream = await RequestQueue.SendAsync(new RestRequest(_restClient, method, endpoint, multipartArgs, headerOnly, options), group, bucketId, guildId).ConfigureAwait(false); var responseStream = await RequestQueue.SendAsync(new RestRequest(_restClient, method, endpoint, multipartArgs, headerOnly, options), group, bucketId, guildId).ConfigureAwait(false);
int bytes = headerOnly ? 0 : (int)responseStream.Length; int bytes = headerOnly ? 0 : (int)responseStream.Length;
@@ -1039,6 +1044,11 @@ namespace Discord.API
} }


//Helpers //Helpers
protected void CheckState()
{
if (LoginState != LoginState.LoggedIn)
throw new InvalidOperationException("Client is not logged in.");
}
protected static double ToMilliseconds(Stopwatch stopwatch) => Math.Round((double)stopwatch.ElapsedTicks / (double)Stopwatch.Frequency * 1000.0, 2); protected static double ToMilliseconds(Stopwatch stopwatch) => Math.Round((double)stopwatch.ElapsedTicks / (double)Stopwatch.Frequency * 1000.0, 2);
protected string SerializeJson(object value) protected string SerializeJson(object value)
{ {


+ 11
- 8
src/Discord.Net/API/DiscordRpcAPIClient.cs View File

@@ -140,7 +140,7 @@ namespace Discord.API
internal override async Task ConnectInternalAsync() internal override async Task ConnectInternalAsync()
{ {
/*if (LoginState != LoginState.LoggedIn) /*if (LoginState != LoginState.LoggedIn)
throw new InvalidOperationException("You must log in before connecting.");*/
throw new InvalidOperationException("Client is not logged in.");*/


ConnectionState = ConnectionState.Connecting; ConnectionState = ConnectionState.Connecting;
try try
@@ -207,17 +207,20 @@ namespace Discord.API


//Core //Core
public Task<TResponse> SendRpcAsync<TResponse>(string cmd, object payload, GlobalBucket bucket = GlobalBucket.GeneralRpc, public Task<TResponse> SendRpcAsync<TResponse>(string cmd, object payload, GlobalBucket bucket = GlobalBucket.GeneralRpc,
Optional<string> evt = default(Optional<string>), RequestOptions options = null)
Optional<string> evt = default(Optional<string>), bool ignoreState = false, RequestOptions options = null)
where TResponse : class where TResponse : class
=> SendRpcAsyncInternal<TResponse>(cmd, payload, BucketGroup.Global, (int)bucket, 0, evt, options);
=> SendRpcAsyncInternal<TResponse>(cmd, payload, BucketGroup.Global, (int)bucket, 0, evt, ignoreState, options);
public Task<TResponse> SendRpcAsync<TResponse>(string cmd, object payload, GuildBucket bucket, ulong guildId, public Task<TResponse> SendRpcAsync<TResponse>(string cmd, object payload, GuildBucket bucket, ulong guildId,
Optional<string> evt = default(Optional<string>), RequestOptions options = null)
Optional<string> evt = default(Optional<string>), bool ignoreState = false, RequestOptions options = null)
where TResponse : class where TResponse : class
=> SendRpcAsyncInternal<TResponse>(cmd, payload, BucketGroup.Guild, (int)bucket, guildId, evt, options);
=> SendRpcAsyncInternal<TResponse>(cmd, payload, BucketGroup.Guild, (int)bucket, guildId, evt, ignoreState, options);
private async Task<TResponse> SendRpcAsyncInternal<TResponse>(string cmd, object payload, BucketGroup group, int bucketId, ulong guildId, private async Task<TResponse> SendRpcAsyncInternal<TResponse>(string cmd, object payload, BucketGroup group, int bucketId, ulong guildId,
Optional<string> evt, RequestOptions options)
Optional<string> evt, bool ignoreState, RequestOptions options)
where TResponse : class where TResponse : class
{ {
if (!ignoreState)
CheckState();

byte[] bytes = null; byte[] bytes = null;
var guid = Guid.NewGuid(); var guid = Guid.NewGuid();
payload = new API.Rpc.RpcMessage { Cmd = cmd, Event = evt, Args = payload, Nonce = guid }; payload = new API.Rpc.RpcMessage { Cmd = cmd, Event = evt, Args = payload, Nonce = guid };
@@ -242,7 +245,7 @@ namespace Discord.API
{ {
AccessToken = _authToken AccessToken = _authToken
}; };
return await SendRpcAsync<AuthenticateResponse>("AUTHENTICATE", msg, options: options).ConfigureAwait(false);
return await SendRpcAsync<AuthenticateResponse>("AUTHENTICATE", msg, ignoreState: true, options: options).ConfigureAwait(false);
} }
public async Task<AuthorizeResponse> SendAuthorizeAsync(string[] scopes, string rpcToken = null, RequestOptions options = null) public async Task<AuthorizeResponse> SendAuthorizeAsync(string[] scopes, string rpcToken = null, RequestOptions options = null)
{ {
@@ -256,7 +259,7 @@ namespace Discord.API
options = new RequestOptions(); options = new RequestOptions();
if (options.Timeout == null) if (options.Timeout == null)
options.Timeout = 60000; //This requires manual input on the user's end, lets give them more time options.Timeout = 60000; //This requires manual input on the user's end, lets give them more time
return await SendRpcAsync<AuthorizeResponse>("AUTHORIZE", msg, options: options).ConfigureAwait(false);
return await SendRpcAsync<AuthorizeResponse>("AUTHORIZE", msg, ignoreState: true, options: options).ConfigureAwait(false);
} }


public async Task<GetGuildsResponse> SendGetGuildsAsync(RequestOptions options = null) public async Task<GetGuildsResponse> SendGetGuildsAsync(RequestOptions options = null)


+ 2
- 0
src/Discord.Net/API/DiscordSocketApiClient.cs View File

@@ -159,6 +159,8 @@ namespace Discord.API
private async Task SendGatewayInternalAsync(GatewayOpCode opCode, object payload, private async Task SendGatewayInternalAsync(GatewayOpCode opCode, object payload,
BucketGroup group, int bucketId, ulong guildId, RequestOptions options) BucketGroup group, int bucketId, ulong guildId, RequestOptions options)
{ {
CheckState();

//TODO: Add ETF //TODO: Add ETF
byte[] bytes = null; byte[] bytes = null;
payload = new WebSocketMessage { Operation = (int)opCode, Payload = payload }; payload = new WebSocketMessage { Operation = (int)opCode, Payload = payload };


+ 1
- 1
src/Discord.Net/WebSocket/DiscordSocketClient.cs View File

@@ -144,7 +144,7 @@ namespace Discord.WebSocket
private async Task ConnectInternalAsync(bool isReconnecting) private async Task ConnectInternalAsync(bool isReconnecting)
{ {
if (LoginState != LoginState.LoggedIn) if (LoginState != LoginState.LoggedIn)
throw new InvalidOperationException("You must log in before connecting.");
throw new InvalidOperationException("Client is not logged in.");


if (!isReconnecting && _reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested) if (!isReconnecting && _reconnectCancelToken != null && !_reconnectCancelToken.IsCancellationRequested)
_reconnectCancelToken.Cancel(); _reconnectCancelToken.Cancel();


Loading…
Cancel
Save