diff --git a/.gitignore b/.gitignore index d6c4cf780..d7bf0ef19 100644 --- a/.gitignore +++ b/.gitignore @@ -200,3 +200,4 @@ project.lock.json /test/Discord.Net.Tests/config.json /docs/_build *.pyc +/.editorconfig diff --git a/src/Discord.Net.Commands/Attributes/PreconditionAttribute.cs b/src/Discord.Net.Commands/Attributes/PreconditionAttribute.cs new file mode 100644 index 000000000..9d7ec8983 --- /dev/null +++ b/src/Discord.Net.Commands/Attributes/PreconditionAttribute.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace Discord.Commands +{ + [AttributeUsage(AttributeTargets.Method | AttributeTargets.Class, AllowMultiple = true, Inherited = true)] + public abstract class PreconditionAttribute : Attribute + { + public abstract Task CheckPermissions(IMessage context, Command executingCommand, object moduleInstance); + } +} diff --git a/src/Discord.Net.Commands/Attributes/Preconditions/RequireContextAttribute.cs b/src/Discord.Net.Commands/Attributes/Preconditions/RequireContextAttribute.cs new file mode 100644 index 000000000..da9391fad --- /dev/null +++ b/src/Discord.Net.Commands/Attributes/Preconditions/RequireContextAttribute.cs @@ -0,0 +1,42 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace Discord.Commands +{ + [Flags] + public enum ContextType + { + Guild = 1, // 01 + DM = 2 // 10 + } + + + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] + public class RequireContextAttribute : PreconditionAttribute + { + public ContextType Context { get; set; } + + public RequireContextAttribute(ContextType context) + { + Context = context; + } + + public override Task CheckPermissions(IMessage context, Command executingCommand, object moduleInstance) + { + var validContext = false; + + if (Context.HasFlag(ContextType.Guild)) + validContext = validContext || context.Channel is IGuildChannel; + + if (Context.HasFlag(ContextType.DM)) + validContext = validContext || context.Channel is IDMChannel; + + if (validContext) + return Task.FromResult(PreconditionResult.FromSuccess()); + else + return Task.FromResult(PreconditionResult.FromError($"Invalid context for command; accepted contexts: {Context}")); + } + } +} diff --git a/src/Discord.Net.Commands/Attributes/Preconditions/RequirePermission.cs b/src/Discord.Net.Commands/Attributes/Preconditions/RequirePermission.cs new file mode 100644 index 000000000..a970685f5 --- /dev/null +++ b/src/Discord.Net.Commands/Attributes/Preconditions/RequirePermission.cs @@ -0,0 +1,52 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace Discord.Commands.Attributes.Preconditions +{ + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true, Inherited = true)] + public class RequirePermission : PreconditionAttribute + { + public GuildPermission? GuildPermission { get; set; } + public ChannelPermission? ChannelPermission { get; set; } + + public RequirePermission(GuildPermission permission) + { + GuildPermission = permission; + ChannelPermission = null; + } + + public RequirePermission(ChannelPermission permission) + { + ChannelPermission = permission; + GuildPermission = null; + } + + public override Task CheckPermissions(IMessage context, Command executingCommand, object moduleInstance) + { + if (!(context.Channel is IGuildChannel)) + return Task.FromResult(PreconditionResult.FromError("Command must be used in a guild channel")); + + var author = context.Author as IGuildUser; + + if (GuildPermission.HasValue) + { + var guildPerms = author.GuildPermissions.ToList(); + if (!guildPerms.Contains(GuildPermission.Value)) + return Task.FromResult(PreconditionResult.FromError($"User is missing guild permission {GuildPermission.Value}")); + } + + if (ChannelPermission.HasValue) + { + var channel = context.Channel as IGuildChannel; + var channelPerms = author.GetPermissions(channel).ToList(); + + if (!channelPerms.Contains(ChannelPermission.Value)) + return Task.FromResult(PreconditionResult.FromError($"User is missing channel permission {ChannelPermission.Value}")); + } + + return Task.FromResult(PreconditionResult.FromSuccess()); + } + } +} diff --git a/src/Discord.Net.Commands/Command.cs b/src/Discord.Net.Commands/Command.cs index 5729e4c81..2b0f34eb3 100644 --- a/src/Discord.Net.Commands/Command.cs +++ b/src/Discord.Net.Commands/Command.cs @@ -19,7 +19,8 @@ namespace Discord.Commands public string Text { get; } public Module Module { get; } public IReadOnlyList Parameters { get; } - + public IReadOnlyList Preconditions { get; } + internal Command(Module module, object instance, CommandAttribute attribute, MethodInfo methodInfo, string groupPrefix) { Module = module; @@ -37,9 +38,29 @@ namespace Discord.Commands Synopsis = synopsis.Text; Parameters = BuildParameters(methodInfo); + Preconditions = BuildPreconditions(methodInfo); _action = BuildAction(methodInfo); } + public async Task CheckPreconditions(IMessage context) + { + foreach (PreconditionAttribute precondition in Module.Preconditions) + { + var result = await precondition.CheckPermissions(context, this, Module.Instance).ConfigureAwait(false); + if (!result.IsSuccess) + return result; + } + + foreach (PreconditionAttribute precondition in Preconditions) + { + var result = await precondition.CheckPermissions(context, this, Module.Instance).ConfigureAwait(false); + if (!result.IsSuccess) + return result; + } + + return PreconditionResult.FromSuccess(); + } + public async Task Parse(IMessage msg, SearchResult searchResult) { if (!searchResult.IsSuccess) @@ -63,6 +84,11 @@ namespace Discord.Commands } } + private IReadOnlyList BuildPreconditions(MethodInfo methodInfo) + { + return methodInfo.GetCustomAttributes().ToImmutableArray(); + } + private IReadOnlyList BuildParameters(MethodInfo methodInfo) { var parameters = methodInfo.GetParameters(); @@ -115,7 +141,7 @@ namespace Discord.Commands { if (methodInfo.ReturnType != typeof(Task)) throw new InvalidOperationException("Commands must return a non-generic Task."); - + return (msg, args) => { object[] newArgs = new object[args.Count + 1]; diff --git a/src/Discord.Net.Commands/CommandError.cs b/src/Discord.Net.Commands/CommandError.cs index 135930dd9..31a84ea1a 100644 --- a/src/Discord.Net.Commands/CommandError.cs +++ b/src/Discord.Net.Commands/CommandError.cs @@ -16,5 +16,6 @@ //Execute Exception, + UnmetPrecondition } } diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index 2ce7c5517..9446d5700 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -209,8 +209,18 @@ namespace Discord.Commands return searchResult; var commands = searchResult.Commands; + for (int i = commands.Count - 1; i >= 0; i--) { + var preconditionResult = await commands[i].CheckPreconditions(message); + if (!preconditionResult.IsSuccess) + { + if (commands.Count == 1) + return preconditionResult; + else + continue; + } + var parseResult = await commands[i].Parse(message, searchResult); if (!parseResult.IsSuccess) { diff --git a/src/Discord.Net.Commands/Module.cs b/src/Discord.Net.Commands/Module.cs index b884832bc..07feaeca2 100644 --- a/src/Discord.Net.Commands/Module.cs +++ b/src/Discord.Net.Commands/Module.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Collections.Immutable; using System.Diagnostics; using System.Reflection; @@ -12,6 +13,8 @@ namespace Discord.Commands public IEnumerable Commands { get; } internal object Instance { get; } + public IReadOnlyList Preconditions { get; } + internal Module(CommandService service, object instance, ModuleAttribute moduleAttr, TypeInfo typeInfo) { Service = service; @@ -21,6 +24,8 @@ namespace Discord.Commands List commands = new List(); SearchClass(instance, commands, typeInfo, moduleAttr.Prefix ?? ""); Commands = commands; + + Preconditions = BuildPreconditions(typeInfo); } private void SearchClass(object instance, List commands, TypeInfo typeInfo, string groupPrefix) @@ -48,6 +53,11 @@ namespace Discord.Commands } } + private IReadOnlyList BuildPreconditions(TypeInfo typeInfo) + { + return typeInfo.GetCustomAttributes().ToImmutableArray(); + } + public override string ToString() => Name; private string DebuggerDisplay => Name; } diff --git a/src/Discord.Net.Commands/Results/ExecuteResult.cs b/src/Discord.Net.Commands/Results/ExecuteResult.cs index a06e8dd99..60d47c7cb 100644 --- a/src/Discord.Net.Commands/Results/ExecuteResult.cs +++ b/src/Discord.Net.Commands/Results/ExecuteResult.cs @@ -28,6 +28,8 @@ namespace Discord.Commands => new ExecuteResult(ex, CommandError.Exception, ex.Message); internal static ExecuteResult FromError(ParseResult result) => new ExecuteResult(null, result.Error, result.ErrorReason); + internal static ExecuteResult FromError(PreconditionResult result) + => new ExecuteResult(null, result.Error, result.ErrorReason); public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; private string DebuggerDisplay => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; diff --git a/src/Discord.Net.Commands/Results/PreconditionResult.cs b/src/Discord.Net.Commands/Results/PreconditionResult.cs new file mode 100644 index 000000000..9d36ba23f --- /dev/null +++ b/src/Discord.Net.Commands/Results/PreconditionResult.cs @@ -0,0 +1,27 @@ +using System.Diagnostics; + +namespace Discord.Commands +{ + [DebuggerDisplay(@"{DebuggerDisplay,nq}")] + public struct PreconditionResult : IResult + { + public CommandError? Error { get; } + public string ErrorReason { get; } + + public bool IsSuccess => !Error.HasValue; + + private PreconditionResult(CommandError? error, string errorReason) + { + Error = error; + ErrorReason = errorReason; + } + + internal static PreconditionResult FromSuccess() + => new PreconditionResult(null, null); + internal static PreconditionResult FromError(string reason) + => new PreconditionResult(CommandError.UnmetPrecondition, reason); + + public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; + private string DebuggerDisplay => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; + } +}