diff --git a/src/Discord.Net/Net/WebSockets/DefaultWebsocketClient.cs b/src/Discord.Net/Net/WebSockets/DefaultWebsocketClient.cs index d9c518874..67754843c 100644 --- a/src/Discord.Net/Net/WebSockets/DefaultWebsocketClient.cs +++ b/src/Discord.Net/Net/WebSockets/DefaultWebsocketClient.cs @@ -51,8 +51,19 @@ namespace Discord.Net.WebSockets public async Task ConnectAsync(string host) { - //Assume locked - await DisconnectAsync().ConfigureAwait(false); + await _sendLock.WaitAsync(_cancelToken).ConfigureAwait(false); + try + { + await ConnectInternalAsync(host); + } + finally + { + _sendLock.Release(); + } + } + private async Task ConnectInternalAsync(string host) + { + await DisconnectInternalAsync().ConfigureAwait(false); _cancelTokenSource = new CancellationTokenSource(); _cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_parentToken, _cancelTokenSource.Token).Token; @@ -69,19 +80,30 @@ namespace Discord.Net.WebSockets await _client.ConnectAsync(new Uri(host), _cancelToken).ConfigureAwait(false); _task = RunAsync(_cancelToken); } + public async Task DisconnectAsync() { - //Assume locked + await _sendLock.WaitAsync(_cancelToken).ConfigureAwait(false); + try + { + await DisconnectInternalAsync(); + } + finally + { + _sendLock.Release(); + } + } + private async Task DisconnectInternalAsync() + { try { _cancelTokenSource.Cancel(false); } catch { } - + + await (_task ?? Task.CompletedTask).ConfigureAwait(false); + if (_client != null && _client.State == WebSocketState.Open) { - var task = _client?.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); - if (task != null) - await task.ConfigureAwait(false); + _client.Dispose(); + _client = null; } - - await (_task ?? Task.CompletedTask).ConfigureAwait(false); } public void SetHeader(string key, string value) @@ -99,7 +121,8 @@ namespace Discord.Net.WebSockets await _sendLock.WaitAsync(_cancelToken).ConfigureAwait(false); try { - //TODO: If connection is temporarily down, retry? + if (_client == null) return; + int frameCount = (int)Math.Ceiling((double)count / SendChunkSize); for (int i = 0; i < frameCount; i++, index += SendChunkSize) @@ -111,16 +134,9 @@ namespace Discord.Net.WebSockets frameSize = count - (i * SendChunkSize); else frameSize = SendChunkSize; - - try - { - var type = isText ? WebSocketMessageType.Text : WebSocketMessageType.Binary; - await _client.SendAsync(new ArraySegment(data, index, count), type, isLast, _cancelToken).ConfigureAwait(false); - } - catch (Win32Exception ex) when (ex.HResult == HR_TIMEOUT) - { - return; - } + + var type = isText ? WebSocketMessageType.Text : WebSocketMessageType.Binary; + await _client.SendAsync(new ArraySegment(data, index, count), type, isLast, _cancelToken).ConfigureAwait(false); } } finally