@@ -13,7 +13,7 @@ namespace Discord.Commands | |||||
public override async Task<TypeReaderResult> ReadAsync(ICommandContext context, string input, IServiceProvider services) | public override async Task<TypeReaderResult> ReadAsync(ICommandContext context, string input, IServiceProvider services) | ||||
{ | { | ||||
var results = new Dictionary<ulong, TypeReaderValue>(); | var results = new Dictionary<ulong, TypeReaderValue>(); | ||||
IReadOnlyCollection<IUser> channelUsers = (await context.Channel.GetUsersAsync(CacheMode.CacheOnly).Flatten().ConfigureAwait(false)).ToArray(); //TODO: must be a better way? | |||||
IAsyncEnumerable<IUser> channelUsers = context.Channel.GetUsersAsync(CacheMode.CacheOnly).Flatten(); // it's better | |||||
IReadOnlyCollection<IGuildUser> guildUsers = ImmutableArray.Create<IGuildUser>(); | IReadOnlyCollection<IGuildUser> guildUsers = ImmutableArray.Create<IGuildUser>(); | ||||
ulong id; | ulong id; | ||||
@@ -45,7 +45,7 @@ namespace Discord.Commands | |||||
string username = input.Substring(0, index); | string username = input.Substring(0, index); | ||||
if (ushort.TryParse(input.Substring(index + 1), out ushort discriminator)) | if (ushort.TryParse(input.Substring(index + 1), out ushort discriminator)) | ||||
{ | { | ||||
var channelUser = channelUsers.FirstOrDefault(x => x.DiscriminatorValue == discriminator && | |||||
var channelUser = await channelUsers.FirstOrDefault(x => x.DiscriminatorValue == discriminator && | |||||
string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)); | string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)); | ||||
AddResult(results, channelUser as T, channelUser?.Username == username ? 0.85f : 0.75f); | AddResult(results, channelUser as T, channelUser?.Username == username ? 0.85f : 0.75f); | ||||
@@ -57,8 +57,9 @@ namespace Discord.Commands | |||||
//By Username (0.5-0.6) | //By Username (0.5-0.6) | ||||
{ | { | ||||
foreach (var channelUser in channelUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase))) | |||||
AddResult(results, channelUser as T, channelUser.Username == input ? 0.65f : 0.55f); | |||||
await channelUsers | |||||
.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase)) | |||||
.ForEachAsync(channelUser => AddResult(results, channelUser as T, channelUser.Username == input ? 0.65f : 0.55f)); | |||||
foreach (var guildUser in guildUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase))) | foreach (var guildUser in guildUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase))) | ||||
AddResult(results, guildUser as T, guildUser.Username == input ? 0.60f : 0.50f); | AddResult(results, guildUser as T, guildUser.Username == input ? 0.60f : 0.50f); | ||||
@@ -66,8 +67,9 @@ namespace Discord.Commands | |||||
//By Nickname (0.5-0.6) | //By Nickname (0.5-0.6) | ||||
{ | { | ||||
foreach (var channelUser in channelUsers.Where(x => string.Equals(input, (x as IGuildUser)?.Nickname, StringComparison.OrdinalIgnoreCase))) | |||||
AddResult(results, channelUser as T, (channelUser as IGuildUser).Nickname == input ? 0.65f : 0.55f); | |||||
await channelUsers | |||||
.Where(x => string.Equals(input, (x as IGuildUser)?.Nickname, StringComparison.OrdinalIgnoreCase)) | |||||
.ForEachAsync(channelUser => AddResult(results, channelUser as T, (channelUser as IGuildUser).Nickname == input ? 0.65f : 0.55f)); | |||||
foreach (var guildUser in guildUsers.Where(x => string.Equals(input, (x as IGuildUser).Nickname, StringComparison.OrdinalIgnoreCase))) | foreach (var guildUser in guildUsers.Where(x => string.Equals(input, (x as IGuildUser).Nickname, StringComparison.OrdinalIgnoreCase))) | ||||
AddResult(results, guildUser as T, (guildUser as IGuildUser).Nickname == input ? 0.60f : 0.50f); | AddResult(results, guildUser as T, (guildUser as IGuildUser).Nickname == input ? 0.60f : 0.50f); | ||||
@@ -1,14 +1,64 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Threading; | |||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
namespace Discord | namespace Discord | ||||
{ | { | ||||
public static class AsyncEnumerableExtensions | public static class AsyncEnumerableExtensions | ||||
{ | { | ||||
public static async Task<IEnumerable<T>> Flatten<T>(this IAsyncEnumerable<IReadOnlyCollection<T>> source) | |||||
/// <summary> | |||||
/// Flattens the specified pages into one <see cref="IEnumerable{T}"/> asynchronously | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
/// <param name="source"></param> | |||||
/// <returns></returns> | |||||
public static async Task<IEnumerable<T>> FlattenAsync<T>(this IAsyncEnumerable<IEnumerable<T>> source) | |||||
{ | { | ||||
return (await source.ToArray().ConfigureAwait(false)).SelectMany(x => x); | |||||
return await source.Flatten().ToArray().ConfigureAwait(false); | |||||
} | |||||
public static IAsyncEnumerable<T> Flatten<T>(this IAsyncEnumerable<IEnumerable<T>> source) | |||||
{ | |||||
return new PagedCollectionEnumerator<T>(source); | |||||
} | |||||
internal class PagedCollectionEnumerator<T> : IAsyncEnumerator<T>, IAsyncEnumerable<T> | |||||
{ | |||||
readonly IAsyncEnumerator<IEnumerable<T>> _source; | |||||
IEnumerator<T> _enumerator; | |||||
public IAsyncEnumerator<T> GetEnumerator() => this; | |||||
internal PagedCollectionEnumerator(IAsyncEnumerable<IEnumerable<T>> source) | |||||
{ | |||||
_source = source.GetEnumerator(); | |||||
} | |||||
public T Current => _enumerator.Current; | |||||
public void Dispose() | |||||
{ | |||||
_enumerator?.Dispose(); | |||||
_source.Dispose(); | |||||
} | |||||
public async Task<bool> MoveNext(CancellationToken cancellationToken) | |||||
{ | |||||
cancellationToken.ThrowIfCancellationRequested(); | |||||
if(!_enumerator?.MoveNext() ?? true) | |||||
{ | |||||
if (!await _source.MoveNext(cancellationToken).ConfigureAwait(false)) | |||||
return false; | |||||
_enumerator?.Dispose(); | |||||
_enumerator = _source.Current.GetEnumerator(); | |||||
return _enumerator.MoveNext(); | |||||
} | |||||
return true; | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -79,7 +79,7 @@ namespace Discord.Rest | |||||
ulong? fromGuildId, int? limit, RequestOptions options) | ulong? fromGuildId, int? limit, RequestOptions options) | ||||
{ | { | ||||
return new PagedAsyncEnumerable<RestUserGuild>( | return new PagedAsyncEnumerable<RestUserGuild>( | ||||
DiscordConfig.MaxUsersPerBatch, | |||||
DiscordConfig.MaxGuildsPerBatch, | |||||
async (info, ct) => | async (info, ct) => | ||||
{ | { | ||||
var args = new GetGuildSummariesParams | var args = new GetGuildSummariesParams | ||||
@@ -106,7 +106,7 @@ namespace Discord.Rest | |||||
} | } | ||||
public static async Task<IReadOnlyCollection<RestGuild>> GetGuildsAsync(BaseDiscordClient client, RequestOptions options) | public static async Task<IReadOnlyCollection<RestGuild>> GetGuildsAsync(BaseDiscordClient client, RequestOptions options) | ||||
{ | { | ||||
var summaryModels = await GetGuildSummariesAsync(client, null, null, options).Flatten(); | |||||
var summaryModels = await GetGuildSummariesAsync(client, null, null, options).FlattenAsync().ConfigureAwait(false); | |||||
var guilds = ImmutableArray.CreateBuilder<RestGuild>(); | var guilds = ImmutableArray.CreateBuilder<RestGuild>(); | ||||
foreach (var summaryModel in summaryModels) | foreach (var summaryModel in summaryModels) | ||||
{ | { | ||||
@@ -413,7 +413,7 @@ namespace Discord.Rest | |||||
async Task<IReadOnlyCollection<IGuildUser>> IGuild.GetUsersAsync(CacheMode mode, RequestOptions options) | async Task<IReadOnlyCollection<IGuildUser>> IGuild.GetUsersAsync(CacheMode mode, RequestOptions options) | ||||
{ | { | ||||
if (mode == CacheMode.AllowDownload) | if (mode == CacheMode.AllowDownload) | ||||
return (await GetUsersAsync(options).Flatten().ConfigureAwait(false)).ToImmutableArray(); | |||||
return (await GetUsersAsync(options).FlattenAsync().ConfigureAwait(false)).ToImmutableArray(); | |||||
else | else | ||||
return ImmutableArray.Create<IGuildUser>(); | return ImmutableArray.Create<IGuildUser>(); | ||||
} | } | ||||