/** * \file dnn/src/common/algo_chooser.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include #include #include #include #include "utils.h" namespace megdnn { /*! * \brief get user-configured algorithm, or heuristic algorithm */ template typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { typename Opr::AlgorithmDesc ret; auto set = opr->execution_policy().algo; if (set.valid()) { ret = set; } else { ret = opr->get_algorithm_info_heuristic( std::forward(args)..., std::numeric_limits::max(), AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT).desc; } return static_cast( opr->get_algorithm_from_desc(ret)); } /*! * \brief get user-configured algorithm, or heuristic algorithm. used in opencl * whose algo need to be constructed each time. */ template typename Opr::AlgoBase* get_algorithm_or_construct(Opr* opr, Args&&... args) { auto set = opr->execution_policy().algo; if (set.valid()) { return opr->algo_pack().construct_and_get_algo(set); } else { return static_cast( opr->get_algorithm_heuristic(std::forward(args)..., std::numeric_limits::max(), AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT)); } } /*! * \brief get all algorithms from algo_pack() that is available for current size */ template std::vector get_all_algorithms( const typename Opr::AlgoBase::SizeArgs& args) { std::vector ret; ret.reserve(Opr::algo_pack().all_algos.size()); for (auto i : Opr::algo_pack().all_algos) { if (i->is_available(args)) { ret.push_back(i); } } megdnn_assert(!ret.empty(), "no conv algorithm for %s", args.to_string().c_str()); return ret; } /*! * \brief a helper function to get an algorithm match attribute. If require a * algorithm with specified attribute, and the given algorithm match that * attribute, return the given algorithm. Otherwise return nullptr */ template typename Opr::Algorithm* get_algo_match_attribute( typename Opr::AlgoBase* algo, const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { if (algo->contain_attribute_all(positive_attr) && !algo->contain_attribute_any(negative_attr)) { return algo; } return nullptr; } template typename Opr::Algorithm* get_algo_match_attribute( const std::vector& algos, const typename Opr::AlgoBase::SizeArgs& args, size_t workspace_limit_in_bytes, const char* name, const AlgoAttribute& positive_attr = AlgoAttribute::REPRODUCIBLE, const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { size_t min_workspace_limit_in_bytes = std::numeric_limits::max(); bool available_but_limited_by_workspace = false; bool available_but_attribute_mismatch = false; for (auto i : algos) { if (i->is_available_attribute(args, positive_attr, negative_attr, workspace_limit_in_bytes)) { return i; } if (i->is_available_attribute(args, positive_attr, negative_attr)) { if (i->get_workspace_in_bytes(args) > workspace_limit_in_bytes) { available_but_limited_by_workspace = true; min_workspace_limit_in_bytes = std::min(min_workspace_limit_in_bytes, i->get_workspace_in_bytes(args)); } } if (i->is_available(args)) { if (!(i->contain_attribute_all(positive_attr) && !i->contain_attribute_any(negative_attr))) available_but_attribute_mismatch = true; } } MEGDNN_MARK_USED_VAR(name); if (available_but_limited_by_workspace) { megdnn_throw( ssprintf("no %s algorithm without attribute(%s) with " "attribute(%s) : %s workspace limit %zu is " "less than mini workspace limit %zu", name, Algorithm::attribute_str(negative_attr).c_str(), Algorithm::attribute_str(positive_attr).c_str(), args.to_string().c_str(), workspace_limit_in_bytes, min_workspace_limit_in_bytes)); } else if (available_but_attribute_mismatch) { megdnn_throw(ssprintf( "no %s algorithm without attribute(%s) with attribute(%s)", name, Algorithm::attribute_str(negative_attr).c_str(), Algorithm::attribute_str(positive_attr).c_str())); } else { megdnn_throw(ssprintf("no usable %s algorithm", name)); } } } // namespace megdnn // vim: syntax=cpp.doxygen