/** * \file dnn/src/common/algo_chooser.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include #include #include #include #include "utils.h" namespace megdnn { /*! * \brief get user-configured algorithm, or heuristic algorithm */ template typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { typename Opr::AlgorithmInfo ret; auto set = opr->execution_policy().algo; if (set.valid()) { ret = set; } else { ret = opr->get_algorithm_info_heuristic( std::forward(args)..., std::numeric_limits::max(), false); } return opr->get_algo_from_desc(ret.desc); } /*! * \brief get user-configured algorithm, or heuristic algorithm. used in opencl * whose algo need to be constructed each time. */ template typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { typename Opr::AlgorithmInfo ret; auto set = opr->execution_policy().algo; if (set.valid()) { return opr->algo_pack().construct_and_get_algo(set.desc); } else { ret = opr->get_algorithm_info_heuristic( std::forward(args)..., std::numeric_limits::max(), false); return opr->get_algo_from_desc(ret.desc); } } /*! * \brief get all algorithms from algo_pack() that is available for current size */ template std::vector get_all_algorithms( const typename Opr::AlgoBase::SizeArgs& args) { std::vector ret; ret.reserve(Opr::algo_pack().all_algos.size()); for (auto i : Opr::algo_pack().all_algos) { if (i->is_available(args)) { ret.push_back(i); } } megdnn_assert(!ret.empty(), "no conv algorithm for %s", args.to_string().c_str()); return ret; } /*! * \brief a helper function to get a reproducible algorithm. If require a * reproducible algorithm, and the given algorithm is reproducible, return the * given algorithm. Otherwise return nullptr */ template typename Opr::Algorithm* get_reproducible_algo(typename Opr::AlgoBase* algo, bool reproducible) { if (reproducible) { if (algo->is_reproducible()) { return algo; } } else { return algo; } return nullptr; } template typename Opr::Algorithm* get_reproducible_algo( const std::vector& algos, const typename Opr::AlgoBase::SizeArgs& args, size_t workspace_limit_in_bytes, const char* name) { size_t min_workspace_limit_in_bytes = std::numeric_limits::max(); bool available_but_limited_by_workspace = false; bool available_but_not_reproducible = false; for (auto i : algos) { if (i->is_available_reproducible(args, true, workspace_limit_in_bytes)) { return i; } if (i->is_available_reproducible(args)) { if (i->get_workspace_in_bytes(args) > workspace_limit_in_bytes) { available_but_limited_by_workspace = true; min_workspace_limit_in_bytes = std::min(min_workspace_limit_in_bytes, i->get_workspace_in_bytes(args)); } } if (i->is_available(args)) { if (!i->is_reproducible()) available_but_not_reproducible = true; } } MEGDNN_MARK_USED_VAR(name); if (available_but_limited_by_workspace) { megdnn_throw(megdnn_mangle(ssprintf( "no reproducible %s algorithm: %s workspace limit %zu is " "less than mini workspace limit %zu", name, args.to_string().c_str(), workspace_limit_in_bytes, min_workspace_limit_in_bytes))); } else if (available_but_not_reproducible) { megdnn_throw( megdnn_mangle(ssprintf("no reproducible %s algorithm", name))); } else { megdnn_throw(megdnn_mangle(ssprintf("no usable %s algorithm", name))); } } template typename Opr::Algorithm* get_usable_algo( const std::vector& algos, const typename Opr::AlgoBase::SizeArgs& args, size_t workspace_limit_in_bytes, const char* name) { size_t min_workspace_limit_in_bytes = std::numeric_limits::max(); bool available_but_limited_by_workspace = false; for (auto i : algos) { if (i->is_available_wk(args, workspace_limit_in_bytes)) { return i; } if (i->is_available(args)) { available_but_limited_by_workspace = true; min_workspace_limit_in_bytes = std::min(min_workspace_limit_in_bytes, i->get_workspace_in_bytes(args)); } } MEGDNN_MARK_USED_VAR(name); if (available_but_limited_by_workspace) { megdnn_throw(megdnn_mangle(ssprintf( "no usable %s algorithm: %s workspace limit %zu is " "less than mini workspace limit %zu", name, args.to_string().c_str(), workspace_limit_in_bytes, min_workspace_limit_in_bytes))); } else { megdnn_throw(megdnn_mangle(ssprintf("no usable %s algorithm", name))); } } } // namespace megdnn // vim: syntax=cpp.doxygen