Move DI stuff around to support scoped DIpull/322/head
@@ -15,7 +15,7 @@ namespace Discord.Commands | |||||
private static readonly MethodInfo _convertParamsMethod = typeof(CommandInfo).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList)); | private static readonly MethodInfo _convertParamsMethod = typeof(CommandInfo).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList)); | ||||
private static readonly ConcurrentDictionary<Type, Func<IEnumerable<object>, object>> _arrayConverters = new ConcurrentDictionary<Type, Func<IEnumerable<object>, object>>(); | private static readonly ConcurrentDictionary<Type, Func<IEnumerable<object>, object>> _arrayConverters = new ConcurrentDictionary<Type, Func<IEnumerable<object>, object>>(); | ||||
private readonly Func<CommandContext, object[], Task> _action; | |||||
private readonly Func<CommandContext, object[], IDependencyMap, Task> _action; | |||||
public MethodInfo Source { get; } | public MethodInfo Source { get; } | ||||
public ModuleInfo Module { get; } | public ModuleInfo Module { get; } | ||||
@@ -125,7 +125,7 @@ namespace Discord.Commands | |||||
return await CommandParser.ParseArgs(this, context, input, 0).ConfigureAwait(false); | return await CommandParser.ParseArgs(this, context, input, 0).ConfigureAwait(false); | ||||
} | } | ||||
public Task<ExecuteResult> Execute(CommandContext context, ParseResult parseResult) | |||||
public Task<ExecuteResult> Execute(CommandContext context, ParseResult parseResult, IDependencyMap map) | |||||
{ | { | ||||
if (!parseResult.IsSuccess) | if (!parseResult.IsSuccess) | ||||
return Task.FromResult(ExecuteResult.FromError(parseResult)); | return Task.FromResult(ExecuteResult.FromError(parseResult)); | ||||
@@ -146,9 +146,9 @@ namespace Discord.Commands | |||||
paramList[i] = parseResult.ParamValues[i].Values.First().Value; | paramList[i] = parseResult.ParamValues[i].Values.First().Value; | ||||
} | } | ||||
return Execute(context, argList, paramList); | |||||
return Execute(context, argList, paramList, map); | |||||
} | } | ||||
public async Task<ExecuteResult> Execute(CommandContext context, IEnumerable<object> argList, IEnumerable<object> paramList) | |||||
public async Task<ExecuteResult> Execute(CommandContext context, IEnumerable<object> argList, IEnumerable<object> paramList, IDependencyMap map) | |||||
{ | { | ||||
try | try | ||||
{ | { | ||||
@@ -156,13 +156,13 @@ namespace Discord.Commands | |||||
switch (RunMode) | switch (RunMode) | ||||
{ | { | ||||
case RunMode.Sync: //Always sync | case RunMode.Sync: //Always sync | ||||
await _action(context, args).ConfigureAwait(false); | |||||
await _action(context, args, map).ConfigureAwait(false); | |||||
break; | break; | ||||
case RunMode.Mixed: //Sync until first await statement | case RunMode.Mixed: //Sync until first await statement | ||||
var t1 = _action(context, args); | |||||
var t1 = _action(context, args, map); | |||||
break; | break; | ||||
case RunMode.Async: //Always async | case RunMode.Async: //Always async | ||||
var t2 = Task.Run(() => _action(context, args)); | |||||
var t2 = Task.Run(() => _action(context, args, map)); | |||||
break; | break; | ||||
} | } | ||||
return ExecuteResult.FromSuccess(); | return ExecuteResult.FromSuccess(); | ||||
@@ -219,14 +219,14 @@ namespace Discord.Commands | |||||
} | } | ||||
return paramBuilder.ToImmutable(); | return paramBuilder.ToImmutable(); | ||||
} | } | ||||
private Func<CommandContext, object[], Task> BuildAction(MethodInfo methodInfo) | |||||
private Func<CommandContext, object[], IDependencyMap, Task> BuildAction(MethodInfo methodInfo) | |||||
{ | { | ||||
if (methodInfo.ReturnType != typeof(Task)) | if (methodInfo.ReturnType != typeof(Task)) | ||||
throw new InvalidOperationException("Commands must return a non-generic Task."); | throw new InvalidOperationException("Commands must return a non-generic Task."); | ||||
return (context, args) => | |||||
return (context, args, map) => | |||||
{ | { | ||||
var instance = Module.CreateInstance(); | |||||
var instance = Module.CreateInstance(map); | |||||
instance.Context = context; | instance.Context = context; | ||||
try | try | ||||
{ | { | ||||
@@ -65,7 +65,7 @@ namespace Discord.Commands | |||||
} | } | ||||
//Modules | //Modules | ||||
public async Task<ModuleInfo> AddModule<T>(IDependencyMap dependencyMap = null) | |||||
public async Task<ModuleInfo> AddModule<T>() | |||||
{ | { | ||||
await _moduleLock.WaitAsync().ConfigureAwait(false); | await _moduleLock.WaitAsync().ConfigureAwait(false); | ||||
try | try | ||||
@@ -80,14 +80,14 @@ namespace Discord.Commands | |||||
if (_moduleDefs.ContainsKey(typeof(T))) | if (_moduleDefs.ContainsKey(typeof(T))) | ||||
throw new ArgumentException($"This module has already been added."); | throw new ArgumentException($"This module has already been added."); | ||||
return AddModuleInternal(typeInfo, dependencyMap); | |||||
return AddModuleInternal(typeInfo); | |||||
} | } | ||||
finally | finally | ||||
{ | { | ||||
_moduleLock.Release(); | _moduleLock.Release(); | ||||
} | } | ||||
} | } | ||||
public async Task<IEnumerable<ModuleInfo>> AddModules(Assembly assembly, IDependencyMap dependencyMap = null) | |||||
public async Task<IEnumerable<ModuleInfo>> AddModules(Assembly assembly) | |||||
{ | { | ||||
var moduleDefs = ImmutableArray.CreateBuilder<ModuleInfo>(); | var moduleDefs = ImmutableArray.CreateBuilder<ModuleInfo>(); | ||||
await _moduleLock.WaitAsync().ConfigureAwait(false); | await _moduleLock.WaitAsync().ConfigureAwait(false); | ||||
@@ -102,7 +102,7 @@ namespace Discord.Commands | |||||
{ | { | ||||
var dontAutoLoad = typeInfo.GetCustomAttribute<DontAutoLoadAttribute>(); | var dontAutoLoad = typeInfo.GetCustomAttribute<DontAutoLoadAttribute>(); | ||||
if (dontAutoLoad == null && !typeInfo.IsAbstract) | if (dontAutoLoad == null && !typeInfo.IsAbstract) | ||||
moduleDefs.Add(AddModuleInternal(typeInfo, dependencyMap)); | |||||
moduleDefs.Add(AddModuleInternal(typeInfo)); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -113,9 +113,9 @@ namespace Discord.Commands | |||||
_moduleLock.Release(); | _moduleLock.Release(); | ||||
} | } | ||||
} | } | ||||
private ModuleInfo AddModuleInternal(TypeInfo typeInfo, IDependencyMap dependencyMap) | |||||
private ModuleInfo AddModuleInternal(TypeInfo typeInfo) | |||||
{ | { | ||||
var moduleDef = new ModuleInfo(typeInfo, this, dependencyMap); | |||||
var moduleDef = new ModuleInfo(typeInfo, this); | |||||
_moduleDefs[typeInfo.AsType()] = moduleDef; | _moduleDefs[typeInfo.AsType()] = moduleDef; | ||||
foreach (var cmd in moduleDef.Commands) | foreach (var cmd in moduleDef.Commands) | ||||
@@ -236,7 +236,7 @@ namespace Discord.Commands | |||||
} | } | ||||
} | } | ||||
return await commands[i].Execute(context, parseResult).ConfigureAwait(false); | |||||
return await commands[i].Execute(context, parseResult, dependencyMap).ConfigureAwait(false); | |||||
} | } | ||||
return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload."); | return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload."); | ||||
@@ -9,7 +9,7 @@ namespace Discord.Commands | |||||
[DebuggerDisplay(@"{DebuggerDisplay,nq}")] | [DebuggerDisplay(@"{DebuggerDisplay,nq}")] | ||||
public class ModuleInfo | public class ModuleInfo | ||||
{ | { | ||||
internal readonly Func<ModuleBase> _builder; | |||||
internal readonly Func<IDependencyMap, ModuleBase> _builder; | |||||
public TypeInfo Source { get; } | public TypeInfo Source { get; } | ||||
public CommandService Service { get; } | public CommandService Service { get; } | ||||
@@ -20,12 +20,12 @@ namespace Discord.Commands | |||||
public IEnumerable<CommandInfo> Commands { get; } | public IEnumerable<CommandInfo> Commands { get; } | ||||
public IReadOnlyList<PreconditionAttribute> Preconditions { get; } | public IReadOnlyList<PreconditionAttribute> Preconditions { get; } | ||||
internal ModuleInfo(TypeInfo source, CommandService service, IDependencyMap dependencyMap) | |||||
internal ModuleInfo(TypeInfo source, CommandService service) | |||||
{ | { | ||||
Source = source; | Source = source; | ||||
Service = service; | Service = service; | ||||
Name = source.Name; | Name = source.Name; | ||||
_builder = ReflectionUtils.CreateBuilder<ModuleBase>(source, Service, dependencyMap); | |||||
_builder = ReflectionUtils.CreateBuilder<ModuleBase>(source, Service); | |||||
var groupAttr = source.GetCustomAttribute<GroupAttribute>(); | var groupAttr = source.GetCustomAttribute<GroupAttribute>(); | ||||
if (groupAttr != null) | if (groupAttr != null) | ||||
@@ -46,12 +46,12 @@ namespace Discord.Commands | |||||
Remarks = remarksAttr.Text; | Remarks = remarksAttr.Text; | ||||
List<CommandInfo> commands = new List<CommandInfo>(); | List<CommandInfo> commands = new List<CommandInfo>(); | ||||
SearchClass(source, commands, Prefix, dependencyMap); | |||||
SearchClass(source, commands, Prefix); | |||||
Commands = commands; | Commands = commands; | ||||
Preconditions = Source.GetCustomAttributes<PreconditionAttribute>().ToImmutableArray(); | Preconditions = Source.GetCustomAttributes<PreconditionAttribute>().ToImmutableArray(); | ||||
} | } | ||||
private void SearchClass(TypeInfo parentType, List<CommandInfo> commands, string groupPrefix, IDependencyMap dependencyMap) | |||||
private void SearchClass(TypeInfo parentType, List<CommandInfo> commands, string groupPrefix) | |||||
{ | { | ||||
foreach (var method in parentType.DeclaredMethods) | foreach (var method in parentType.DeclaredMethods) | ||||
{ | { | ||||
@@ -71,13 +71,13 @@ namespace Discord.Commands | |||||
else | else | ||||
nextGroupPrefix = groupAttrib.Prefix ?? type.Name.ToLowerInvariant(); | nextGroupPrefix = groupAttrib.Prefix ?? type.Name.ToLowerInvariant(); | ||||
SearchClass(type, commands, nextGroupPrefix, dependencyMap); | |||||
SearchClass(type, commands, nextGroupPrefix); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
internal ModuleBase CreateInstance() | |||||
=> _builder(); | |||||
internal ModuleBase CreateInstance(IDependencyMap map) | |||||
=> _builder(map); | |||||
public override string ToString() => Name; | public override string ToString() => Name; | ||||
private string DebuggerDisplay => Name; | private string DebuggerDisplay => Name; | ||||
@@ -7,9 +7,9 @@ namespace Discord.Commands | |||||
internal class ReflectionUtils | internal class ReflectionUtils | ||||
{ | { | ||||
internal static T CreateObject<T>(TypeInfo typeInfo, CommandService service, IDependencyMap map = null) | internal static T CreateObject<T>(TypeInfo typeInfo, CommandService service, IDependencyMap map = null) | ||||
=> CreateBuilder<T>(typeInfo, service, map)(); | |||||
=> CreateBuilder<T>(typeInfo, service)(map); | |||||
internal static Func<T> CreateBuilder<T>(TypeInfo typeInfo, CommandService service, IDependencyMap map = null) | |||||
internal static Func<IDependencyMap, T> CreateBuilder<T>(TypeInfo typeInfo, CommandService service) | |||||
{ | { | ||||
var constructors = typeInfo.DeclaredConstructors.Where(x => !x.IsStatic).ToArray(); | var constructors = typeInfo.DeclaredConstructors.Where(x => !x.IsStatic).ToArray(); | ||||
if (constructors.Length == 0) | if (constructors.Length == 0) | ||||
@@ -19,26 +19,27 @@ namespace Discord.Commands | |||||
var constructor = constructors[0]; | var constructor = constructors[0]; | ||||
ParameterInfo[] parameters = constructor.GetParameters(); | ParameterInfo[] parameters = constructor.GetParameters(); | ||||
object[] args = new object[parameters.Length]; | |||||
for (int i = 0; i < parameters.Length; i++) | |||||
return (map) => | |||||
{ | { | ||||
var parameter = parameters[i]; | |||||
object arg; | |||||
if (map == null || !map.TryGet(parameter.ParameterType, out arg)) | |||||
object[] args = new object[parameters.Length]; | |||||
for (int i = 0; i < parameters.Length; i++) | |||||
{ | { | ||||
if (parameter.ParameterType == typeof(CommandService)) | |||||
arg = service; | |||||
else if (parameter.ParameterType == typeof(IDependencyMap)) | |||||
arg = map; | |||||
else | |||||
throw new InvalidOperationException($"Failed to create \"{typeInfo.FullName}\", dependency \"{parameter.ParameterType.Name}\" was not found."); | |||||
var parameter = parameters[i]; | |||||
object arg; | |||||
if (map == null || !map.TryGet(parameter.ParameterType, out arg)) | |||||
{ | |||||
if (parameter.ParameterType == typeof(CommandService)) | |||||
arg = service; | |||||
else if (parameter.ParameterType == typeof(IDependencyMap)) | |||||
arg = map; | |||||
else | |||||
throw new InvalidOperationException($"Failed to create \"{typeInfo.FullName}\", dependency \"{parameter.ParameterType.Name}\" was not found."); | |||||
} | |||||
args[i] = arg; | |||||
} | } | ||||
args[i] = arg; | |||||
} | |||||
return () => | |||||
{ | |||||
try | try | ||||
{ | { | ||||
return (T)constructor.Invoke(args); | return (T)constructor.Invoke(args); | ||||