Move DI stuff around to support scoped DIpull/322/head
@@ -15,7 +15,7 @@ namespace Discord.Commands | |||
private static readonly MethodInfo _convertParamsMethod = typeof(CommandInfo).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList)); | |||
private static readonly ConcurrentDictionary<Type, Func<IEnumerable<object>, object>> _arrayConverters = new ConcurrentDictionary<Type, Func<IEnumerable<object>, object>>(); | |||
private readonly Func<CommandContext, object[], Task> _action; | |||
private readonly Func<CommandContext, object[], IDependencyMap, Task> _action; | |||
public MethodInfo Source { get; } | |||
public ModuleInfo Module { get; } | |||
@@ -125,7 +125,7 @@ namespace Discord.Commands | |||
return await CommandParser.ParseArgs(this, context, input, 0).ConfigureAwait(false); | |||
} | |||
public Task<ExecuteResult> Execute(CommandContext context, ParseResult parseResult) | |||
public Task<ExecuteResult> Execute(CommandContext context, ParseResult parseResult, IDependencyMap map) | |||
{ | |||
if (!parseResult.IsSuccess) | |||
return Task.FromResult(ExecuteResult.FromError(parseResult)); | |||
@@ -146,9 +146,9 @@ namespace Discord.Commands | |||
paramList[i] = parseResult.ParamValues[i].Values.First().Value; | |||
} | |||
return Execute(context, argList, paramList); | |||
return Execute(context, argList, paramList, map); | |||
} | |||
public async Task<ExecuteResult> Execute(CommandContext context, IEnumerable<object> argList, IEnumerable<object> paramList) | |||
public async Task<ExecuteResult> Execute(CommandContext context, IEnumerable<object> argList, IEnumerable<object> paramList, IDependencyMap map) | |||
{ | |||
try | |||
{ | |||
@@ -156,13 +156,13 @@ namespace Discord.Commands | |||
switch (RunMode) | |||
{ | |||
case RunMode.Sync: //Always sync | |||
await _action(context, args).ConfigureAwait(false); | |||
await _action(context, args, map).ConfigureAwait(false); | |||
break; | |||
case RunMode.Mixed: //Sync until first await statement | |||
var t1 = _action(context, args); | |||
var t1 = _action(context, args, map); | |||
break; | |||
case RunMode.Async: //Always async | |||
var t2 = Task.Run(() => _action(context, args)); | |||
var t2 = Task.Run(() => _action(context, args, map)); | |||
break; | |||
} | |||
return ExecuteResult.FromSuccess(); | |||
@@ -219,14 +219,14 @@ namespace Discord.Commands | |||
} | |||
return paramBuilder.ToImmutable(); | |||
} | |||
private Func<CommandContext, object[], Task> BuildAction(MethodInfo methodInfo) | |||
private Func<CommandContext, object[], IDependencyMap, Task> BuildAction(MethodInfo methodInfo) | |||
{ | |||
if (methodInfo.ReturnType != typeof(Task)) | |||
throw new InvalidOperationException("Commands must return a non-generic Task."); | |||
return (context, args) => | |||
return (context, args, map) => | |||
{ | |||
var instance = Module.CreateInstance(); | |||
var instance = Module.CreateInstance(map); | |||
instance.Context = context; | |||
try | |||
{ | |||
@@ -65,7 +65,7 @@ namespace Discord.Commands | |||
} | |||
//Modules | |||
public async Task<ModuleInfo> AddModule<T>(IDependencyMap dependencyMap = null) | |||
public async Task<ModuleInfo> AddModule<T>() | |||
{ | |||
await _moduleLock.WaitAsync().ConfigureAwait(false); | |||
try | |||
@@ -80,14 +80,14 @@ namespace Discord.Commands | |||
if (_moduleDefs.ContainsKey(typeof(T))) | |||
throw new ArgumentException($"This module has already been added."); | |||
return AddModuleInternal(typeInfo, dependencyMap); | |||
return AddModuleInternal(typeInfo); | |||
} | |||
finally | |||
{ | |||
_moduleLock.Release(); | |||
} | |||
} | |||
public async Task<IEnumerable<ModuleInfo>> AddModules(Assembly assembly, IDependencyMap dependencyMap = null) | |||
public async Task<IEnumerable<ModuleInfo>> AddModules(Assembly assembly) | |||
{ | |||
var moduleDefs = ImmutableArray.CreateBuilder<ModuleInfo>(); | |||
await _moduleLock.WaitAsync().ConfigureAwait(false); | |||
@@ -102,7 +102,7 @@ namespace Discord.Commands | |||
{ | |||
var dontAutoLoad = typeInfo.GetCustomAttribute<DontAutoLoadAttribute>(); | |||
if (dontAutoLoad == null && !typeInfo.IsAbstract) | |||
moduleDefs.Add(AddModuleInternal(typeInfo, dependencyMap)); | |||
moduleDefs.Add(AddModuleInternal(typeInfo)); | |||
} | |||
} | |||
} | |||
@@ -113,9 +113,9 @@ namespace Discord.Commands | |||
_moduleLock.Release(); | |||
} | |||
} | |||
private ModuleInfo AddModuleInternal(TypeInfo typeInfo, IDependencyMap dependencyMap) | |||
private ModuleInfo AddModuleInternal(TypeInfo typeInfo) | |||
{ | |||
var moduleDef = new ModuleInfo(typeInfo, this, dependencyMap); | |||
var moduleDef = new ModuleInfo(typeInfo, this); | |||
_moduleDefs[typeInfo.AsType()] = moduleDef; | |||
foreach (var cmd in moduleDef.Commands) | |||
@@ -236,7 +236,7 @@ namespace Discord.Commands | |||
} | |||
} | |||
return await commands[i].Execute(context, parseResult).ConfigureAwait(false); | |||
return await commands[i].Execute(context, parseResult, dependencyMap).ConfigureAwait(false); | |||
} | |||
return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload."); | |||
@@ -9,7 +9,7 @@ namespace Discord.Commands | |||
[DebuggerDisplay(@"{DebuggerDisplay,nq}")] | |||
public class ModuleInfo | |||
{ | |||
internal readonly Func<ModuleBase> _builder; | |||
internal readonly Func<IDependencyMap, ModuleBase> _builder; | |||
public TypeInfo Source { get; } | |||
public CommandService Service { get; } | |||
@@ -20,12 +20,12 @@ namespace Discord.Commands | |||
public IEnumerable<CommandInfo> Commands { get; } | |||
public IReadOnlyList<PreconditionAttribute> Preconditions { get; } | |||
internal ModuleInfo(TypeInfo source, CommandService service, IDependencyMap dependencyMap) | |||
internal ModuleInfo(TypeInfo source, CommandService service) | |||
{ | |||
Source = source; | |||
Service = service; | |||
Name = source.Name; | |||
_builder = ReflectionUtils.CreateBuilder<ModuleBase>(source, Service, dependencyMap); | |||
_builder = ReflectionUtils.CreateBuilder<ModuleBase>(source, Service); | |||
var groupAttr = source.GetCustomAttribute<GroupAttribute>(); | |||
if (groupAttr != null) | |||
@@ -46,12 +46,12 @@ namespace Discord.Commands | |||
Remarks = remarksAttr.Text; | |||
List<CommandInfo> commands = new List<CommandInfo>(); | |||
SearchClass(source, commands, Prefix, dependencyMap); | |||
SearchClass(source, commands, Prefix); | |||
Commands = commands; | |||
Preconditions = Source.GetCustomAttributes<PreconditionAttribute>().ToImmutableArray(); | |||
} | |||
private void SearchClass(TypeInfo parentType, List<CommandInfo> commands, string groupPrefix, IDependencyMap dependencyMap) | |||
private void SearchClass(TypeInfo parentType, List<CommandInfo> commands, string groupPrefix) | |||
{ | |||
foreach (var method in parentType.DeclaredMethods) | |||
{ | |||
@@ -71,13 +71,13 @@ namespace Discord.Commands | |||
else | |||
nextGroupPrefix = groupAttrib.Prefix ?? type.Name.ToLowerInvariant(); | |||
SearchClass(type, commands, nextGroupPrefix, dependencyMap); | |||
SearchClass(type, commands, nextGroupPrefix); | |||
} | |||
} | |||
} | |||
internal ModuleBase CreateInstance() | |||
=> _builder(); | |||
internal ModuleBase CreateInstance(IDependencyMap map) | |||
=> _builder(map); | |||
public override string ToString() => Name; | |||
private string DebuggerDisplay => Name; | |||
@@ -7,9 +7,9 @@ namespace Discord.Commands | |||
internal class ReflectionUtils | |||
{ | |||
internal static T CreateObject<T>(TypeInfo typeInfo, CommandService service, IDependencyMap map = null) | |||
=> CreateBuilder<T>(typeInfo, service, map)(); | |||
=> CreateBuilder<T>(typeInfo, service)(map); | |||
internal static Func<T> CreateBuilder<T>(TypeInfo typeInfo, CommandService service, IDependencyMap map = null) | |||
internal static Func<IDependencyMap, T> CreateBuilder<T>(TypeInfo typeInfo, CommandService service) | |||
{ | |||
var constructors = typeInfo.DeclaredConstructors.Where(x => !x.IsStatic).ToArray(); | |||
if (constructors.Length == 0) | |||
@@ -19,26 +19,27 @@ namespace Discord.Commands | |||
var constructor = constructors[0]; | |||
ParameterInfo[] parameters = constructor.GetParameters(); | |||
object[] args = new object[parameters.Length]; | |||
for (int i = 0; i < parameters.Length; i++) | |||
return (map) => | |||
{ | |||
var parameter = parameters[i]; | |||
object arg; | |||
if (map == null || !map.TryGet(parameter.ParameterType, out arg)) | |||
object[] args = new object[parameters.Length]; | |||
for (int i = 0; i < parameters.Length; i++) | |||
{ | |||
if (parameter.ParameterType == typeof(CommandService)) | |||
arg = service; | |||
else if (parameter.ParameterType == typeof(IDependencyMap)) | |||
arg = map; | |||
else | |||
throw new InvalidOperationException($"Failed to create \"{typeInfo.FullName}\", dependency \"{parameter.ParameterType.Name}\" was not found."); | |||
var parameter = parameters[i]; | |||
object arg; | |||
if (map == null || !map.TryGet(parameter.ParameterType, out arg)) | |||
{ | |||
if (parameter.ParameterType == typeof(CommandService)) | |||
arg = service; | |||
else if (parameter.ParameterType == typeof(IDependencyMap)) | |||
arg = map; | |||
else | |||
throw new InvalidOperationException($"Failed to create \"{typeInfo.FullName}\", dependency \"{parameter.ParameterType.Name}\" was not found."); | |||
} | |||
args[i] = arg; | |||
} | |||
args[i] = arg; | |||
} | |||
return () => | |||
{ | |||
try | |||
{ | |||
return (T)constructor.Invoke(args); | |||