diff --git a/src/Discord.Net.Audio/Net/WebSockets/VoiceWebSocket.cs b/src/Discord.Net.Audio/Net/WebSockets/VoiceWebSocket.cs index da03bff35..96c5b962e 100644 --- a/src/Discord.Net.Audio/Net/WebSockets/VoiceWebSocket.cs +++ b/src/Discord.Net.Audio/Net/WebSockets/VoiceWebSocket.cs @@ -177,7 +177,7 @@ namespace Discord.Net.WebSockets if (packetLength > 0 && endpoint.Equals(_endpoint)) { - if (_state != (int)ConnectionState.Connected) + if (_state != ConnectionState.Connected) { if (packetLength != 70) return; @@ -247,7 +247,7 @@ namespace Discord.Net.WebSockets { try { - while (!cancelToken.IsCancellationRequested && _state != (int)ConnectionState.Connected) + while (!cancelToken.IsCancellationRequested && _state != ConnectionState.Connected) Thread.Sleep(1); if (cancelToken.IsCancellationRequested) @@ -399,7 +399,7 @@ namespace Discord.Net.WebSockets { case VoiceOpCodes.Ready: { - if (_state != (int)ConnectionState.Connected) + if (_state != ConnectionState.Connected) { var payload = (msg.Payload as JToken).ToObject(_serializer); _heartbeatInterval = payload.HeartbeatInterval; @@ -488,7 +488,7 @@ namespace Discord.Net.WebSockets } catch (OperationCanceledException) { - ThrowError(); + _taskManager.ThrowException(); } }); } diff --git a/src/Discord.Net/DiscordClient.cs b/src/Discord.Net/DiscordClient.cs index 5fb6450d4..ad68f805a 100644 --- a/src/Discord.Net/DiscordClient.cs +++ b/src/Discord.Net/DiscordClient.cs @@ -68,8 +68,8 @@ namespace Discord private readonly DiscordConfig _config; /// Returns the current connection state of this client. - public ConnectionState State => (ConnectionState)_state; - private int _state; + public ConnectionState State => _state; + private ConnectionState _state; /// Gives direct access to the underlying DiscordAPIClient. This can be used to modify objects not in cache. public DiscordAPIClient APIClient => _api; @@ -145,7 +145,7 @@ namespace Discord var settings = new JsonSerializerSettings(); _webSocket.Connected += (s, e) => { - if (_state == (int)ConnectionState.Connecting) + if (_state == ConnectionState.Connecting) EndConnect(); }; _webSocket.Disconnected += (s, e) => @@ -277,7 +277,7 @@ namespace Discord if (!_sentInitialLog) SendInitialLog(); - if (State != (int)ConnectionState.Disconnected) + if (State != ConnectionState.Disconnected) await Disconnect().ConfigureAwait(false); _token = token; @@ -293,7 +293,8 @@ namespace Discord try { await _taskManager.Stop().ConfigureAwait(false); - _state = (int)ConnectionState.Connecting; + _taskManager.ClearException(); + _state = ConnectionState.Connecting; var gatewayResponse = await _api.Gateway().ConfigureAwait(false); string gateway = gatewayResponse.Url; @@ -326,7 +327,7 @@ namespace Discord } catch (OperationCanceledException) { - _webSocket.ThrowError(); //Throws data socket's internal error if any occured + _webSocket.TaskManager.ThrowException(); //Throws data socket's internal error if any occured throw; } } @@ -335,15 +336,15 @@ namespace Discord _lock.Release(); } } - catch + catch (Exception ex) { - await Disconnect().ConfigureAwait(false); + _taskManager.SignalError(ex, true); throw; } } private void EndConnect() { - _state = (int)ConnectionState.Connected; + _state = ConnectionState.Connected; _connectedEvent.Set(); RaiseConnected(); } @@ -352,8 +353,9 @@ namespace Discord public Task Disconnect() => _taskManager.Stop(); private async Task Cleanup() - { - if (Config.UseMessageQueue) + { + _state = ConnectionState.Disconnecting; + if (Config.UseMessageQueue) { MessageQueueItem ignored; while (_pendingMessages.TryDequeue(out ignored)) { } @@ -373,8 +375,8 @@ namespace Discord _token = null; _state = (int)ConnectionState.Disconnected; - _disconnectedEvent.Set(); _connectedEvent.Reset(); + _disconnectedEvent.Set(); } private void OnReceivedEvent(WebSocketEventEventArgs e) @@ -822,11 +824,11 @@ namespace Discord { switch (_state) { - case (int)ConnectionState.Disconnecting: + case ConnectionState.Disconnecting: throw new InvalidOperationException("The client is disconnecting."); - case (int)ConnectionState.Disconnected: + case ConnectionState.Disconnected: throw new InvalidOperationException("The client is not connected to Discord"); - case (int)ConnectionState.Connecting: + case ConnectionState.Connecting: throw new InvalidOperationException("The client is connecting."); } } diff --git a/src/Discord.Net/Helpers/TaskManager.cs b/src/Discord.Net/Helpers/TaskManager.cs index 33d9264e2..4f87ebc28 100644 --- a/src/Discord.Net/Helpers/TaskManager.cs +++ b/src/Discord.Net/Helpers/TaskManager.cs @@ -16,8 +16,8 @@ namespace Discord private CancellationTokenSource _cancelSource; private Task _task; - public bool WasUnexpected => _wasUnexpected; - private bool _wasUnexpected; + public bool WasUnexpected => _wasStopUnexpected; + private bool _wasStopUnexpected; public Exception Exception => _stopReason.SourceException; private ExceptionDispatchInfo _stopReason; @@ -53,7 +53,7 @@ namespace Discord continue; //Another thread sneaked in and started this manager before we got a lock, loop and try again _stopReason = null; - _wasUnexpected = false; + _wasStopUnexpected = false; Task[] tasksArray = tasks.ToArray(); Task anyTask = Task.WhenAny(tasksArray); @@ -74,8 +74,10 @@ namespace Discord await allTasks.ConfigureAwait(false); //Run the cleanup function within our lock - await _stopAction().ConfigureAwait(false); + if (_stopAction != null) + await _stopAction().ConfigureAwait(false); _task = null; + _cancelSource = null; }); return; } @@ -89,7 +91,8 @@ namespace Discord if (_task == null) return; //Are we running? if (_cancelSource.IsCancellationRequested) return; - _cancelSource.Cancel(); + if (_cancelSource != null) + _cancelSource.Cancel(); } } public Task Stop() @@ -102,7 +105,8 @@ namespace Discord if (task == null) return TaskHelper.CompletedTask; //Are we running? if (_cancelSource.IsCancellationRequested) return task; - _cancelSource.Cancel(); + if (_cancelSource != null) + _cancelSource.Cancel(); } return task; } @@ -111,11 +115,12 @@ namespace Discord { lock (_lock) { - if (_task == null) return; //Are we running? + if (_stopReason != null) return; _stopReason = ExceptionDispatchInfo.Capture(ex); - _wasUnexpected = isUnexpected; - _cancelSource.Cancel(); + _wasStopUnexpected = isUnexpected; + if (_cancelSource != null) + _cancelSource.Cancel(); } } public Task Error(Exception ex, bool isUnexpected = true) @@ -123,20 +128,22 @@ namespace Discord Task task; lock (_lock) { + if (_stopReason != null) return TaskHelper.CompletedTask; + //Cache the task so we still have something to await if Cleanup is run really quickly - task = _task; - if (task == null) return TaskHelper.CompletedTask; //Are we running? + task = _task ?? TaskHelper.CompletedTask; if (_cancelSource.IsCancellationRequested) return task; _stopReason = ExceptionDispatchInfo.Capture(ex); - _wasUnexpected = isUnexpected; - _cancelSource.Cancel(); + _wasStopUnexpected = isUnexpected; + if (_cancelSource != null) + _cancelSource.Cancel(); } return task; } /// Throws an exception if one was captured. - public void Throw() + public void ThrowException() { lock (_lock) { @@ -144,5 +151,13 @@ namespace Discord _stopReason.Throw(); } } + public void ClearException() + { + lock (_lock) + { + _stopReason = null; + _wasStopUnexpected = false; + } + } } } diff --git a/src/Discord.Net/Net/WebSockets/WS4NetEngine.cs b/src/Discord.Net/Net/WebSockets/WS4NetEngine.cs index a8e625191..c9f323d3f 100644 --- a/src/Discord.Net/Net/WebSockets/WS4NetEngine.cs +++ b/src/Discord.Net/Net/WebSockets/WS4NetEngine.cs @@ -57,6 +57,7 @@ namespace Discord.Net.WebSockets _waitUntilConnect.Reset(); _webSocket.Open(); _waitUntilConnect.Wait(cancelToken); + _parent.TaskManager.ThrowException(); //In case our connection failed return TaskHelper.CompletedTask; } diff --git a/src/Discord.Net/Net/WebSockets/WebSocket.cs b/src/Discord.Net/Net/WebSockets/WebSocket.cs index f932e343b..e4708dbb7 100644 --- a/src/Discord.Net/Net/WebSockets/WebSocket.cs +++ b/src/Discord.Net/Net/WebSockets/WebSocket.cs @@ -34,8 +34,8 @@ namespace Discord.Net.WebSockets public string Host { get { return _host; } set { _host = value; } } private string _host; - public ConnectionState State => (ConnectionState)_state; - protected int _state; + public ConnectionState State => _state; + protected ConnectionState _state; public event EventHandler Connected; private void RaiseConnected() @@ -104,8 +104,9 @@ namespace Discord.Net.WebSockets _lock.WaitOne(); try { - await _taskManager.Stop().ConfigureAwait(false); - _state = (int)ConnectionState.Connecting; + await _taskManager.Stop().ConfigureAwait(false); + _taskManager.ClearException(); + _state = ConnectionState.Connecting; _cancelTokenSource = new CancellationTokenSource(); _cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_cancelTokenSource.Token, ParentCancelToken.Value).Token; @@ -122,13 +123,14 @@ namespace Discord.Net.WebSockets catch (Exception ex) { _taskManager.SignalError(ex, true); + throw; } } protected void EndConnect() { try { - _state = (int)ConnectionState.Connected; + _state = ConnectionState.Connected; _connectedEvent.Set(); RaiseConnected(); @@ -142,17 +144,17 @@ namespace Discord.Net.WebSockets protected abstract Task Run(); protected virtual async Task Cleanup() { - await _engine.Disconnect().ConfigureAwait(false); + var oldState = _state; + _state = ConnectionState.Disconnecting; + + await _engine.Disconnect().ConfigureAwait(false); _cancelTokenSource = null; - var oldState = _state; _connectedEvent.Reset(); - if (oldState == (int)ConnectionState.Connected) - { - _state = (int)ConnectionState.Disconnected; + if (oldState == ConnectionState.Connected) RaiseDisconnected(_taskManager.WasUnexpected, _taskManager.Exception); - } - } + _state = ConnectionState.Disconnected; + } protected virtual Task ProcessMessage(string json) { @@ -176,23 +178,18 @@ namespace Discord.Net.WebSockets { while (!cancelToken.IsCancellationRequested) { - if (_state == (int)ConnectionState.Connected) + if (_state == ConnectionState.Connected) { SendHeartbeat(); await Task.Delay(_heartbeatInterval, cancelToken).ConfigureAwait(false); } else - await Task.Delay(100, cancelToken).ConfigureAwait(false); + await Task.Delay(1000, cancelToken).ConfigureAwait(false); } } catch (OperationCanceledException) { } }); } public abstract void SendHeartbeat(); - - protected internal void ThrowError() - { - _taskManager.Throw(); - } } }