diff --git a/src/Discord.Net.Commands/CommandInfo.cs b/src/Discord.Net.Commands/CommandInfo.cs index 75107a80c..770dfdb6b 100644 --- a/src/Discord.Net.Commands/CommandInfo.cs +++ b/src/Discord.Net.Commands/CommandInfo.cs @@ -15,7 +15,7 @@ namespace Discord.Commands private static readonly MethodInfo _convertParamsMethod = typeof(CommandInfo).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList)); private static readonly ConcurrentDictionary, object>> _arrayConverters = new ConcurrentDictionary, object>>(); - private readonly Func _action; + private readonly Func _action; public MethodInfo Source { get; } public ModuleInfo Module { get; } @@ -125,7 +125,7 @@ namespace Discord.Commands return await CommandParser.ParseArgs(this, context, input, 0).ConfigureAwait(false); } - public Task Execute(CommandContext context, ParseResult parseResult) + public Task Execute(CommandContext context, ParseResult parseResult, IDependencyMap map) { if (!parseResult.IsSuccess) return Task.FromResult(ExecuteResult.FromError(parseResult)); @@ -146,9 +146,9 @@ namespace Discord.Commands paramList[i] = parseResult.ParamValues[i].Values.First().Value; } - return Execute(context, argList, paramList); + return Execute(context, argList, paramList, map); } - public async Task Execute(CommandContext context, IEnumerable argList, IEnumerable paramList) + public async Task Execute(CommandContext context, IEnumerable argList, IEnumerable paramList, IDependencyMap map) { try { @@ -156,13 +156,13 @@ namespace Discord.Commands switch (RunMode) { case RunMode.Sync: //Always sync - await _action(context, args).ConfigureAwait(false); + await _action(context, args, map).ConfigureAwait(false); break; case RunMode.Mixed: //Sync until first await statement - var t1 = _action(context, args); + var t1 = _action(context, args, map); break; case RunMode.Async: //Always async - var t2 = Task.Run(() => _action(context, args)); + var t2 = Task.Run(() => _action(context, args, map)); break; } return ExecuteResult.FromSuccess(); @@ -219,14 +219,14 @@ namespace Discord.Commands } return paramBuilder.ToImmutable(); } - private Func BuildAction(MethodInfo methodInfo) + private Func BuildAction(MethodInfo methodInfo) { if (methodInfo.ReturnType != typeof(Task)) throw new InvalidOperationException("Commands must return a non-generic Task."); - return (context, args) => + return (context, args, map) => { - var instance = Module.CreateInstance(); + var instance = Module.CreateInstance(map); instance.Context = context; try { diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index 6fe5667ea..c55f1541d 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -65,7 +65,7 @@ namespace Discord.Commands } //Modules - public async Task AddModule(IDependencyMap dependencyMap = null) + public async Task AddModule() { await _moduleLock.WaitAsync().ConfigureAwait(false); try @@ -80,14 +80,14 @@ namespace Discord.Commands if (_moduleDefs.ContainsKey(typeof(T))) throw new ArgumentException($"This module has already been added."); - return AddModuleInternal(typeInfo, dependencyMap); + return AddModuleInternal(typeInfo); } finally { _moduleLock.Release(); } } - public async Task> AddModules(Assembly assembly, IDependencyMap dependencyMap = null) + public async Task> AddModules(Assembly assembly) { var moduleDefs = ImmutableArray.CreateBuilder(); await _moduleLock.WaitAsync().ConfigureAwait(false); @@ -102,7 +102,7 @@ namespace Discord.Commands { var dontAutoLoad = typeInfo.GetCustomAttribute(); if (dontAutoLoad == null && !typeInfo.IsAbstract) - moduleDefs.Add(AddModuleInternal(typeInfo, dependencyMap)); + moduleDefs.Add(AddModuleInternal(typeInfo)); } } } @@ -113,9 +113,9 @@ namespace Discord.Commands _moduleLock.Release(); } } - private ModuleInfo AddModuleInternal(TypeInfo typeInfo, IDependencyMap dependencyMap) + private ModuleInfo AddModuleInternal(TypeInfo typeInfo) { - var moduleDef = new ModuleInfo(typeInfo, this, dependencyMap); + var moduleDef = new ModuleInfo(typeInfo, this); _moduleDefs[typeInfo.AsType()] = moduleDef; foreach (var cmd in moduleDef.Commands) @@ -236,7 +236,7 @@ namespace Discord.Commands } } - return await commands[i].Execute(context, parseResult).ConfigureAwait(false); + return await commands[i].Execute(context, parseResult, dependencyMap).ConfigureAwait(false); } return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload."); diff --git a/src/Discord.Net.Commands/ModuleInfo.cs b/src/Discord.Net.Commands/ModuleInfo.cs index c061e3de4..b7471edb5 100644 --- a/src/Discord.Net.Commands/ModuleInfo.cs +++ b/src/Discord.Net.Commands/ModuleInfo.cs @@ -9,7 +9,7 @@ namespace Discord.Commands [DebuggerDisplay(@"{DebuggerDisplay,nq}")] public class ModuleInfo { - internal readonly Func _builder; + internal readonly Func _builder; public TypeInfo Source { get; } public CommandService Service { get; } @@ -20,12 +20,12 @@ namespace Discord.Commands public IEnumerable Commands { get; } public IReadOnlyList Preconditions { get; } - internal ModuleInfo(TypeInfo source, CommandService service, IDependencyMap dependencyMap) + internal ModuleInfo(TypeInfo source, CommandService service) { Source = source; Service = service; Name = source.Name; - _builder = ReflectionUtils.CreateBuilder(source, Service, dependencyMap); + _builder = ReflectionUtils.CreateBuilder(source, Service); var groupAttr = source.GetCustomAttribute(); if (groupAttr != null) @@ -46,12 +46,12 @@ namespace Discord.Commands Remarks = remarksAttr.Text; List commands = new List(); - SearchClass(source, commands, Prefix, dependencyMap); + SearchClass(source, commands, Prefix); Commands = commands; Preconditions = Source.GetCustomAttributes().ToImmutableArray(); } - private void SearchClass(TypeInfo parentType, List commands, string groupPrefix, IDependencyMap dependencyMap) + private void SearchClass(TypeInfo parentType, List commands, string groupPrefix) { foreach (var method in parentType.DeclaredMethods) { @@ -71,13 +71,13 @@ namespace Discord.Commands else nextGroupPrefix = groupAttrib.Prefix ?? type.Name.ToLowerInvariant(); - SearchClass(type, commands, nextGroupPrefix, dependencyMap); + SearchClass(type, commands, nextGroupPrefix); } } } - internal ModuleBase CreateInstance() - => _builder(); + internal ModuleBase CreateInstance(IDependencyMap map) + => _builder(map); public override string ToString() => Name; private string DebuggerDisplay => Name; diff --git a/src/Discord.Net.Commands/ReflectionUtils.cs b/src/Discord.Net.Commands/ReflectionUtils.cs index e84a037ef..052e5fe98 100644 --- a/src/Discord.Net.Commands/ReflectionUtils.cs +++ b/src/Discord.Net.Commands/ReflectionUtils.cs @@ -7,9 +7,9 @@ namespace Discord.Commands internal class ReflectionUtils { internal static T CreateObject(TypeInfo typeInfo, CommandService service, IDependencyMap map = null) - => CreateBuilder(typeInfo, service, map)(); + => CreateBuilder(typeInfo, service)(map); - internal static Func CreateBuilder(TypeInfo typeInfo, CommandService service, IDependencyMap map = null) + internal static Func CreateBuilder(TypeInfo typeInfo, CommandService service) { var constructors = typeInfo.DeclaredConstructors.Where(x => !x.IsStatic).ToArray(); if (constructors.Length == 0) @@ -19,26 +19,27 @@ namespace Discord.Commands var constructor = constructors[0]; ParameterInfo[] parameters = constructor.GetParameters(); - object[] args = new object[parameters.Length]; - for (int i = 0; i < parameters.Length; i++) + return (map) => { - var parameter = parameters[i]; - object arg; - if (map == null || !map.TryGet(parameter.ParameterType, out arg)) + object[] args = new object[parameters.Length]; + + for (int i = 0; i < parameters.Length; i++) { - if (parameter.ParameterType == typeof(CommandService)) - arg = service; - else if (parameter.ParameterType == typeof(IDependencyMap)) - arg = map; - else - throw new InvalidOperationException($"Failed to create \"{typeInfo.FullName}\", dependency \"{parameter.ParameterType.Name}\" was not found."); + var parameter = parameters[i]; + object arg; + if (map == null || !map.TryGet(parameter.ParameterType, out arg)) + { + if (parameter.ParameterType == typeof(CommandService)) + arg = service; + else if (parameter.ParameterType == typeof(IDependencyMap)) + arg = map; + else + throw new InvalidOperationException($"Failed to create \"{typeInfo.FullName}\", dependency \"{parameter.ParameterType.Name}\" was not found."); + } + args[i] = arg; } - args[i] = arg; - } - return () => - { try { return (T)constructor.Invoke(args);