diff --git a/src/Discord.Net.Commands/Attributes/ModuleAttribute.cs b/src/Discord.Net.Commands/Attributes/ModuleAttribute.cs index 59e6a6aca..ec04041e8 100644 --- a/src/Discord.Net.Commands/Attributes/ModuleAttribute.cs +++ b/src/Discord.Net.Commands/Attributes/ModuleAttribute.cs @@ -6,13 +6,16 @@ namespace Discord.Commands public class ModuleAttribute : Attribute { public string Prefix { get; } + public bool AutoLoad { get; set; } public ModuleAttribute() { Prefix = null; + AutoLoad = true; } public ModuleAttribute(string prefix) { Prefix = prefix; + AutoLoad = true; } } } diff --git a/src/Discord.Net.Commands/CommandService.cs b/src/Discord.Net.Commands/CommandService.cs index f762ae366..46c5aaa39 100644 --- a/src/Discord.Net.Commands/CommandService.cs +++ b/src/Discord.Net.Commands/CommandService.cs @@ -164,7 +164,7 @@ namespace Discord.Commands return loadedModule; } - public async Task> LoadAssembly(Assembly assembly) + public async Task> LoadAssembly(Assembly assembly, IDependencyMap dependencyMap = null) { var modules = ImmutableArray.CreateBuilder(); await _moduleLock.WaitAsync().ConfigureAwait(false); @@ -174,9 +174,9 @@ namespace Discord.Commands { var typeInfo = type.GetTypeInfo(); var moduleAttr = typeInfo.GetCustomAttribute(); - if (moduleAttr != null) + if (moduleAttr != null && moduleAttr.AutoLoad) { - var moduleInstance = ReflectionUtils.CreateObject(typeInfo); + var moduleInstance = ReflectionUtils.CreateObject(typeInfo, this, dependencyMap); modules.Add(LoadInternal(moduleInstance, moduleAttr, typeInfo)); } } diff --git a/src/Discord.Net.Commands/Dependencies/DependencyMap.cs b/src/Discord.Net.Commands/Dependencies/DependencyMap.cs new file mode 100644 index 000000000..db4d20984 --- /dev/null +++ b/src/Discord.Net.Commands/Dependencies/DependencyMap.cs @@ -0,0 +1,36 @@ +using System; +using System.Collections.Generic; +using System.Reflection; + +namespace Discord.Commands +{ + public class DependencyMap : IDependencyMap + { + private Dictionary map; + + public DependencyMap() + { + map = new Dictionary(); + } + + public object Get(Type t) + { + if (!map.ContainsKey(t)) + throw new KeyNotFoundException($"The dependency map does not contain \"{t.FullName}\""); + return map[t]; + } + + public T Get() where T : class + { + return Get(typeof(T)) as T; + } + + public void Add(T obj) + { + var t = typeof(T); + if (map.ContainsKey(t)) + throw new InvalidOperationException($"The dependency map already contains \"{t.FullName}\""); + map.Add(t, obj); + } + } +} diff --git a/src/Discord.Net.Commands/Dependencies/IDependencyMap.cs b/src/Discord.Net.Commands/Dependencies/IDependencyMap.cs new file mode 100644 index 000000000..859cff09b --- /dev/null +++ b/src/Discord.Net.Commands/Dependencies/IDependencyMap.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +namespace Discord.Commands +{ + public interface IDependencyMap + { + object Get(Type t); + T Get() where T : class; + void Add(T obj); + } +} diff --git a/src/Discord.Net.Commands/Module.cs b/src/Discord.Net.Commands/Module.cs index ea6e29c28..b884832bc 100644 --- a/src/Discord.Net.Commands/Module.cs +++ b/src/Discord.Net.Commands/Module.cs @@ -43,7 +43,7 @@ namespace Discord.Commands nextGroupPrefix = groupPrefix + groupAttrib.Prefix ?? type.Name; else nextGroupPrefix = groupPrefix; - SearchClass(ReflectionUtils.CreateObject(type), commands, type, nextGroupPrefix); + SearchClass(ReflectionUtils.CreateObject(type, Service), commands, type, nextGroupPrefix); } } } diff --git a/src/Discord.Net.Commands/ReflectionUtils.cs b/src/Discord.Net.Commands/ReflectionUtils.cs index 28672a06f..62c77ff64 100644 --- a/src/Discord.Net.Commands/ReflectionUtils.cs +++ b/src/Discord.Net.Commands/ReflectionUtils.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Linq; using System.Reflection; @@ -6,18 +7,64 @@ namespace Discord.Commands { internal class ReflectionUtils { - internal static object CreateObject(TypeInfo typeInfo) + internal static object CreateObject(TypeInfo typeInfo, CommandService service, IDependencyMap map = null) { - var constructor = typeInfo.DeclaredConstructors.Where(x => x.GetParameters().Length == 0).FirstOrDefault(); + if (typeInfo.DeclaredConstructors.Count() > 1) + throw new InvalidOperationException($"Found too many constructors for \"{typeInfo.FullName}\""); + + var constructor = typeInfo.DeclaredConstructors.FirstOrDefault(); + if (constructor == null) - throw new InvalidOperationException($"Failed to find a valid constructor for \"{typeInfo.FullName}\""); + throw new InvalidOperationException($"Found no constructor for \"{typeInfo.FullName}\""); + + object[] arguments = null; + + ParameterInfo[] parameters = constructor.GetParameters(); + + // TODO: can this logic be made better/cleaner? + if (parameters.Length == 1) + { + if (parameters[0].ParameterType == typeof(IDependencyMap)) + { + if (map != null) + arguments = new object[] { map }; + else + throw new InvalidOperationException($"Could not find a valid constructor for \"{typeInfo.FullName}\" (an IDependencyMap is required)"); + } + } + else if (parameters.Length == 2) + { + if (parameters[0].ParameterType == typeof(CommandService) && parameters[1].ParameterType == typeof(IDependencyMap)) + if (map != null) + arguments = new object[] { service, map }; + else + throw new InvalidOperationException($"Could not find a valid constructor for \"{typeInfo.FullName}\" (an IDependencyMap is required)"); + } + + if (arguments == null) + { + try + { + // TODO: probably change this ternary into something sensible? + arguments = parameters.Select(x => x.ParameterType == typeof(CommandService) ? service : map.Get(x.ParameterType)).ToArray(); + } + catch (KeyNotFoundException ex) // tried to inject an invalid dependency + { + throw new InvalidOperationException($"Could not find a valid constructor for \"{typeInfo.FullName}\" (could not provide parameter)", ex); + } + catch (NullReferenceException ex) // tried to find a dependency + { + throw new InvalidOperationException($"Could not find a valid constructor for \"{typeInfo.FullName}\" (an IDependencyMap is required)", ex); + } + } + try { - return constructor.Invoke(null); + return constructor.Invoke(arguments); } catch (Exception ex) { - throw new InvalidOperationException($"Failed to create \"{typeInfo.FullName}\"", ex); + throw new InvalidOperationException($"Could not create \"{typeInfo.FullName}\"", ex); } } }