using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using System.Reflection; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; using Discord.Commands.Builders; using Discord.Logging; namespace Discord.Commands { public class CommandService { public event Func Log { add { _logEvent.Add(value); } remove { _logEvent.Remove(value); } } internal readonly AsyncEvent> _logEvent = new AsyncEvent>(); public event Func CommandExecuted { add { _commandExecutedEvent.Add(value); } remove { _commandExecutedEvent.Remove(value); } } internal readonly AsyncEvent> _commandExecutedEvent = new AsyncEvent>(); private readonly SemaphoreSlim _moduleLock; private readonly ConcurrentDictionary _typedModuleDefs; private readonly ConcurrentDictionary> _typeReaders; private readonly ConcurrentDictionary _defaultTypeReaders; private readonly ImmutableList> _entityTypeReaders; //TODO: Candidate for C#7 Tuple private readonly HashSet _moduleDefs; private readonly CommandMap _map; internal readonly bool _caseSensitive, _throwOnError, _ignoreExtraArgs; internal readonly char _separatorChar; internal readonly RunMode _defaultRunMode; internal readonly Logger _cmdLogger; internal readonly LogManager _logManager; public IEnumerable Modules => _moduleDefs.Select(x => x); public IEnumerable Commands => _moduleDefs.SelectMany(x => x.Commands); public ILookup TypeReaders => _typeReaders.SelectMany(x => x.Value.Select(y => new { y.Key, y.Value })).ToLookup(x => x.Key, x => x.Value); public CommandService() : this(new CommandServiceConfig()) { } public CommandService(CommandServiceConfig config) { _caseSensitive = config.CaseSensitiveCommands; _throwOnError = config.ThrowOnError; _ignoreExtraArgs = config.IgnoreExtraArgs; _separatorChar = config.SeparatorChar; _defaultRunMode = config.DefaultRunMode; if (_defaultRunMode == RunMode.Default) throw new InvalidOperationException("The default run mode cannot be set to Default."); _logManager = new LogManager(config.LogLevel); _logManager.Message += async msg => await _logEvent.InvokeAsync(msg).ConfigureAwait(false); _cmdLogger = _logManager.CreateLogger("Command"); _moduleLock = new SemaphoreSlim(1, 1); _typedModuleDefs = new ConcurrentDictionary(); _moduleDefs = new HashSet(); _map = new CommandMap(this); _typeReaders = new ConcurrentDictionary>(); _defaultTypeReaders = new ConcurrentDictionary(); foreach (var type in PrimitiveParsers.SupportedTypes) { _defaultTypeReaders[type] = PrimitiveTypeReader.Create(type); _defaultTypeReaders[typeof(Nullable<>).MakeGenericType(type)] = NullableTypeReader.Create(type, _defaultTypeReaders[type]); } _defaultTypeReaders[typeof(string)] = new PrimitiveTypeReader((string x, out string y) => { y = x; return true; }, 0); var entityTypeReaders = ImmutableList.CreateBuilder>(); entityTypeReaders.Add(new Tuple(typeof(IMessage), typeof(MessageTypeReader<>))); entityTypeReaders.Add(new Tuple(typeof(IChannel), typeof(ChannelTypeReader<>))); entityTypeReaders.Add(new Tuple(typeof(IRole), typeof(RoleTypeReader<>))); entityTypeReaders.Add(new Tuple(typeof(IUser), typeof(UserTypeReader<>))); _entityTypeReaders = entityTypeReaders.ToImmutable(); } //Modules public async Task CreateModuleAsync(string primaryAlias, Action buildFunc) { await _moduleLock.WaitAsync().ConfigureAwait(false); try { var builder = new ModuleBuilder(this, null, primaryAlias); buildFunc(builder); var module = builder.Build(this, null); return LoadModuleInternal(module); } finally { _moduleLock.Release(); } } /// /// Add a command module from a type /// /// The type of module /// An IServiceProvider for your dependency injection solution, if using one - otherwise, pass null /// A built module public Task AddModuleAsync(IServiceProvider services) => AddModuleAsync(typeof(T), services); public async Task AddModuleAsync(Type type, IServiceProvider services) { services = services ?? EmptyServiceProvider.Instance; await _moduleLock.WaitAsync().ConfigureAwait(false); try { var typeInfo = type.GetTypeInfo(); if (_typedModuleDefs.ContainsKey(type)) throw new ArgumentException($"This module has already been added."); var module = (await ModuleClassBuilder.BuildAsync(this, services, typeInfo).ConfigureAwait(false)).FirstOrDefault(); if (module.Value == default(ModuleInfo)) throw new InvalidOperationException($"Could not build the module {type.FullName}, did you pass an invalid type?"); _typedModuleDefs[module.Key] = module.Value; return LoadModuleInternal(module.Value); } finally { _moduleLock.Release(); } } /// /// Add command modules from an assembly /// /// The assembly containing command modules /// An IServiceProvider for your dependency injection solution, if using one - otherwise, pass null /// A collection of built modules public async Task> AddModulesAsync(Assembly assembly, IServiceProvider services) { services = services ?? EmptyServiceProvider.Instance; await _moduleLock.WaitAsync().ConfigureAwait(false); try { var types = await ModuleClassBuilder.SearchAsync(assembly, this).ConfigureAwait(false); var moduleDefs = await ModuleClassBuilder.BuildAsync(types, this, services).ConfigureAwait(false); foreach (var info in moduleDefs) { _typedModuleDefs[info.Key] = info.Value; LoadModuleInternal(info.Value); } return moduleDefs.Select(x => x.Value).ToImmutableArray(); } finally { _moduleLock.Release(); } } private ModuleInfo LoadModuleInternal(ModuleInfo module) { _moduleDefs.Add(module); foreach (var command in module.Commands) _map.AddCommand(command); foreach (var submodule in module.Submodules) LoadModuleInternal(submodule); return module; } public async Task RemoveModuleAsync(ModuleInfo module) { await _moduleLock.WaitAsync().ConfigureAwait(false); try { return RemoveModuleInternal(module); } finally { _moduleLock.Release(); } } public Task RemoveModuleAsync() => RemoveModuleAsync(typeof(T)); public async Task RemoveModuleAsync(Type type) { await _moduleLock.WaitAsync().ConfigureAwait(false); try { if (!_typedModuleDefs.TryRemove(type, out var module)) return false; return RemoveModuleInternal(module); } finally { _moduleLock.Release(); } } private bool RemoveModuleInternal(ModuleInfo module) { if (!_moduleDefs.Remove(module)) return false; foreach (var cmd in module.Commands) _map.RemoveCommand(cmd); foreach (var submodule in module.Submodules) { RemoveModuleInternal(submodule); } return true; } //Type Readers /// /// Adds a custom to this for the supplied object type. /// If is a , a will also be added. /// If a default exists for , a warning will be logged and the default will be replaced. /// /// The object type to be read by the . /// An instance of the to be added. public void AddTypeReader(TypeReader reader) => AddTypeReader(typeof(T), reader); /// /// Adds a custom to this for the supplied object type. /// If is a , a for the value type will also be added. /// If a default exists for , a warning will be logged and the default will be replaced. /// /// A instance for the type to be read. /// An instance of the to be added. public void AddTypeReader(Type type, TypeReader reader) { if (_defaultTypeReaders.ContainsKey(type)) _ = _cmdLogger.WarningAsync($"The default TypeReader for {type.FullName} was replaced by {reader.GetType().FullName}." + $"To suppress this message, use AddTypeReader(reader, true)."); AddTypeReader(type, reader, true); } /// /// Adds a custom to this for the supplied object type. /// If is a , a will also be added. /// /// The object type to be read by the . /// An instance of the to be added. /// If should replace the default for if one exists. public void AddTypeReader(TypeReader reader, bool replaceDefault) => AddTypeReader(typeof(T), reader, replaceDefault); /// /// Adds a custom to this for the supplied object type. /// If is a , a for the value type will also be added. /// /// A instance for the type to be read. /// An instance of the to be added. /// If should replace the default for if one exists. public void AddTypeReader(Type type, TypeReader reader, bool replaceDefault) { if (replaceDefault && _defaultTypeReaders.ContainsKey(type)) { _defaultTypeReaders.AddOrUpdate(type, reader, (k, v) => reader); if (type.GetTypeInfo().IsValueType) { var nullableType = typeof(Nullable<>).MakeGenericType(type); var nullableReader = NullableTypeReader.Create(type, reader); _defaultTypeReaders.AddOrUpdate(nullableType, nullableReader, (k, v) => nullableReader); } } else { var readers = _typeReaders.GetOrAdd(type, x => new ConcurrentDictionary()); readers[reader.GetType()] = reader; if (type.GetTypeInfo().IsValueType) AddNullableTypeReader(type, reader); } } internal void AddNullableTypeReader(Type valueType, TypeReader valueTypeReader) { var readers = _typeReaders.GetOrAdd(typeof(Nullable<>).MakeGenericType(valueType), x => new ConcurrentDictionary()); var nullableReader = NullableTypeReader.Create(valueType, valueTypeReader); readers[nullableReader.GetType()] = nullableReader; } internal IDictionary GetTypeReaders(Type type) { if (_typeReaders.TryGetValue(type, out var definedTypeReaders)) return definedTypeReaders; return null; } internal TypeReader GetDefaultTypeReader(Type type) { if (_defaultTypeReaders.TryGetValue(type, out var reader)) return reader; var typeInfo = type.GetTypeInfo(); //Is this an enum? if (typeInfo.IsEnum) { reader = EnumTypeReader.GetReader(type); _defaultTypeReaders[type] = reader; return reader; } //Is this an entity? for (int i = 0; i < _entityTypeReaders.Count; i++) { if (type == _entityTypeReaders[i].Item1 || typeInfo.ImplementedInterfaces.Contains(_entityTypeReaders[i].Item1)) { reader = Activator.CreateInstance(_entityTypeReaders[i].Item2.MakeGenericType(type)) as TypeReader; _defaultTypeReaders[type] = reader; return reader; } } return null; } //Execution public SearchResult Search(ICommandContext context, int argPos) => Search(context, context.Message.Content.Substring(argPos)); public SearchResult Search(ICommandContext context, string input) { string searchInput = _caseSensitive ? input : input.ToLowerInvariant(); var matches = _map.GetCommands(searchInput).OrderByDescending(x => x.Command.Priority).ToImmutableArray(); if (matches.Length > 0) return SearchResult.FromSuccess(input, matches); else return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command."); } public Task ExecuteAsync(ICommandContext context, int argPos, IServiceProvider services, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) => ExecuteAsync(context, context.Message.Content.Substring(argPos), services, multiMatchHandling); public async Task ExecuteAsync(ICommandContext context, string input, IServiceProvider services, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) { services = services ?? EmptyServiceProvider.Instance; var searchResult = Search(context, input); if (!searchResult.IsSuccess) return searchResult; var commands = searchResult.Commands; var preconditionResults = new Dictionary(); foreach (var match in commands) { preconditionResults[match] = await match.Command.CheckPreconditionsAsync(context, services).ConfigureAwait(false); } var successfulPreconditions = preconditionResults .Where(x => x.Value.IsSuccess) .ToArray(); if (successfulPreconditions.Length == 0) { //All preconditions failed, return the one from the highest priority command var bestCandidate = preconditionResults .OrderByDescending(x => x.Key.Command.Priority) .FirstOrDefault(x => !x.Value.IsSuccess); return bestCandidate.Value; } //If we get this far, at least one precondition was successful. var parseResultsDict = new Dictionary(); foreach (var pair in successfulPreconditions) { var parseResult = await pair.Key.ParseAsync(context, searchResult, pair.Value, services).ConfigureAwait(false); if (parseResult.Error == CommandError.MultipleMatches) { IReadOnlyList argList, paramList; switch (multiMatchHandling) { case MultiMatchHandling.Best: argList = parseResult.ArgValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToImmutableArray(); paramList = parseResult.ParamValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToImmutableArray(); parseResult = ParseResult.FromSuccess(argList, paramList); break; } } parseResultsDict[pair.Key] = parseResult; } // Calculates the 'score' of a command given a parse result float CalculateScore(CommandMatch match, ParseResult parseResult) { float argValuesScore = 0, paramValuesScore = 0; if (match.Command.Parameters.Count > 0) { var argValuesSum = parseResult.ArgValues?.Sum(x => x.Values.OrderByDescending(y => y.Score).FirstOrDefault().Score) ?? 0; var paramValuesSum = parseResult.ParamValues?.Sum(x => x.Values.OrderByDescending(y => y.Score).FirstOrDefault().Score) ?? 0; argValuesScore = argValuesSum / match.Command.Parameters.Count; paramValuesScore = paramValuesSum / match.Command.Parameters.Count; } var totalArgsScore = (argValuesScore + paramValuesScore) / 2; return match.Command.Priority + totalArgsScore * 0.99f; } //Order the parse results by their score so that we choose the most likely result to execute var parseResults = parseResultsDict .OrderByDescending(x => CalculateScore(x.Key, x.Value)); var successfulParses = parseResults .Where(x => x.Value.IsSuccess) .ToArray(); if (successfulParses.Length == 0) { //All parses failed, return the one from the highest priority command, using score as a tie breaker var bestMatch = parseResults .FirstOrDefault(x => !x.Value.IsSuccess); return bestMatch.Value; } //If we get this far, at least one parse was successful. Execute the most likely overload. var chosenOverload = successfulParses[0]; return await chosenOverload.Key.ExecuteAsync(context, chosenOverload.Value, services).ConfigureAwait(false); } } }