@@ -122,6 +122,11 @@ public: | |||
* these algorithms to speed up fastrun. | |||
* */ | |||
NAIVE = 1 << 1, | |||
/** | |||
* \brief whether the algo is usable once shape changed. | |||
* */ | |||
USABLE_DEPEND_ON_SHAPE = 1 << 2, | |||
}; | |||
/** | |||
@@ -35,7 +35,8 @@ public: | |||
class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -146,7 +147,8 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -220,7 +222,8 @@ public: | |||
class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -235,7 +238,8 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { | |||
return "AARCH64_INT8X8X16_MK4_16X12X4"; | |||
@@ -253,7 +257,8 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { | |||
return "AARCH64_INT8X8X16_MK4_K8X8X8"; | |||
@@ -271,7 +276,8 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -330,7 +336,8 @@ public: | |||
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -34,7 +34,8 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -50,7 +51,8 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -67,7 +69,8 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -102,7 +105,8 @@ public: | |||
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -35,7 +35,8 @@ public: | |||
class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -224,7 +225,8 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -266,7 +268,8 @@ public: | |||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -18,7 +18,8 @@ using namespace megdnn; | |||
#define FOREACH_ALGO_ATTRIBUTE(cb) \ | |||
cb(DEFAULT) \ | |||
cb(REPRODUCIBLE) \ | |||
cb(NAIVE) | |||
cb(NAIVE) \ | |||
cb(USABLE_DEPEND_ON_SHAPE) | |||
namespace { | |||
inline const char* attr_str(const AlgoAttribute& attr) { | |||
@@ -184,7 +184,8 @@ public: | |||
const char* name() const override { return "CHANNEL_WISE_SMALL"; } | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
}; | |||
@@ -89,7 +89,8 @@ public: | |||
void exec(const ExecArgs& args) const override; | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { | |||
@@ -108,7 +109,8 @@ public: | |||
void exec(const ExecArgs& args) const override; | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { | |||
@@ -114,7 +114,8 @@ public: | |||
void exec(const ExecArgs& args) const override; | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
}; | |||
@@ -231,7 +232,8 @@ public: | |||
const char* name() const override { return m_name.c_str(); } | |||
void exec(const ExecArgs& args) const override; | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | |||
@@ -100,7 +100,8 @@ public: | |||
const char* name() const override { return "BLAS"; } | |||
void exec(const ExecArgs& args) const override; | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | |||
}; | |||
@@ -135,7 +135,8 @@ public: | |||
class MatrixMulImpl::AlgoF32MK8_8x8 : public AlgoBase { | |||
public: | |||
AlgoAttribute attribute() const override { | |||
return AlgoAttribute::REPRODUCIBLE; | |||
return AlgoAttribute::REPRODUCIBLE | | |||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
} | |||
const char* name() const override { return "X86_F32MK8_8X8"; } | |||
bool usable(const KernSizeParam&) const override; | |||
@@ -276,21 +276,6 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( | |||
return ret; | |||
} | |||
//! return pair<positive_attr, negative_attr> | |||
std::pair<AlgoAttribute, AlgoAttribute> | |||
extract_algo_attribute_from_execution_strategy( | |||
const ExecutionStrategy& strategy) { | |||
std::pair<AlgoAttribute, AlgoAttribute> ret = | |||
std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT); | |||
if (strategy & ExecutionStrategy::REPRODUCIBLE) { | |||
ret.first |= AlgoAttribute::REPRODUCIBLE; | |||
} | |||
if (strategy & ExecutionStrategy::OPTIMIZED) { | |||
ret.second |= AlgoAttribute::NAIVE; | |||
} | |||
return ret; | |||
} | |||
} // namespace | |||
namespace mgb { | |||
@@ -303,9 +288,9 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, | |||
return; | |||
AlgoChooserProfileCache::Result prof_rst; | |||
auto target_attr = | |||
extract_algo_attribute_from_execution_strategy(selected_strategy); | |||
std::string layouts_str = format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out); | |||
auto target_attr = ctx.extract_algo_attribute(selected_strategy); | |||
std::string layouts_str = | |||
format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out); | |||
double cur_timeout = 0; | |||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | |||
@@ -558,16 +543,15 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||
if (prof.empty()) | |||
return {}; | |||
auto attr_from_strategy = | |||
extract_algo_attribute_from_execution_strategy(selected_strategy); | |||
auto target_attr = extract_algo_attribute(selected_strategy); | |||
for (auto&& i : prof) { | |||
auto attr_of_algo = | |||
static_cast<megdnn::Algorithm::Attribute>(i.attribute); | |||
bool contain_attr_all_positive = | |||
(attr_from_strategy.first == | |||
(attr_of_algo & attr_from_strategy.first)); | |||
(target_attr.first == | |||
(attr_of_algo & target_attr.first)); | |||
bool contain_attr_any_negative = | |||
static_cast<bool>(attr_of_algo & attr_from_strategy.second); | |||
static_cast<bool>(attr_of_algo & target_attr.second); | |||
if (contain_attr_all_positive && !contain_attr_any_negative) { | |||
auto iter = algo_map.find(i.algo); | |||
mgb_assert(iter != algo_map.end(), | |||
@@ -586,8 +570,8 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( | |||
mgb_log_error( | |||
"algos read from cache could not satisfy attribute with %s and " | |||
"without %s", | |||
Algorithm::attribute_str(attr_from_strategy.first).c_str(), | |||
Algorithm::attribute_str(attr_from_strategy.second).c_str()); | |||
Algorithm::attribute_str(target_attr.first).c_str(), | |||
Algorithm::attribute_str(target_attr.second).c_str()); | |||
mgb_trap(); | |||
MIDOUT_E | |||
@@ -606,8 +590,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic( | |||
} | |||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | |||
owner_graph(), m_cn, m_execution_policy.workspace_limit); | |||
auto attr = | |||
extract_algo_attribute_from_execution_strategy(selected_strategy); | |||
auto attr = extract_algo_attribute(selected_strategy); | |||
ImplExecutionPolicy policy; | |||
policy.algo = | |||
APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | |||
@@ -668,9 +651,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||
if (retrive_from_cache) { | |||
policy.algo = get_profile_result_from_cache(selected_strategy).desc; | |||
if (!policy.algo.valid()) { | |||
auto target_attr = | |||
extract_algo_attribute_from_execution_strategy( | |||
selected_strategy); | |||
auto target_attr = extract_algo_attribute(selected_strategy); | |||
std::string layouts_str = | |||
format_fixlayouts<Opr>(m_layouts, arity_in, arity_out); | |||
std::string msg = ssprintf( | |||
@@ -692,8 +673,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( | |||
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | |||
owner_graph(), m_cn, m_execution_policy.workspace_limit); | |||
auto attr = extract_algo_attribute_from_execution_strategy( | |||
selected_strategy); | |||
auto attr = extract_algo_attribute(selected_strategy); | |||
policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( | |||
args..., workspace_limit, attr.first, | |||
attr.second), | |||
@@ -837,6 +817,24 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { | |||
return result; | |||
} | |||
template <typename Opr> | |||
std::pair<AlgoAttribute, AlgoAttribute> | |||
AlgoChooser<Opr>::ExeContext::extract_algo_attribute( | |||
const ExecutionStrategy& strategy) const { | |||
std::pair<AlgoAttribute, AlgoAttribute> ret = | |||
std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT); | |||
//! from strategy | |||
if (strategy & ExecutionStrategy::REPRODUCIBLE) { | |||
ret.first |= AlgoAttribute::REPRODUCIBLE; | |||
} | |||
if (strategy & ExecutionStrategy::OPTMIZED) { | |||
ret.second |= AlgoAttribute::NAIVE; | |||
} | |||
return ret; | |||
} | |||
#define INST(Opr) \ | |||
template AlgoChooser<megdnn::Opr>::ExeContext::ExeContext( \ | |||
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ | |||
@@ -865,7 +863,10 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { | |||
AlgoChooser<megdnn::Opr>::ExeContext::profile_single_algo( \ | |||
const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& \ | |||
policy, \ | |||
double& timeout) const; | |||
double& timeout) const; \ | |||
template std::pair<AlgoAttribute, AlgoAttribute> \ | |||
AlgoChooser<megdnn::Opr>::ExeContext::extract_algo_attribute( \ | |||
const ExecutionStrategy& strategy) const; | |||
MGB_FOREACH_FASTRUN_OPR(INST) | |||
@@ -149,6 +149,16 @@ public: | |||
ImplExecutionPolicy& policy, | |||
bool retrive_from_cache = true) const; | |||
/** | |||
* \brief extract algo attribute from execution strategy and graph | |||
* option. | |||
* | |||
* \param strategy select algo which matched this strategy | |||
* \return pair<positive_attr, negative_attr> | |||
*/ | |||
std::pair<AlgoAttribute, AlgoAttribute> extract_algo_attribute( | |||
const ExecutionStrategy& strategy) const; | |||
private: | |||
Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const; | |||
}; | |||