diff --git a/src/Discord.Net.Commands/Command.cs b/src/Discord.Net.Commands/Command.cs index 9b89b2cc0..433a63073 100644 --- a/src/Discord.Net.Commands/Command.cs +++ b/src/Discord.Net.Commands/Command.cs @@ -75,19 +75,24 @@ namespace Discord.Commands continue; } + var typeInfo = type.GetTypeInfo(); + if (typeInfo.IsEnum) + Module.Service.AddTypeReader(type, new EnumTypeReader(type)); + var reader = Module.Service.GetTypeReader(type); if (reader == null) { - var typeInfo = type.GetTypeInfo(); if (typeInfo.IsEnum) + { type = Enum.GetUnderlyingType(type); - reader = Module.Service.GetTypeReader(type); - } + reader = Module.Service.GetTypeReader(type); + } - if (reader == null) - throw new InvalidOperationException($"{type.FullName} is not supported as a command parameter, are you missing a TypeReader?"); + if (reader == null) + throw new InvalidOperationException($"{type.FullName} is not supported as a command parameter, are you missing a TypeReader?"); + } bool isUnparsed = parameter.GetCustomAttribute() != null; if (isUnparsed) diff --git a/src/Discord.Net.Commands/Readers/EnumTypeReader.cs b/src/Discord.Net.Commands/Readers/EnumTypeReader.cs new file mode 100644 index 000000000..fc75432bf --- /dev/null +++ b/src/Discord.Net.Commands/Readers/EnumTypeReader.cs @@ -0,0 +1,54 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace Discord.Commands +{ + internal class EnumTypeReader : TypeReader + { + private readonly Dictionary stringValues; + private readonly Dictionary intValues; + private readonly Type enumType; + + public override Task Read(IMessage context, string input) + { + int inputAsInt; + object enumValue; + + if (int.TryParse(input, out inputAsInt)) + { + if (intValues.TryGetValue(inputAsInt, out enumValue)) + return Task.FromResult(TypeReaderResult.FromSuccess(enumValue)); + else + return Task.FromResult(TypeReaderResult.FromError(CommandError.CastFailed, $"Value is not a {enumType.Name}")); + } + else + { + if (stringValues.TryGetValue(input.ToLower(), out enumValue)) + return Task.FromResult(TypeReaderResult.FromSuccess(enumValue)); + else + return Task.FromResult(TypeReaderResult.FromError(CommandError.CastFailed, $"Value is not a {enumType.Name}")); + } + } + + public EnumTypeReader(Type type) + { + enumType = type; + + var stringValuesBuilder = new Dictionary(); + var intValuesBuilder = new Dictionary(); + + var values = Enum.GetValues(enumType); + + foreach (var v in values) + { + stringValuesBuilder.Add(v.ToString().ToLower(), v); + intValuesBuilder.Add((int)v, v); + } + + stringValues = stringValuesBuilder; + intValues = intValuesBuilder; + } + } +}