|
|
@@ -12,6 +12,8 @@ |
|
|
|
|
|
|
|
#include "megbrain/opr/search_policy/algo_chooser.h" |
|
|
|
#include <limits> |
|
|
|
#include <unordered_set> |
|
|
|
#include "megbrain/opr/dnn/convolution.h" |
|
|
|
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" |
|
|
|
#include "megbrain/opr/search_policy/algo_chooser_helper.h" |
|
|
|
#include "megbrain/opr/search_policy/profiler.h" |
|
|
@@ -22,6 +24,7 @@ |
|
|
|
|
|
|
|
//! TODO: here has to be know some megdnn::opr when there is produced midout.h |
|
|
|
//! fix it if there is another graceful way. |
|
|
|
#include "megdnn/opr_param_defs.h" |
|
|
|
#include "megdnn/oprs.h" |
|
|
|
#include "megdnn/oprs/base.h" |
|
|
|
#include "midout.h" |
|
|
@@ -78,6 +81,58 @@ std::string format_fixlayouts( |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
/** |
|
|
|
* \brief Check if the sub opr list has circular dependence. |
|
|
|
*/ |
|
|
|
class CircularDepsChecker { |
|
|
|
struct SearchItemStorage { |
|
|
|
std::string data_hold; |
|
|
|
size_t hash = 0; |
|
|
|
|
|
|
|
SearchItemStorage(const Algorithm::SearchItem& item) { |
|
|
|
Algorithm::serialize_write_pod(item.opr_type, data_hold); |
|
|
|
for (auto&& layout : item.layouts) { |
|
|
|
data_hold += layout.serialize(); |
|
|
|
} |
|
|
|
data_hold += item.param; |
|
|
|
} |
|
|
|
|
|
|
|
SearchItemStorage& init_hash() { |
|
|
|
hash = XXHash64CT::hash(data_hold.data(), data_hold.size(), |
|
|
|
20201225); |
|
|
|
return *this; |
|
|
|
} |
|
|
|
|
|
|
|
bool operator==(const SearchItemStorage& rhs) const { |
|
|
|
return data_hold == rhs.data_hold; |
|
|
|
} |
|
|
|
|
|
|
|
struct Hash { |
|
|
|
size_t operator()(const SearchItemStorage& s) const { |
|
|
|
return s.hash; |
|
|
|
} |
|
|
|
}; |
|
|
|
}; |
|
|
|
std::unordered_set<SearchItemStorage, SearchItemStorage::Hash> m_set; |
|
|
|
|
|
|
|
public: |
|
|
|
void put(const megdnn::Algorithm::SearchItem& key) { |
|
|
|
SearchItemStorage key_storage(key); |
|
|
|
key_storage.init_hash(); |
|
|
|
mgb_assert(m_set.find(key_storage) == m_set.end(), |
|
|
|
"Circular dependency during flatten search space"); |
|
|
|
auto ret = m_set.insert(std::move(key_storage)); |
|
|
|
mgb_assert(ret.second); |
|
|
|
} |
|
|
|
void remove(const megdnn::Algorithm::SearchItem& key) { |
|
|
|
SearchItemStorage key_storage(key); |
|
|
|
key_storage.init_hash(); |
|
|
|
auto&& iter = m_set.find(key_storage); |
|
|
|
mgb_assert(iter != m_set.end()); |
|
|
|
m_set.erase(iter); |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
///////////////// OprTypeTrait ///////////////////////////// |
|
|
|
template <megdnn::Algorithm::OprType> |
|
|
|
struct OprFromOprTypeTrait; |
|
|
@@ -176,14 +231,26 @@ typename opr::AlgoChooser<Opr>::FixedTensorLayouts to_fixed_layouts( |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace |
|
|
|
|
|
|
|
namespace mgb { |
|
|
|
namespace opr { |
|
|
|
|
|
|
|
/** |
|
|
|
* flatten search space in postorder traversal |
|
|
|
* The subopr search construct a search tree |
|
|
|
* |
|
|
|
* A |
|
|
|
* / \ |
|
|
|
* B1B2 C |
|
|
|
* / \ |
|
|
|
* D1D2D3 E |
|
|
|
* We use postorder traverse the search tree. |
|
|
|
* D1 -> D2 -> D3 -> E -> B1 -> B2 -> C -> A |
|
|
|
*/ |
|
|
|
template <typename Opr> |
|
|
|
std::vector<megdnn::Algorithm::SearchItem> |
|
|
|
AlgoChooser<Opr>::flatten_search_space(const ExeContext& ctx) { |
|
|
|
std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( |
|
|
|
const typename opr::AlgoChooser<Opr>::ExeContext& ctx, |
|
|
|
CircularDepsChecker& checker) { |
|
|
|
auto&& search_item = megdnn::Algorithm::SearchItem{ |
|
|
|
OprTypeFromOprTrait<Opr>::opr_type, ctx.param(), |
|
|
|
to_layout_array<Opr>(ctx.layouts())}; |
|
|
|
checker.put(search_item); |
|
|
|
std::vector<megdnn::Algorithm::SearchItem> ret; |
|
|
|
for (auto algo_info : ctx.get_all_candidates()) { |
|
|
|
megdnn::Algorithm* algo = ctx.get_algorithm_from_desc(algo_info.desc); |
|
|
@@ -193,23 +260,29 @@ AlgoChooser<Opr>::flatten_search_space(const ExeContext& ctx) { |
|
|
|
ctx.megdnn_opr()); |
|
|
|
|
|
|
|
FOREACH_OPR_TYPE_DISPATCH(sub_items, { |
|
|
|
auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(ctx.comp_node()); |
|
|
|
auto&& megdnn_opr = |
|
|
|
opr::intl::create_megdnn_opr<_Opr>(ctx.comp_node()); |
|
|
|
megdnn_opr->param() = |
|
|
|
Algorithm::deserialize_read_pod<typename _Opr::Param>( |
|
|
|
_item.param); |
|
|
|
typename AlgoChooser<_Opr>::ExeContext sub_ctx( |
|
|
|
typename opr::AlgoChooser<_Opr>::ExeContext sub_ctx( |
|
|
|
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), |
|
|
|
_item.param, ctx.mgb_opr(), ctx.comp_node(), |
|
|
|
ctx.execution_policy(), ctx.allow_weight_preprocess()); |
|
|
|
auto space = AlgoChooser<_Opr>::flatten_search_space(sub_ctx); |
|
|
|
auto space = flatten_search_space<_Opr>(sub_ctx, checker); |
|
|
|
ret.insert(ret.end(), space.begin(), space.end()); |
|
|
|
}); |
|
|
|
} |
|
|
|
ret.push_back({OprTypeFromOprTrait<Opr>::opr_type, ctx.param(), |
|
|
|
to_layout_array<Opr>(ctx.layouts())}); |
|
|
|
ret.push_back(search_item); |
|
|
|
checker.remove(search_item); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace |
|
|
|
|
|
|
|
namespace mgb { |
|
|
|
namespace opr { |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
void AlgoChooser<Opr>::profile(ExeContext& ctx, bool require_reproducible) { |
|
|
|
if (ctx.get_profile_result_from_cache(require_reproducible).valid()) |
|
|
@@ -289,7 +362,9 @@ AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, bool require_reproducible, |
|
|
|
} |
|
|
|
|
|
|
|
if (enable_update) { |
|
|
|
auto&& search_items = flatten_search_space(ctx); |
|
|
|
CircularDepsChecker circular_deps_checker; |
|
|
|
auto&& search_items = |
|
|
|
flatten_search_space<Opr>(ctx, circular_deps_checker); |
|
|
|
FOREACH_OPR_TYPE_DISPATCH(search_items, { |
|
|
|
auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(ctx.comp_node()); |
|
|
|
megdnn_opr->param() = |
|
|
@@ -382,14 +457,12 @@ typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( |
|
|
|
AlgoChooser<megdnn::Opr>::get_policy(ExeContext& ctx); \ |
|
|
|
template void AlgoChooser<megdnn::Opr>::profile( \ |
|
|
|
ExeContext& ctx, bool require_reproducible); \ |
|
|
|
template std::vector<megdnn::Algorithm::SearchItem> \ |
|
|
|
AlgoChooser<megdnn::Opr>::flatten_search_space(const ExeContext& ctx); \ |
|
|
|
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ |
|
|
|
AlgoChooser<megdnn::Opr>::choose_by_profile( \ |
|
|
|
ExeContext& ctx, bool require_reproducible, bool enable_update); \ |
|
|
|
template size_t AlgoChooser<megdnn::Opr>::setup_algo( \ |
|
|
|
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ |
|
|
|
const MGBOpr* mgb_opr, bool allow_weight_preprocess); |
|
|
|
const MGBOpr* mgb_opr, bool allow_weight_preprocess); \ |
|
|
|
|
|
|
|
MGB_FOREACH_FASTRUN_OPR(INST) |
|
|
|
|
|
|
|