diff --git a/src/Discord.Net.Audio/AudioClient.cs b/src/Discord.Net.Audio/AudioClient.cs index 20aaed2db..b2ca4120c 100644 --- a/src/Discord.Net.Audio/AudioClient.cs +++ b/src/Discord.Net.Audio/AudioClient.cs @@ -59,24 +59,17 @@ namespace Discord.Audio public AudioClient(AudioService service, int clientId, Server server, GatewaySocket gatewaySocket, Logger logger) { Service = service; + _serializer = service.Client.Serializer; Id = clientId; GatewaySocket = gatewaySocket; Logger = logger; OutputStream = new OutStream(this); - _connectionLock = new AsyncLock(); - - _serializer = new JsonSerializer(); - _serializer.DateTimeZoneHandling = DateTimeZoneHandling.Utc; - _serializer.Error += (s, e) => - { - e.ErrorContext.Handled = true; - Logger.Error("Serialization Failed", e.ErrorContext.Error); - }; + _connectionLock = new AsyncLock(); GatewaySocket.ReceivedDispatch += OnReceivedDispatch; - VoiceSocket = new VoiceWebSocket(service.Client, this, _serializer, logger); + VoiceSocket = new VoiceWebSocket(service.Client, this, logger); VoiceSocket.Server = server; /*_voiceSocket.Connected += (s, e) => RaiseVoiceConnected(); @@ -124,27 +117,41 @@ namespace Discord.Audio if (channel == VoiceSocket.Channel) return; if (VoiceSocket.Server == null) throw new InvalidOperationException("This client has been closed."); - - using (await _connectionLock.LockAsync()) + + using (await _connectionLock.LockAsync().ConfigureAwait(false)) { - _cancelTokenSource = new CancellationTokenSource(); - var cancelToken = _cancelTokenSource.Token; - VoiceSocket.ParentCancelToken = cancelToken; VoiceSocket.Channel = channel; await Task.Run(() => { SendVoiceUpdate(); - VoiceSocket.WaitForConnection(cancelToken); + VoiceSocket.WaitForConnection(_cancelTokenSource.Token); }); } } + public async Task Connect(bool connectGateway) + { + using (await _connectionLock.LockAsync().ConfigureAwait(false)) + { + _cancelTokenSource = new CancellationTokenSource(); + var cancelToken = _cancelTokenSource.Token; + VoiceSocket.ParentCancelToken = cancelToken; + + if (connectGateway) + { + GatewaySocket.ParentCancelToken = cancelToken; + await GatewaySocket.Connect().ConfigureAwait(false); + GatewaySocket.WaitForConnection(cancelToken); + } + } + } + public async Task Disconnect() { - using (await _connectionLock.LockAsync()) + using (await _connectionLock.LockAsync().ConfigureAwait(false)) { - Service.RemoveClient(VoiceSocket.Server, this); + await Service.RemoveClient(VoiceSocket.Server, this).ConfigureAwait(false); VoiceSocket.Channel = null; SendVoiceUpdate(); await VoiceSocket.Disconnect(); diff --git a/src/Discord.Net.Audio/AudioService.cs b/src/Discord.Net.Audio/AudioService.cs index daa5dd5b0..199a5e902 100644 --- a/src/Discord.Net.Audio/AudioService.cs +++ b/src/Discord.Net.Audio/AudioService.cs @@ -1,4 +1,5 @@ -using System; +using Discord.Net.WebSockets; +using System; using System.Collections.Concurrent; using System.Linq; using System.Threading.Tasks; @@ -10,7 +11,7 @@ namespace Discord.Audio private AudioClient _defaultClient; private ConcurrentDictionary _voiceClients; private ConcurrentDictionary _talkingUsers; - //private int _nextClientId; + private int _nextClientId; internal DiscordClient Client { get; private set; } public AudioServiceConfig Config { get; } @@ -83,52 +84,71 @@ namespace Discord.Audio return null; } } - private Task CreateClient(Server server) + private async Task CreateClient(Server server) { - throw new NotImplementedException(); - /*var client = _voiceClients.GetOrAdd(server.Id, _ => - { - int id = unchecked(++_nextClientId); - var logger = Client.Log.CreateLogger($"Voice #{id}"); - var voiceClient = new DiscordAudioClient(this, id, logger, Client.GatewaySocket); - voiceClient.SetServerId(server.Id); + var client = _voiceClients.GetOrAdd(server.Id, _ => null); //Placeholder, so we can't have two clients connecting at once - voiceClient.VoiceSocket.OnPacket += (s, e) => - { - RaiseOnPacket(e); - }; - voiceClient.VoiceSocket.IsSpeaking += (s, e) => - { - var user = server.GetUser(e.UserId); - RaiseUserIsSpeakingUpdated(user, e.IsSpeaking); - }; - - return voiceClient; - }); - //await client.Connect(gatewaySocket.Host, _client.Token).ConfigureAwait(false); - return Task.FromResult(client);*/ + if (client == null) + { + int id = unchecked(++_nextClientId); + + var gatewayLogger = Client.Log.CreateLogger($"Gateway #{id}"); + var gatewaySocket = new GatewaySocket(Client, gatewayLogger); + await gatewaySocket.Connect().ConfigureAwait(false); + + var voiceLogger = Client.Log.CreateLogger($"Voice #{id}"); + var voiceClient = new AudioClient(this, id, server, Client.GatewaySocket, voiceLogger); + await voiceClient.Connect(true).ConfigureAwait(false); + + /*voiceClient.VoiceSocket.FrameReceived += (s, e) => + { + OnFrameReceieved(e); + }; + voiceClient.VoiceSocket.UserIsSpeaking += (s, e) => + { + var user = server.GetUser(e.UserId); + OnUserIsSpeakingUpdated(user, e.IsSpeaking); + };*/ + + //Update the placeholder only it still exists (RemoveClient wasnt called) + if (!_voiceClients.TryUpdate(server.Id, voiceClient, null)) + { + //If it was, cleanup + await voiceClient.Disconnect().ConfigureAwait(false); ; + await gatewaySocket.Disconnect().ConfigureAwait(false); ; + } + } + return client; } //TODO: This isn't threadsafe - internal void RemoveClient(Server server, IAudioClient client) + internal async Task RemoveClient(Server server, IAudioClient client) { if (Config.EnableMultiserver && server != null) - _voiceClients.TryRemove(server.Id, out client); + { + if (_voiceClients.TryRemove(server.Id, out client)) + { + await client.Disconnect(); + await (client as AudioClient).GatewaySocket.Disconnect(); + } + } } public async Task Join(Channel channel) { if (channel == null) throw new ArgumentNullException(nameof(channel)); - - IAudioClient client; + if (!Config.EnableMultiserver) - client = await (_defaultClient as SimpleAudioClient).Connect(channel).ConfigureAwait(false); + { + await (_defaultClient as SimpleAudioClient).Join(channel).ConfigureAwait(false); + return _defaultClient; + } else { - client = await CreateClient(channel.Server).ConfigureAwait(false); + var client = await CreateClient(channel.Server).ConfigureAwait(false); await client.Join(channel).ConfigureAwait(false); + return client; } - return client; } public async Task Leave(Server server) diff --git a/src/Discord.Net.Audio/Net/VoiceWebSocket.cs b/src/Discord.Net.Audio/Net/VoiceWebSocket.cs index c8b8ec5bf..561a4cb8a 100644 --- a/src/Discord.Net.Audio/Net/VoiceWebSocket.cs +++ b/src/Discord.Net.Audio/Net/VoiceWebSocket.cs @@ -57,8 +57,8 @@ namespace Discord.Net.WebSockets internal void OnFrameReceived(ulong userId, ulong channelId, byte[] buffer, int offset, int count) => FrameReceived(this, new InternalFrameEventArgs(userId, channelId, buffer, offset, count)); - internal VoiceWebSocket(DiscordClient client, AudioClient audioClient, JsonSerializer serializer, Logger logger) - : base(client, serializer, logger) + internal VoiceWebSocket(DiscordClient client, AudioClient audioClient, Logger logger) + : base(client, logger) { _audioClient = audioClient; _config = client.Audio().Config; diff --git a/src/Discord.Net.Audio/SimpleAudioClient.cs b/src/Discord.Net.Audio/SimpleAudioClient.cs index 222dc691d..b073e2ed3 100644 --- a/src/Discord.Net.Audio/SimpleAudioClient.cs +++ b/src/Discord.Net.Audio/SimpleAudioClient.cs @@ -42,19 +42,19 @@ namespace Discord.Audio //Only disconnects if is current a member of this server public async Task Leave(VirtualClient client) { - using (await _connectionLock.LockAsync()) + using (await _connectionLock.LockAsync().ConfigureAwait(false)) { if (CurrentClient == client) { CurrentClient = null; - await Disconnect(); + await Disconnect().ConfigureAwait(false); } } } - internal async Task Connect(Channel channel) + internal async Task Connect(Channel channel, bool connectGateway) { - using (await _connectionLock.LockAsync()) + using (await _connectionLock.LockAsync().ConfigureAwait(false)) { bool changeServer = channel.Server != VoiceSocket.Server; if (changeServer || CurrentClient == null) @@ -62,8 +62,9 @@ namespace Discord.Audio await Disconnect().ConfigureAwait(false); CurrentClient = new VirtualClient(this); VoiceSocket.Server = channel.Server; + await Connect(connectGateway).ConfigureAwait(false); } - await Join(channel); + await Join(channel).ConfigureAwait(false); return CurrentClient; } } diff --git a/src/Discord.Net/DiscordClient.cs b/src/Discord.Net/DiscordClient.cs index d08784ab0..ee06e0f12 100644 --- a/src/Discord.Net/DiscordClient.cs +++ b/src/Discord.Net/DiscordClient.cs @@ -30,10 +30,11 @@ namespace Discord private readonly ConcurrentDictionary _servers; private readonly ConcurrentDictionary _channels; private readonly ConcurrentDictionary _privateChannels; //Key = RecipientId - private readonly JsonSerializer _serializer; private Dictionary _regions; private CancellationTokenSource _cancelTokenSource; + internal Logger Logger { get; } + /// Gets the configuration object used to make this client. public DiscordConfig Config { get; } /// Gets the log manager. @@ -48,8 +49,8 @@ namespace Discord public ServiceManager Services { get; } /// Gets the queue used for outgoing messages, if enabled. public MessageQueue MessageQueue { get; } - /// Gets the logger used for this client. - internal Logger Logger { get; } + /// Gets the JSON serializer used by this client. + public JsonSerializer Serializer { get; } /// Gets the current connection state of this client. public ConnectionState State { get; private set; } @@ -101,16 +102,16 @@ namespace Discord _privateChannels = new ConcurrentDictionary(); //Serialization - _serializer = new JsonSerializer(); - _serializer.DateTimeZoneHandling = DateTimeZoneHandling.Utc; + Serializer = new JsonSerializer(); + Serializer.DateTimeZoneHandling = DateTimeZoneHandling.Utc; #if TEST_RESPONSES - _serializer.CheckAdditionalContent = true; - _serializer.MissingMemberHandling = MissingMemberHandling.Error; + Serializer.CheckAdditionalContent = true; + Serializer.MissingMemberHandling = MissingMemberHandling.Error; #else - _serializer.CheckAdditionalContent = false; - _serializer.MissingMemberHandling = MissingMemberHandling.Ignore; + Serializer.CheckAdditionalContent = false; + Serializer.MissingMemberHandling = MissingMemberHandling.Ignore; #endif - _serializer.Error += (s, e) => + Serializer.Error += (s, e) => { e.ErrorContext.Handled = true; Logger.Error("Serialization Failed", e.ErrorContext.Error); @@ -119,7 +120,7 @@ namespace Discord //Networking ClientAPI = new RestClient(Config, DiscordConfig.ClientAPIUrl, Log.CreateLogger("ClientAPI")); StatusAPI = new RestClient(Config, DiscordConfig.StatusAPIUrl, Log.CreateLogger("StatusAPI")); - GatewaySocket = new GatewaySocket(this, _serializer, Log.CreateLogger("Gateway")); + GatewaySocket = new GatewaySocket(this, Log.CreateLogger("Gateway")); GatewaySocket.Connected += (s, e) => { if (State == ConnectionState.Connecting) @@ -486,7 +487,7 @@ namespace Discord Stopwatch stopwatch = null; if (Config.LogLevel >= LogSeverity.Verbose) stopwatch = Stopwatch.StartNew(); - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); GatewaySocket.StartHeartbeat(data.HeartbeatInterval); GatewaySocket.SessionId = data.SessionId; SessionId = data.SessionId; @@ -517,7 +518,7 @@ namespace Discord break; case "RESUMED": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); GatewaySocket.StartHeartbeat(data.HeartbeatInterval); } break; @@ -525,7 +526,7 @@ namespace Discord //Servers case "GUILD_CREATE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); if (data.Unavailable != true) { var server = AddServer(data.Id); @@ -547,7 +548,7 @@ namespace Discord break; case "GUILD_UPDATE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var server = GetServer(data.Id); if (server != null) { @@ -562,7 +563,7 @@ namespace Discord break; case "GUILD_DELETE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); Server server = RemoveServer(data.Id); if (server != null) { @@ -586,7 +587,7 @@ namespace Discord //Channels case "CHANNEL_CREATE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); Channel channel = null; if (data.GuildId != null) @@ -610,7 +611,7 @@ namespace Discord break; case "CHANNEL_UPDATE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var channel = GetChannel(data.Id); if (channel != null) { @@ -625,7 +626,7 @@ namespace Discord break; case "CHANNEL_DELETE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var channel = RemoveChannel(data.Id); if (channel != null) { @@ -641,7 +642,7 @@ namespace Discord //Members case "GUILD_MEMBER_ADD": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var server = GetServer(data.GuildId.Value); if (server != null) { @@ -658,7 +659,7 @@ namespace Discord break; case "GUILD_MEMBER_UPDATE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var server = GetServer(data.GuildId.Value); if (server != null) { @@ -679,7 +680,7 @@ namespace Discord break; case "GUILD_MEMBER_REMOVE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var server = GetServer(data.GuildId.Value); if (server != null) { @@ -699,7 +700,7 @@ namespace Discord break; case "GUILD_MEMBERS_CHUNK": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var server = GetServer(data.GuildId); if (server != null) { @@ -718,7 +719,7 @@ namespace Discord //Roles case "GUILD_ROLE_CREATE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var server = GetServer(data.GuildId); if (server != null) { @@ -734,7 +735,7 @@ namespace Discord break; case "GUILD_ROLE_UPDATE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var server = GetServer(data.GuildId); if (server != null) { @@ -755,7 +756,7 @@ namespace Discord break; case "GUILD_ROLE_DELETE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var server = GetServer(data.GuildId); if (server != null) { @@ -777,7 +778,7 @@ namespace Discord //Bans case "GUILD_BAN_ADD": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var server = GetServer(data.GuildId.Value); if (server != null) { @@ -797,7 +798,7 @@ namespace Discord break; case "GUILD_BAN_REMOVE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var server = GetServer(data.GuildId.Value); if (server != null) { @@ -815,7 +816,7 @@ namespace Discord //Messages case "MESSAGE_CREATE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); Channel channel = GetChannel(data.ChannelId); if (channel != null) @@ -860,7 +861,7 @@ namespace Discord break; case "MESSAGE_UPDATE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var channel = GetChannel(data.ChannelId); if (channel != null) { @@ -876,7 +877,7 @@ namespace Discord break; case "MESSAGE_DELETE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var channel = GetChannel(data.ChannelId); if (channel != null) { @@ -893,7 +894,7 @@ namespace Discord { if (Config.MessageCacheSize > 0) { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var channel = GetChannel(data.ChannelId); if (channel != null) { @@ -914,7 +915,7 @@ namespace Discord //Statuses case "PRESENCE_UPDATE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); User user; Server server; if (data.GuildId == null) @@ -951,7 +952,7 @@ namespace Discord break; case "TYPING_START": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var channel = GetChannel(data.ChannelId); if (channel != null) { @@ -980,7 +981,7 @@ namespace Discord //Voice case "VOICE_STATE_UPDATE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); var server = GetServer(data.GuildId); if (server != null) { @@ -1002,7 +1003,7 @@ namespace Discord //Settings case "USER_UPDATE": { - var data = e.Payload.ToObject(_serializer); + var data = e.Payload.ToObject(Serializer); if (data.Id == CurrentUser.Id) { CurrentUser.Update(data); diff --git a/src/Discord.Net/Net/WebSockets/GatewaySocket.cs b/src/Discord.Net/Net/WebSockets/GatewaySocket.cs index 9c7c6e26f..22ae0af57 100644 --- a/src/Discord.Net/Net/WebSockets/GatewaySocket.cs +++ b/src/Discord.Net/Net/WebSockets/GatewaySocket.cs @@ -22,8 +22,8 @@ namespace Discord.Net.WebSockets private void OnReceivedDispatch(string type, JToken payload) => ReceivedDispatch(this, new WebSocketEventEventArgs(type, payload)); - public GatewaySocket(DiscordClient client, JsonSerializer serializer, Logger logger) - : base(client, serializer, logger) + public GatewaySocket(DiscordClient client, Logger logger) + : base(client, logger) { Disconnected += async (s, e) => { diff --git a/src/Discord.Net/Net/WebSockets/WebSocket.cs b/src/Discord.Net/Net/WebSockets/WebSocket.cs index 0ba1acb49..20b7d6547 100644 --- a/src/Discord.Net/Net/WebSockets/WebSocket.cs +++ b/src/Discord.Net/Net/WebSockets/WebSocket.cs @@ -24,9 +24,7 @@ namespace Discord.Net.WebSockets /// Gets the logger used for this client. protected internal Logger Logger { get; } - public CancellationToken CancelToken { get; private set; } - public CancellationToken? ParentCancelToken { get; set; } public string Host { get; set; } @@ -40,11 +38,11 @@ namespace Discord.Net.WebSockets private void OnDisconnected(bool wasUnexpected, Exception error) => Disconnected(this, new DisconnectedEventArgs(wasUnexpected, error)); - public WebSocket(DiscordClient client, JsonSerializer serializer, Logger logger) + public WebSocket(DiscordClient client, Logger logger) { _client = client; Logger = logger; - _serializer = serializer; + _serializer = client.Serializer; _lock = new AsyncLock(); _taskManager = new TaskManager(Cleanup);