@@ -122,6 +122,11 @@ public: | |||||
* these algorithms to speed up fastrun. | * these algorithms to speed up fastrun. | ||||
* */ | * */ | ||||
NAIVE = 1 << 1, | NAIVE = 1 << 1, | ||||
/** | |||||
* \brief whether the algo is usable once shape changed. | |||||
* */ | |||||
USABLE_DEPEND_ON_SHAPE = 1 << 2, | |||||
}; | }; | ||||
/** | /** | ||||
@@ -35,7 +35,8 @@ public: | |||||
class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -146,7 +147,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -220,7 +222,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -235,7 +238,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { | const char* name() const override { | ||||
return "AARCH64_INT8X8X16_MK4_16X12X4"; | return "AARCH64_INT8X8X16_MK4_16X12X4"; | ||||
@@ -253,7 +257,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { | const char* name() const override { | ||||
return "AARCH64_INT8X8X16_MK4_K8X8X8"; | return "AARCH64_INT8X8X16_MK4_K8X8X8"; | ||||
@@ -271,7 +276,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -330,7 +336,8 @@ public: | |||||
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -34,7 +34,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } | const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -50,7 +51,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; } | const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -67,7 +69,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; } | const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -102,7 +105,8 @@ public: | |||||
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } | const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -35,7 +35,8 @@ public: | |||||
class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; } | const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -224,7 +225,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; } | const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -266,7 +268,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; } | const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -18,7 +18,8 @@ using namespace megdnn; | |||||
#define FOREACH_ALGO_ATTRIBUTE(cb) \ | #define FOREACH_ALGO_ATTRIBUTE(cb) \ | ||||
cb(DEFAULT) \ | cb(DEFAULT) \ | ||||
cb(REPRODUCIBLE) \ | cb(REPRODUCIBLE) \ | ||||
cb(NAIVE) | |||||
cb(NAIVE) \ | |||||
cb(USABLE_DEPEND_ON_SHAPE) | |||||
namespace { | namespace { | ||||
inline const char* attr_str(const AlgoAttribute& attr) { | inline const char* attr_str(const AlgoAttribute& attr) { | ||||
@@ -184,7 +184,8 @@ public: | |||||
const char* name() const override { return "CHANNEL_WISE_SMALL"; } | const char* name() const override { return "CHANNEL_WISE_SMALL"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
}; | }; | ||||
@@ -89,7 +89,8 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { | const char* name() const override { | ||||
@@ -108,7 +109,8 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { | const char* name() const override { | ||||
@@ -114,7 +114,8 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
}; | }; | ||||
@@ -231,7 +232,8 @@ public: | |||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | ||||
@@ -100,7 +100,8 @@ public: | |||||
const char* name() const override { return "BLAS"; } | const char* name() const override { return "BLAS"; } | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | ||||
}; | }; | ||||
@@ -135,7 +135,8 @@ public: | |||||
class MatrixMulImpl::AlgoF32MK8_8x8 : public AlgoBase { | class MatrixMulImpl::AlgoF32MK8_8x8 : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "X86_F32MK8_8X8"; } | const char* name() const override { return "X86_F32MK8_8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -276,21 +276,6 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( | |||||
return ret; | return ret; | ||||
} | } | ||||
//! return pair<positive_attr, negative_attr> | |||||
std::pair<AlgoAttribute, AlgoAttribute> | |||||
extract_algo_attribute_from_execution_strategy( | |||||
const ExecutionStrategy& strategy) { | |||||
std::pair<AlgoAttribute, AlgoAttribute> ret = | |||||
std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT); | |||||
if (strategy & ExecutionStrategy::REPRODUCIBLE) { | |||||
ret.first |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
if (strategy & ExecutionStrategy::OPTIMIZED) { | |||||
ret.second |= AlgoAttribute::NAIVE; | |||||
} | |||||
return ret; | |||||
} | |||||
} // namespace | } // namespace | ||||
namespace mgb { | namespace mgb { | ||||
@@ -303,9 +288,9 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||||
return; | return; | ||||
AlgoChooserProfileCache::Result prof_rst; | AlgoChooserProfileCache::Result prof_rst; | ||||
auto target_attr = | |||||
extract_algo_attribute_from_execution_strategy(selected_strategy); | |||||
std::string layouts_str = format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out); | |||||
auto target_attr = ctx.extract_algo_attribute(selected_strategy); | |||||
std::string layouts_str = | |||||
format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out); | |||||
double cur_timeout = 0; | double cur_timeout = 0; | ||||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
@@ -558,16 +543,15 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||||
if (prof.empty()) | if (prof.empty()) | ||||
return {}; | return {}; | ||||
auto attr_from_strategy = | |||||
extract_algo_attribute_from_execution_strategy(selected_strategy); | |||||
auto target_attr = extract_algo_attribute(selected_strategy); | |||||
for (auto&& i : prof) { | for (auto&& i : prof) { | ||||
auto attr_of_algo = | auto attr_of_algo = | ||||
static_cast<megdnn::Algorithm::Attribute>(i.attribute); | static_cast<megdnn::Algorithm::Attribute>(i.attribute); | ||||
bool contain_attr_all_positive = | bool contain_attr_all_positive = | ||||
(attr_from_strategy.first == | |||||
(attr_of_algo & attr_from_strategy.first)); | |||||
(target_attr.first == | |||||
(attr_of_algo & target_attr.first)); | |||||
bool contain_attr_any_negative = | bool contain_attr_any_negative = | ||||
static_cast<bool>(attr_of_algo & attr_from_strategy.second); | |||||
static_cast<bool>(attr_of_algo & target_attr.second); | |||||
if (contain_attr_all_positive && !contain_attr_any_negative) { | if (contain_attr_all_positive && !contain_attr_any_negative) { | ||||
auto iter = algo_map.find(i.algo); | auto iter = algo_map.find(i.algo); | ||||
mgb_assert(iter != algo_map.end(), | mgb_assert(iter != algo_map.end(), | ||||
@@ -586,8 +570,8 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||||
mgb_log_error( | mgb_log_error( | ||||
"algos read from cache could not satisfy attribute with %s and " | "algos read from cache could not satisfy attribute with %s and " | ||||
"without %s", | "without %s", | ||||
Algorithm::attribute_str(attr_from_strategy.first).c_str(), | |||||
Algorithm::attribute_str(attr_from_strategy.second).c_str()); | |||||
Algorithm::attribute_str(target_attr.first).c_str(), | |||||
Algorithm::attribute_str(target_attr.second).c_str()); | |||||
mgb_trap(); | mgb_trap(); | ||||
MIDOUT_E | MIDOUT_E | ||||
@@ -606,8 +590,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||||
} | } | ||||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
auto attr = | |||||
extract_algo_attribute_from_execution_strategy(selected_strategy); | |||||
auto attr = extract_algo_attribute(selected_strategy); | |||||
ImplExecutionPolicy policy; | ImplExecutionPolicy policy; | ||||
policy.algo = | policy.algo = | ||||
APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | ||||
@@ -668,9 +651,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||||
if (retrive_from_cache) { | if (retrive_from_cache) { | ||||
policy.algo = get_profile_result_from_cache(selected_strategy).desc; | policy.algo = get_profile_result_from_cache(selected_strategy).desc; | ||||
if (!policy.algo.valid()) { | if (!policy.algo.valid()) { | ||||
auto target_attr = | |||||
extract_algo_attribute_from_execution_strategy( | |||||
selected_strategy); | |||||
auto target_attr = extract_algo_attribute(selected_strategy); | |||||
std::string layouts_str = | std::string layouts_str = | ||||
format_fixlayouts<Opr>(m_layouts, arity_in, arity_out); | format_fixlayouts<Opr>(m_layouts, arity_in, arity_out); | ||||
std::string msg = ssprintf( | std::string msg = ssprintf( | ||||
@@ -692,8 +673,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
owner_graph(), m_cn, m_execution_policy.workspace_limit); | owner_graph(), m_cn, m_execution_policy.workspace_limit); | ||||
auto attr = extract_algo_attribute_from_execution_strategy( | |||||
selected_strategy); | |||||
auto attr = extract_algo_attribute(selected_strategy); | |||||
policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | ||||
args..., workspace_limit, attr.first, | args..., workspace_limit, attr.first, | ||||
attr.second), | attr.second), | ||||
@@ -837,6 +817,24 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { | |||||
return result; | return result; | ||||
} | } | ||||
template <typename Opr> | |||||
std::pair<AlgoAttribute, AlgoAttribute> | |||||
AlgoChooser<Opr>::ExeContext::extract_algo_attribute( | |||||
const ExecutionStrategy& strategy) const { | |||||
std::pair<AlgoAttribute, AlgoAttribute> ret = | |||||
std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT); | |||||
//! from strategy | |||||
if (strategy & ExecutionStrategy::REPRODUCIBLE) { | |||||
ret.first |= AlgoAttribute::REPRODUCIBLE; | |||||
} | |||||
if (strategy & ExecutionStrategy::OPTMIZED) { | |||||
ret.second |= AlgoAttribute::NAIVE; | |||||
} | |||||
return ret; | |||||
} | |||||
#define INST(Opr) \ | #define INST(Opr) \ | ||||
template AlgoChooser<megdnn::Opr>::ExeContext::ExeContext( \ | template AlgoChooser<megdnn::Opr>::ExeContext::ExeContext( \ | ||||
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ | const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ | ||||
@@ -865,7 +863,10 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { | |||||
AlgoChooser<megdnn::Opr>::ExeContext::profile_single_algo( \ | AlgoChooser<megdnn::Opr>::ExeContext::profile_single_algo( \ | ||||
const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& \ | const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& \ | ||||
policy, \ | policy, \ | ||||
double& timeout) const; | |||||
double& timeout) const; \ | |||||
template std::pair<AlgoAttribute, AlgoAttribute> \ | |||||
AlgoChooser<megdnn::Opr>::ExeContext::extract_algo_attribute( \ | |||||
const ExecutionStrategy& strategy) const; | |||||
MGB_FOREACH_FASTRUN_OPR(INST) | MGB_FOREACH_FASTRUN_OPR(INST) | ||||
@@ -149,6 +149,16 @@ public: | |||||
ImplExecutionPolicy& policy, | ImplExecutionPolicy& policy, | ||||
bool retrive_from_cache = true) const; | bool retrive_from_cache = true) const; | ||||
/** | |||||
* \brief extract algo attribute from execution strategy and graph | |||||
* option. | |||||
* | |||||
* \param strategy select algo which matched this strategy | |||||
* \return pair<positive_attr, negative_attr> | |||||
*/ | |||||
std::pair<AlgoAttribute, AlgoAttribute> extract_algo_attribute( | |||||
const ExecutionStrategy& strategy) const; | |||||
private: | private: | ||||
Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const; | Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const; | ||||
}; | }; | ||||