diff --git a/src/Discord.Net.WebSocket/BaseSocketClient.cs b/src/Discord.Net.WebSocket/BaseSocketClient.cs index a7d590b42..90d694329 100644 --- a/src/Discord.Net.WebSocket/BaseSocketClient.cs +++ b/src/Discord.Net.WebSocket/BaseSocketClient.cs @@ -1,5 +1,7 @@ +using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Threading.Tasks; using Discord.API; using Discord.Rest; @@ -46,7 +48,56 @@ namespace Discord.WebSocket public abstract Task SetStatusAsync(UserStatus status); public abstract Task SetGameAsync(string name, string streamUrl = null, ActivityType type = ActivityType.Playing); public abstract Task SetActivityAsync(IActivity activity); - public abstract Task DownloadUsersAsync(IEnumerable guilds); + public abstract Task DownloadUsersAsync(IEnumerable guilds); + + public Task GetChannelAsync(ulong id, CacheMode mode, RequestOptions options) + { + Func> restAction = async () => + await ClientHelper.GetChannelAsync(this, id, options).ConfigureAwait(false) as IChannel; + Func cacheAction = () => GetChannel(id); + + return CacheHelper(mode, restAction, cacheAction); + } + public Task> GetPrivateChannelsAsync(CacheMode mode, RequestOptions options) + { + Func>> restAction = async () => + { + var col = (await ClientHelper.GetPrivateChannelsAsync(this, options).ConfigureAwait(false)).OfType(); + return col.ToReadOnlyCollection(col.Count); + }; + Func> cacheAction = () => PrivateChannels; + + return CacheHelper(mode, restAction, cacheAction); + } + + public Task GetGuildAsync(ulong id, CacheMode mode, RequestOptions options) + { + Func> restAction = async () => + await ClientHelper.GetGuildAsync(this, id, options).ConfigureAwait(false) as IGuild; + Func cacheAction = () => GetGuild(id); + + return CacheHelper(mode, restAction, cacheAction); + } + public Task> GetGuildsAsync(CacheMode mode, RequestOptions options) + { + Func>> restAction = async () => + { + var col = (await ClientHelper.GetGuildsAsync(this, options)).OfType(); + return col.ToReadOnlyCollection(col.Count); + }; + Func> cacheAction = () => Guilds; + + return CacheHelper(mode, restAction, cacheAction); + } + + public Task GetUserAsync(ulong id, CacheMode mode, RequestOptions options) + { + Func> restAction = async () => + await ClientHelper.GetUserAsync(this, id, options).ConfigureAwait(false) as IUser; + Func cacheAction = () => GetUser(id); + + return CacheHelper(mode, restAction, cacheAction); + } /// public Task CreateGuildAsync(string name, IVoiceRegion region, Stream jpegIcon = null, RequestOptions options = null) @@ -90,5 +141,21 @@ namespace Discord.WebSocket => Task.FromResult(GetVoiceRegion(id)); Task> IDiscordClient.GetVoiceRegionsAsync(RequestOptions options) => Task.FromResult>(VoiceRegions); + + public async Task CacheHelper(CacheMode mode, Func> restAction, Func cacheAction) + where T : class + { + switch (mode) + { + case CacheMode.CacheOnly: + return cacheAction(); + case CacheMode.AllowDownload: + return cacheAction() ?? await restAction().ConfigureAwait(false); + case CacheMode.ForceDownload: + return await restAction().ConfigureAwait(false); + default: + throw new InvalidOperationException("Unhandled CacheMode"); + } + } } } diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index 9efc7d3fa..01a9ef1eb 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -1806,6 +1806,7 @@ namespace Discord.WebSocket internal int GetAudioId() => _nextAudioId++; //IDiscordClient + /* async Task IDiscordClient.GetApplicationInfoAsync(RequestOptions options) => await GetApplicationInfoAsync().ConfigureAwait(false); @@ -1839,7 +1840,7 @@ namespace Discord.WebSocket Task> IDiscordClient.GetVoiceRegionsAsync(RequestOptions options) => Task.FromResult>(VoiceRegions); Task IDiscordClient.GetVoiceRegionAsync(string id, RequestOptions options) - => Task.FromResult(GetVoiceRegion(id)); + => Task.FromResult(GetVoiceRegion(id));*/ async Task IDiscordClient.StartAsync() => await StartAsync().ConfigureAwait(false);