Preconditions for commandspull/193/merge
@@ -200,3 +200,4 @@ project.lock.json | |||
/test/Discord.Net.Tests/config.json | |||
/docs/_build | |||
*.pyc | |||
/.editorconfig |
@@ -0,0 +1,13 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Threading.Tasks; | |||
namespace Discord.Commands | |||
{ | |||
[AttributeUsage(AttributeTargets.Method | AttributeTargets.Class, AllowMultiple = true, Inherited = true)] | |||
public abstract class PreconditionAttribute : Attribute | |||
{ | |||
public abstract Task<PreconditionResult> CheckPermissions(IMessage context, Command executingCommand, object moduleInstance); | |||
} | |||
} |
@@ -0,0 +1,42 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Threading.Tasks; | |||
namespace Discord.Commands | |||
{ | |||
[Flags] | |||
public enum ContextType | |||
{ | |||
Guild = 1, // 01 | |||
DM = 2 // 10 | |||
} | |||
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] | |||
public class RequireContextAttribute : PreconditionAttribute | |||
{ | |||
public ContextType Context { get; set; } | |||
public RequireContextAttribute(ContextType context) | |||
{ | |||
Context = context; | |||
} | |||
public override Task<PreconditionResult> CheckPermissions(IMessage context, Command executingCommand, object moduleInstance) | |||
{ | |||
var validContext = false; | |||
if (Context.HasFlag(ContextType.Guild)) | |||
validContext = validContext || context.Channel is IGuildChannel; | |||
if (Context.HasFlag(ContextType.DM)) | |||
validContext = validContext || context.Channel is IDMChannel; | |||
if (validContext) | |||
return Task.FromResult(PreconditionResult.FromSuccess()); | |||
else | |||
return Task.FromResult(PreconditionResult.FromError($"Invalid context for command; accepted contexts: {Context}")); | |||
} | |||
} | |||
} |
@@ -0,0 +1,52 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Threading.Tasks; | |||
namespace Discord.Commands.Attributes.Preconditions | |||
{ | |||
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true, Inherited = true)] | |||
public class RequirePermission : PreconditionAttribute | |||
{ | |||
public GuildPermission? GuildPermission { get; set; } | |||
public ChannelPermission? ChannelPermission { get; set; } | |||
public RequirePermission(GuildPermission permission) | |||
{ | |||
GuildPermission = permission; | |||
ChannelPermission = null; | |||
} | |||
public RequirePermission(ChannelPermission permission) | |||
{ | |||
ChannelPermission = permission; | |||
GuildPermission = null; | |||
} | |||
public override Task<PreconditionResult> CheckPermissions(IMessage context, Command executingCommand, object moduleInstance) | |||
{ | |||
if (!(context.Channel is IGuildChannel)) | |||
return Task.FromResult(PreconditionResult.FromError("Command must be used in a guild channel")); | |||
var author = context.Author as IGuildUser; | |||
if (GuildPermission.HasValue) | |||
{ | |||
var guildPerms = author.GuildPermissions.ToList(); | |||
if (!guildPerms.Contains(GuildPermission.Value)) | |||
return Task.FromResult(PreconditionResult.FromError($"User is missing guild permission {GuildPermission.Value}")); | |||
} | |||
if (ChannelPermission.HasValue) | |||
{ | |||
var channel = context.Channel as IGuildChannel; | |||
var channelPerms = author.GetPermissions(channel).ToList(); | |||
if (!channelPerms.Contains(ChannelPermission.Value)) | |||
return Task.FromResult(PreconditionResult.FromError($"User is missing channel permission {ChannelPermission.Value}")); | |||
} | |||
return Task.FromResult(PreconditionResult.FromSuccess()); | |||
} | |||
} | |||
} |
@@ -19,7 +19,8 @@ namespace Discord.Commands | |||
public string Text { get; } | |||
public Module Module { get; } | |||
public IReadOnlyList<CommandParameter> Parameters { get; } | |||
public IReadOnlyList<PreconditionAttribute> Preconditions { get; } | |||
internal Command(Module module, object instance, CommandAttribute attribute, MethodInfo methodInfo, string groupPrefix) | |||
{ | |||
Module = module; | |||
@@ -37,9 +38,29 @@ namespace Discord.Commands | |||
Synopsis = synopsis.Text; | |||
Parameters = BuildParameters(methodInfo); | |||
Preconditions = BuildPreconditions(methodInfo); | |||
_action = BuildAction(methodInfo); | |||
} | |||
public async Task<PreconditionResult> CheckPreconditions(IMessage context) | |||
{ | |||
foreach (PreconditionAttribute precondition in Module.Preconditions) | |||
{ | |||
var result = await precondition.CheckPermissions(context, this, Module.Instance).ConfigureAwait(false); | |||
if (!result.IsSuccess) | |||
return result; | |||
} | |||
foreach (PreconditionAttribute precondition in Preconditions) | |||
{ | |||
var result = await precondition.CheckPermissions(context, this, Module.Instance).ConfigureAwait(false); | |||
if (!result.IsSuccess) | |||
return result; | |||
} | |||
return PreconditionResult.FromSuccess(); | |||
} | |||
public async Task<ParseResult> Parse(IMessage msg, SearchResult searchResult) | |||
{ | |||
if (!searchResult.IsSuccess) | |||
@@ -63,6 +84,11 @@ namespace Discord.Commands | |||
} | |||
} | |||
private IReadOnlyList<PreconditionAttribute> BuildPreconditions(MethodInfo methodInfo) | |||
{ | |||
return methodInfo.GetCustomAttributes<PreconditionAttribute>().ToImmutableArray(); | |||
} | |||
private IReadOnlyList<CommandParameter> BuildParameters(MethodInfo methodInfo) | |||
{ | |||
var parameters = methodInfo.GetParameters(); | |||
@@ -115,7 +141,7 @@ namespace Discord.Commands | |||
{ | |||
if (methodInfo.ReturnType != typeof(Task)) | |||
throw new InvalidOperationException("Commands must return a non-generic Task."); | |||
return (msg, args) => | |||
{ | |||
object[] newArgs = new object[args.Count + 1]; | |||
@@ -16,5 +16,6 @@ | |||
//Execute | |||
Exception, | |||
UnmetPrecondition | |||
} | |||
} |
@@ -209,8 +209,18 @@ namespace Discord.Commands | |||
return searchResult; | |||
var commands = searchResult.Commands; | |||
for (int i = commands.Count - 1; i >= 0; i--) | |||
{ | |||
var preconditionResult = await commands[i].CheckPreconditions(message); | |||
if (!preconditionResult.IsSuccess) | |||
{ | |||
if (commands.Count == 1) | |||
return preconditionResult; | |||
else | |||
continue; | |||
} | |||
var parseResult = await commands[i].Parse(message, searchResult); | |||
if (!parseResult.IsSuccess) | |||
{ | |||
@@ -1,4 +1,5 @@ | |||
using System.Collections.Generic; | |||
using System.Collections.Immutable; | |||
using System.Diagnostics; | |||
using System.Reflection; | |||
@@ -12,6 +13,8 @@ namespace Discord.Commands | |||
public IEnumerable<Command> Commands { get; } | |||
internal object Instance { get; } | |||
public IReadOnlyList<PreconditionAttribute> Preconditions { get; } | |||
internal Module(CommandService service, object instance, ModuleAttribute moduleAttr, TypeInfo typeInfo) | |||
{ | |||
Service = service; | |||
@@ -21,6 +24,8 @@ namespace Discord.Commands | |||
List<Command> commands = new List<Command>(); | |||
SearchClass(instance, commands, typeInfo, moduleAttr.Prefix ?? ""); | |||
Commands = commands; | |||
Preconditions = BuildPreconditions(typeInfo); | |||
} | |||
private void SearchClass(object instance, List<Command> commands, TypeInfo typeInfo, string groupPrefix) | |||
@@ -48,6 +53,11 @@ namespace Discord.Commands | |||
} | |||
} | |||
private IReadOnlyList<PreconditionAttribute> BuildPreconditions(TypeInfo typeInfo) | |||
{ | |||
return typeInfo.GetCustomAttributes<PreconditionAttribute>().ToImmutableArray(); | |||
} | |||
public override string ToString() => Name; | |||
private string DebuggerDisplay => Name; | |||
} | |||
@@ -28,6 +28,8 @@ namespace Discord.Commands | |||
=> new ExecuteResult(ex, CommandError.Exception, ex.Message); | |||
internal static ExecuteResult FromError(ParseResult result) | |||
=> new ExecuteResult(null, result.Error, result.ErrorReason); | |||
internal static ExecuteResult FromError(PreconditionResult result) | |||
=> new ExecuteResult(null, result.Error, result.ErrorReason); | |||
public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; | |||
private string DebuggerDisplay => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; | |||
@@ -0,0 +1,27 @@ | |||
using System.Diagnostics; | |||
namespace Discord.Commands | |||
{ | |||
[DebuggerDisplay(@"{DebuggerDisplay,nq}")] | |||
public struct PreconditionResult : IResult | |||
{ | |||
public CommandError? Error { get; } | |||
public string ErrorReason { get; } | |||
public bool IsSuccess => !Error.HasValue; | |||
private PreconditionResult(CommandError? error, string errorReason) | |||
{ | |||
Error = error; | |||
ErrorReason = errorReason; | |||
} | |||
internal static PreconditionResult FromSuccess() | |||
=> new PreconditionResult(null, null); | |||
internal static PreconditionResult FromError(string reason) | |||
=> new PreconditionResult(CommandError.UnmetPrecondition, reason); | |||
public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; | |||
private string DebuggerDisplay => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; | |||
} | |||
} |