From 3fe7124f4decd988247de55539f5c945164660fe Mon Sep 17 00:00:00 2001 From: RogueException Date: Tue, 8 Dec 2015 09:35:18 -0400 Subject: [PATCH] Started WebSocket cleanup --- src/Discord.Net.Audio/DiscordAudioClient.cs | 2 +- .../Net/WebSockets/VoiceWebSocket.cs | 56 ++++------- src/Discord.Net/DiscordClient.cs | 8 +- src/Discord.Net/Net/WebSockets/GatewayWebSocket.cs | 35 ++++--- src/Discord.Net/Net/WebSockets/WebSocket.cs | 103 ++++++++++----------- .../Net/WebSockets/WebSocketSharpEngine.cs | 6 +- 6 files changed, 100 insertions(+), 110 deletions(-) diff --git a/src/Discord.Net.Audio/DiscordAudioClient.cs b/src/Discord.Net.Audio/DiscordAudioClient.cs index 773a02761..791321188 100644 --- a/src/Discord.Net.Audio/DiscordAudioClient.cs +++ b/src/Discord.Net.Audio/DiscordAudioClient.cs @@ -74,7 +74,7 @@ namespace Discord.Audio var client = _service.Client; string token = e.Payload.Value("token"); _voiceSocket.Host = "wss://" + e.Payload.Value("endpoint").Split(':')[0]; - await _voiceSocket.Login(client.CurrentUser.Id, _gatewaySocket.SessionId, token, client.CancelToken).ConfigureAwait(false); + await _voiceSocket.Connect(client.CurrentUser.Id, _gatewaySocket.SessionId, token/*, client.CancelToken*/).ConfigureAwait(false); } } break; diff --git a/src/Discord.Net.Audio/Net/WebSockets/VoiceWebSocket.cs b/src/Discord.Net.Audio/Net/WebSockets/VoiceWebSocket.cs index 0a44f773b..93c16572c 100644 --- a/src/Discord.Net.Audio/Net/WebSockets/VoiceWebSocket.cs +++ b/src/Discord.Net.Audio/Net/WebSockets/VoiceWebSocket.cs @@ -58,20 +58,13 @@ namespace Discord.Net.WebSockets _sendBuffer = new VoiceBuffer((int)Math.Ceiling(_audioConfig.BufferLength / (double)_encoder.FrameLength), _encoder.FrameSize); } - public async Task Login(long userId, string sessionId, string token, CancellationToken cancelToken) - { - if ((WebSocketState)_state == WebSocketState.Connected) - { - //Adjust the host and tell the system to reconnect - await DisconnectInternal(new Exception("Server transfer occurred."), isUnexpected: false).ConfigureAwait(false); - return; - } - + public async Task Connect(long userId, string sessionId, string token) + { _userId = userId; _sessionId = sessionId; _token = token; - - await Start().ConfigureAwait(false); + + await BeginConnect().ConfigureAwait(false); } public async Task Reconnect() { @@ -83,9 +76,7 @@ namespace Discord.Net.WebSockets { try { - //This check is needed in case we start a reconnect before the initial login completes - if (_state != (int)WebSocketState.Disconnected) - await Start().ConfigureAwait(false); + await Connect(_userId.Value, _sessionId, _token).ConfigureAwait(false); break; } catch (OperationCanceledException) { throw; } @@ -99,48 +90,39 @@ namespace Discord.Net.WebSockets } catch (OperationCanceledException) { } } + public Task Disconnect() + { + return SignalDisconnect(wait: true); + } - protected override IEnumerable GetTasks() + protected override async Task Run() { _udp = new UdpClient(new IPEndPoint(IPAddress.Any, 0)); - SendIdentify(); - List tasks = new List(); if ((_audioConfig.Mode & AudioMode.Outgoing) != 0) { _sendThread = new Thread(new ThreadStart(() => SendVoiceAsync(_cancelToken))); _sendThread.IsBackground = true; _sendThread.Start(); - } - - //This thread is required to establish a connection even if we're outgoing only + } if ((_audioConfig.Mode & AudioMode.Incoming) != 0) { _receiveThread = new Thread(new ThreadStart(() => ReceiveVoiceAsync(_cancelToken))); _receiveThread.IsBackground = true; _receiveThread.Start(); } - else //Dont make an OS thread if we only want to capture one packet... - tasks.Add(Task.Run(() => ReceiveVoiceAsync(_cancelToken))); + + SendIdentify(); #if !DOTNET5_4 tasks.Add(WatcherAsync()); #endif - if (tasks.Count > 0) - { - // We need to combine tasks into one because receiveThread is - // supposed to exit early if it's an outgoing-only client - // and we dont want the main thread to think we errored - var task = Task.WhenAll(tasks); - tasks.Clear(); - tasks.Add(task); - } - tasks.AddRange(base.GetTasks()); - - return new Task[] { Task.WhenAll(tasks.ToArray()) }; + await RunTasks(tasks.ToArray()); + + await Cleanup(); } - protected override Task Stop() + protected override Task Cleanup() { if (_sendThread != null) _sendThread.Join(); @@ -165,7 +147,7 @@ namespace Discord.Net.WebSockets } _udp = null; - return base.Stop(); + return base.Cleanup(); } private void ReceiveVoiceAsync(CancellationToken cancelToken) @@ -474,7 +456,7 @@ namespace Discord.Net.WebSockets var payload = (msg.Payload as JToken).ToObject(_serializer); _secretKey = payload.SecretKey; SendIsTalking(true); - EndConnect(); + await EndConnect(); } break; case VoiceOpCodes.Speaking: diff --git a/src/Discord.Net/DiscordClient.cs b/src/Discord.Net/DiscordClient.cs index edb4fa660..4c4c82866 100644 --- a/src/Discord.Net/DiscordClient.cs +++ b/src/Discord.Net/DiscordClient.cs @@ -266,11 +266,9 @@ namespace Discord if (_state == (int)DiscordClientState.Connecting) CompleteConnect(); }; - socket.Disconnected += async (s, e) => + socket.Disconnected += (s, e) => { RaiseDisconnected(e); - if (e.WasUnexpected) - await socket.Reconnect(_token).ConfigureAwait(false); }; socket.ReceivedDispatch += async (s, e) => await OnReceivedEvent(e).ConfigureAwait(false); @@ -329,7 +327,7 @@ namespace Discord _webSocket.Host = gateway; _webSocket.ParentCancelToken = _cancelToken; - await _webSocket.Login(token).ConfigureAwait(false); + await _webSocket.Connect(token).ConfigureAwait(false); _runTask = RunTasks(); @@ -422,7 +420,7 @@ namespace Discord var wasDisconnectUnexpected = _wasDisconnectUnexpected; _wasDisconnectUnexpected = false; - await _webSocket.Disconnect().ConfigureAwait(false); + await _webSocket.SignalDisconnect().ConfigureAwait(false); _userId = null; _gateway = null; diff --git a/src/Discord.Net/Net/WebSockets/GatewayWebSocket.cs b/src/Discord.Net/Net/WebSockets/GatewayWebSocket.cs index b6c292268..441de8f40 100644 --- a/src/Discord.Net/Net/WebSockets/GatewayWebSocket.cs +++ b/src/Discord.Net/Net/WebSockets/GatewayWebSocket.cs @@ -8,7 +8,11 @@ namespace Discord.Net.WebSockets { public partial class GatewayWebSocket : WebSocket { - private int _lastSeq; + public int LastSequence => _lastSeq; + private int _lastSeq; + + public string Token => _token; + private string _token; public string SessionId => _sessionId; private string _sessionId; @@ -16,25 +20,25 @@ namespace Discord.Net.WebSockets public GatewayWebSocket(DiscordConfig config, Logger logger) : base(config, logger) { + Disconnected += async (s, e) => + { + if (e.WasUnexpected) + await Reconnect().ConfigureAwait(false); + }; } - public async Task Login(string token) + public async Task Connect(string token) { + _token = token; await BeginConnect().ConfigureAwait(false); - await Start().ConfigureAwait(false); - SendIdentify(token); } private async Task Redirect(string server) { - await DisconnectInternal(isUnexpected: false).ConfigureAwait(false); - await BeginConnect().ConfigureAwait(false); - await Start().ConfigureAwait(false); - SendResume(); } - public async Task Reconnect(string token) + private async Task Reconnect() { try { @@ -44,7 +48,7 @@ namespace Discord.Net.WebSockets { try { - await Login(token).ConfigureAwait(false); + await Connect(_token).ConfigureAwait(false); break; } catch (OperationCanceledException) { throw; } @@ -58,6 +62,15 @@ namespace Discord.Net.WebSockets } catch (OperationCanceledException) { } } + public Task Disconnect() + { + return SignalDisconnect(wait: true); + } + + protected override async Task Run() + { + await RunTasks(); + } protected override async Task ProcessMessage(string json) { @@ -85,7 +98,7 @@ namespace Discord.Net.WebSockets } RaiseReceivedDispatch(msg.Type, token); if (msg.Type == "READY" || msg.Type == "RESUMED") - EndConnect(); + await EndConnect(); //Complete the connect } break; case GatewayOpCodes.Redirect: diff --git a/src/Discord.Net/Net/WebSockets/WebSocket.cs b/src/Discord.Net/Net/WebSockets/WebSocket.cs index e6a269b29..b0352c067 100644 --- a/src/Discord.Net/Net/WebSockets/WebSocket.cs +++ b/src/Discord.Net/Net/WebSockets/WebSocket.cs @@ -114,38 +114,50 @@ namespace Discord.Net.WebSockets { try { - await Disconnect().ConfigureAwait(false); + await SignalDisconnect(wait: true).ConfigureAwait(false); + _state = (int)WebSocketState.Connecting; if (ParentCancelToken == null) throw new InvalidOperationException("Parent cancel token was never set."); _cancelTokenSource = new CancellationTokenSource(); _cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_cancelTokenSource.Token, ParentCancelToken.Value).Token; - _state = (int)WebSocketState.Connecting; + if (_state != (int)WebSocketState.Connecting) + throw new InvalidOperationException("Socket is in the wrong state."); + + _lastHeartbeat = DateTime.UtcNow; + await _engine.Connect(Host, _cancelToken).ConfigureAwait(false); + + _runTask = Run(); } catch (Exception ex) { - await DisconnectInternal(ex, isUnexpected: false).ConfigureAwait(false); + await SignalDisconnect(ex, isUnexpected: false).ConfigureAwait(false); throw; } } - protected void EndConnect() + protected async Task EndConnect() { - _state = (int)WebSocketState.Connected; - _connectedEvent.Set(); - RaiseConnected(); - } + try + { + _state = (int)WebSocketState.Connected; - public Task Disconnect() => DisconnectInternal(new Exception("Disconnect was requested by user."), isUnexpected: false); - protected internal async Task DisconnectInternal(Exception ex = null, bool isUnexpected = true, bool skipAwait = false) + _connectedEvent.Set(); + RaiseConnected(); + } + catch (Exception ex) + { + await SignalDisconnect(ex, isUnexpected: false).ConfigureAwait(false); + throw; + } + } + + protected internal async Task SignalDisconnect(Exception ex = null, bool isUnexpected = false, bool wait = false) { - int oldState; - bool hasWriterLock; - //If in either connecting or connected state, get a lock by being the first to switch to disconnecting - oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connecting); + int oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connecting); if (oldState == (int)WebSocketState.Disconnected) return; //Already disconnected - hasWriterLock = oldState == (int)WebSocketState.Connecting; //Caused state change + bool hasWriterLock = oldState == (int)WebSocketState.Connecting; //Caused state change if (!hasWriterLock) { oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connected); @@ -155,70 +167,55 @@ namespace Discord.Net.WebSockets if (hasWriterLock) { - _wasDisconnectUnexpected = isUnexpected; - _disconnectState = (WebSocketState)oldState; - _disconnectReason = ex != null ? ExceptionDispatchInfo.Capture(ex) : null; - + CaptureError(ex ?? new Exception("Disconnect was requested."), isUnexpected); _cancelTokenSource.Cancel(); if (_disconnectState == WebSocketState.Connecting) //_runTask was never made - await Stop().ConfigureAwait(false); + await Cleanup().ConfigureAwait(false); } - if (!skipAwait) + if (!wait) { Task task = _runTask; if (_runTask != null) await task.ConfigureAwait(false); } } - - protected virtual async Task Start() + private void CaptureError(Exception ex, bool isUnexpected) { - try - { - if (_state != (int)WebSocketState.Connecting) - throw new InvalidOperationException("Socket is in the wrong state."); - - _lastHeartbeat = DateTime.UtcNow; - await _engine.Connect(Host, _cancelToken).ConfigureAwait(false); - - _runTask = RunTasks(); - } - catch (Exception ex) - { - await DisconnectInternal(ex, isUnexpected: false).ConfigureAwait(false); - throw; - } + _disconnectReason = ExceptionDispatchInfo.Capture(ex); + _wasDisconnectUnexpected = isUnexpected; } - protected virtual async Task RunTasks() + protected abstract Task Run(); + protected async Task RunTasks(params Task[] tasks) { - Task[] tasks = GetTasks().ToArray(); + //Get all async tasks + tasks = tasks + .Concat(_engine.GetTasks(_cancelToken)) + .Concat(new Task[] { HeartbeatAsync(_cancelToken) }) + .ToArray(); + + //Create group tasks Task firstTask = Task.WhenAny(tasks); Task allTasks = Task.WhenAll(tasks); //Wait until the first task ends/errors and capture the error - try { await firstTask.ConfigureAwait(false); } - catch (Exception ex) { await DisconnectInternal(ex: ex, skipAwait: true).ConfigureAwait(false); } + Exception ex = null; + try { await firstTask.ConfigureAwait(false); } + catch (Exception ex2) { ex = ex2; } //Ensure all other tasks are signaled to end. - await DisconnectInternal(skipAwait: true).ConfigureAwait(false); + await SignalDisconnect(ex, ex != null, true).ConfigureAwait(false); //Wait for the remaining tasks to complete try { await allTasks.ConfigureAwait(false); } catch { } //Start cleanup - await Stop().ConfigureAwait(false); - } - protected virtual IEnumerable GetTasks() - { - var cancelToken = _cancelToken; - return _engine.GetTasks(cancelToken) - .Concat(new Task[] { HeartbeatAsync(cancelToken) }); + await Cleanup().ConfigureAwait(false); } - protected virtual async Task Stop() + protected virtual async Task Cleanup() { var disconnectState = _disconnectState; _disconnectState = WebSocketState.Disconnected; @@ -254,7 +251,7 @@ namespace Discord.Net.WebSockets private Task HeartbeatAsync(CancellationToken cancelToken) { - return Task.Run((Func)(async () => + return Task.Run(async () => { try { @@ -270,7 +267,7 @@ namespace Discord.Net.WebSockets } } catch (OperationCanceledException) { } - })); + }); } protected internal void ThrowError() diff --git a/src/Discord.Net/Net/WebSockets/WebSocketSharpEngine.cs b/src/Discord.Net/Net/WebSockets/WebSocketSharpEngine.cs index 7640a71a6..e4933cec6 100644 --- a/src/Discord.Net/Net/WebSockets/WebSocketSharpEngine.cs +++ b/src/Discord.Net/Net/WebSockets/WebSocketSharpEngine.cs @@ -55,14 +55,14 @@ namespace Discord.Net.WebSockets _webSocket.OnError += async (s, e) => { _logger.Log(LogSeverity.Error, "WebSocket Error", e.Exception); - await _parent.DisconnectInternal(e.Exception, skipAwait: true).ConfigureAwait(false); + await _parent.SignalDisconnect(e.Exception, isUnexpected: true).ConfigureAwait(false); }; _webSocket.OnClose += async (s, e) => { string code = e.WasClean ? e.Code.ToString() : "Unexpected"; string reason = e.Reason != "" ? e.Reason : "No Reason"; - Exception ex = new Exception($"Got Close Message ({code}): {reason}"); - await _parent.DisconnectInternal(ex, skipAwait: true).ConfigureAwait(false); + var ex = new Exception($"Got Close Message ({code}): {reason}"); + await _parent.SignalDisconnect(ex, isUnexpected: true).ConfigureAwait(false); }; _webSocket.Log.Output = (e, m) => { }; //Dont let websocket-sharp print to console directly _webSocket.Connect();