Browse Source

fix(mgb): add usable-depend-on-shape attr

GitOrigin-RevId: 3a14fa6b6f
release-1.4
Megvii Engine Team 4 years ago
parent
commit
1cadf9d8d7
12 changed files with 93 additions and 55 deletions
  1. +5
    -0
      dnn/include/megdnn/oprs/base.h
  2. +14
    -7
      dnn/src/aarch64/matrix_mul/algos.h
  3. +8
    -4
      dnn/src/arm_common/matrix_mul/algos.h
  4. +6
    -3
      dnn/src/armv7/matrix_mul/algos.h
  5. +2
    -1
      dnn/src/common/algo_base.cpp
  6. +2
    -1
      dnn/src/cuda/convolution/backward_data/algo.h
  7. +4
    -2
      dnn/src/cuda/local_share/forward/algo.h
  8. +4
    -2
      dnn/src/cuda/matrix_mul/algos.h
  9. +2
    -1
      dnn/src/rocm/matrix_mul/algos.h
  10. +2
    -1
      dnn/src/x86/matrix_mul/algos.h
  11. +34
    -33
      src/opr/impl/search_policy/algo_chooser.cpp
  12. +10
    -0
      src/opr/include/megbrain/opr/search_policy/algo_chooser.h

+ 5
- 0
dnn/include/megdnn/oprs/base.h View File

@@ -122,6 +122,11 @@ public:
* these algorithms to speed up fastrun.
* */
NAIVE = 1 << 1,

/**
* \brief whether the algo is usable once shape changed.
* */
USABLE_DEPEND_ON_SHAPE = 1 << 2,
};

/**


+ 14
- 7
dnn/src/aarch64/matrix_mul/algos.h View File

@@ -35,7 +35,8 @@ public:
class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; }
bool usable(const KernSizeParam&) const override;
@@ -146,7 +147,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; }
bool usable(const KernSizeParam&) const override;
@@ -220,7 +222,8 @@ public:
class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; }
bool usable(const KernSizeParam&) const override;
@@ -235,7 +238,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override {
return "AARCH64_INT8X8X16_MK4_16X12X4";
@@ -253,7 +257,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override {
return "AARCH64_INT8X8X16_MK4_K8X8X8";
@@ -271,7 +276,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; }
bool usable(const KernSizeParam&) const override;
@@ -330,7 +336,8 @@ public:
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; }
bool usable(const KernSizeParam&) const override;


+ 8
- 4
dnn/src/arm_common/matrix_mul/algos.h View File

@@ -34,7 +34,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; }
bool usable(const KernSizeParam&) const override;
@@ -50,7 +51,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; }
bool usable(const KernSizeParam&) const override;
@@ -67,7 +69,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; }
bool usable(const KernSizeParam&) const override;
@@ -102,7 +105,8 @@ public:
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; }
bool usable(const KernSizeParam&) const override;


+ 6
- 3
dnn/src/armv7/matrix_mul/algos.h View File

@@ -35,7 +35,8 @@ public:
class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; }
bool usable(const KernSizeParam&) const override;
@@ -224,7 +225,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; }
bool usable(const KernSizeParam&) const override;
@@ -266,7 +268,8 @@ public:
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; }
bool usable(const KernSizeParam&) const override;


+ 2
- 1
dnn/src/common/algo_base.cpp View File

@@ -18,7 +18,8 @@ using namespace megdnn;
#define FOREACH_ALGO_ATTRIBUTE(cb) \
cb(DEFAULT) \
cb(REPRODUCIBLE) \
cb(NAIVE)
cb(NAIVE) \
cb(USABLE_DEPEND_ON_SHAPE)

namespace {
inline const char* attr_str(const AlgoAttribute& attr) {


+ 2
- 1
dnn/src/cuda/convolution/backward_data/algo.h View File

@@ -184,7 +184,8 @@ public:
const char* name() const override { return "CHANNEL_WISE_SMALL"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
};



+ 4
- 2
dnn/src/cuda/local_share/forward/algo.h View File

@@ -89,7 +89,8 @@ public:
void exec(const ExecArgs& args) const override;

AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}

const char* name() const override {
@@ -108,7 +109,8 @@ public:
void exec(const ExecArgs& args) const override;

AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}

const char* name() const override {


+ 4
- 2
dnn/src/cuda/matrix_mul/algos.h View File

@@ -114,7 +114,8 @@ public:
void exec(const ExecArgs& args) const override;
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS)
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
};

@@ -231,7 +232,8 @@ public:
const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override;
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K)



+ 2
- 1
dnn/src/rocm/matrix_mul/algos.h View File

@@ -100,7 +100,8 @@ public:
const char* name() const override { return "BLAS"; }
void exec(const ExecArgs& args) const override;
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS)
};


+ 2
- 1
dnn/src/x86/matrix_mul/algos.h View File

@@ -135,7 +135,8 @@ public:
class MatrixMulImpl::AlgoF32MK8_8x8 : public AlgoBase {
public:
AlgoAttribute attribute() const override {
return AlgoAttribute::REPRODUCIBLE;
return AlgoAttribute::REPRODUCIBLE |
AlgoAttribute::USABLE_DEPEND_ON_SHAPE;
}
const char* name() const override { return "X86_F32MK8_8X8"; }
bool usable(const KernSizeParam&) const override;


+ 34
- 33
src/opr/impl/search_policy/algo_chooser.cpp View File

@@ -276,21 +276,6 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space(
return ret;
}

//! return pair<positive_attr, negative_attr>
std::pair<AlgoAttribute, AlgoAttribute>
extract_algo_attribute_from_execution_strategy(
const ExecutionStrategy& strategy) {
std::pair<AlgoAttribute, AlgoAttribute> ret =
std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);
if (strategy & ExecutionStrategy::REPRODUCIBLE) {
ret.first |= AlgoAttribute::REPRODUCIBLE;
}
if (strategy & ExecutionStrategy::OPTIMIZED) {
ret.second |= AlgoAttribute::NAIVE;
}
return ret;
}

} // namespace

namespace mgb {
@@ -303,9 +288,9 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx,
return;
AlgoChooserProfileCache::Result prof_rst;

auto target_attr =
extract_algo_attribute_from_execution_strategy(selected_strategy);
std::string layouts_str = format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out);
auto target_attr = ctx.extract_algo_attribute(selected_strategy);
std::string layouts_str =
format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out);
double cur_timeout = 0;

auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
@@ -558,16 +543,15 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
if (prof.empty())
return {};

auto attr_from_strategy =
extract_algo_attribute_from_execution_strategy(selected_strategy);
auto target_attr = extract_algo_attribute(selected_strategy);
for (auto&& i : prof) {
auto attr_of_algo =
static_cast<megdnn::Algorithm::Attribute>(i.attribute);
bool contain_attr_all_positive =
(attr_from_strategy.first ==
(attr_of_algo & attr_from_strategy.first));
(target_attr.first ==
(attr_of_algo & target_attr.first));
bool contain_attr_any_negative =
static_cast<bool>(attr_of_algo & attr_from_strategy.second);
static_cast<bool>(attr_of_algo & target_attr.second);
if (contain_attr_all_positive && !contain_attr_any_negative) {
auto iter = algo_map.find(i.algo);
mgb_assert(iter != algo_map.end(),
@@ -586,8 +570,8 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
mgb_log_error(
"algos read from cache could not satisfy attribute with %s and "
"without %s",
Algorithm::attribute_str(attr_from_strategy.first).c_str(),
Algorithm::attribute_str(attr_from_strategy.second).c_str());
Algorithm::attribute_str(target_attr.first).c_str(),
Algorithm::attribute_str(target_attr.second).c_str());

mgb_trap();
MIDOUT_E
@@ -606,8 +590,7 @@ AlgoChooser<Opr>::ExeContext::choose_by_heuristic(
}
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit);
auto attr =
extract_algo_attribute_from_execution_strategy(selected_strategy);
auto attr = extract_algo_attribute(selected_strategy);
ImplExecutionPolicy policy;
policy.algo =
APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
@@ -668,9 +651,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
if (retrive_from_cache) {
policy.algo = get_profile_result_from_cache(selected_strategy).desc;
if (!policy.algo.valid()) {
auto target_attr =
extract_algo_attribute_from_execution_strategy(
selected_strategy);
auto target_attr = extract_algo_attribute(selected_strategy);
std::string layouts_str =
format_fixlayouts<Opr>(m_layouts, arity_in, arity_out);
std::string msg = ssprintf(
@@ -692,8 +673,7 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy(
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit);

auto attr = extract_algo_attribute_from_execution_strategy(
selected_strategy);
auto attr = extract_algo_attribute(selected_strategy);
policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic(
args..., workspace_limit, attr.first,
attr.second),
@@ -837,6 +817,24 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const {
return result;
}

template <typename Opr>
std::pair<AlgoAttribute, AlgoAttribute>
AlgoChooser<Opr>::ExeContext::extract_algo_attribute(
const ExecutionStrategy& strategy) const {
std::pair<AlgoAttribute, AlgoAttribute> ret =
std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT);

//! from strategy
if (strategy & ExecutionStrategy::REPRODUCIBLE) {
ret.first |= AlgoAttribute::REPRODUCIBLE;
}
if (strategy & ExecutionStrategy::OPTMIZED) {
ret.second |= AlgoAttribute::NAIVE;
}

return ret;
}

#define INST(Opr) \
template AlgoChooser<megdnn::Opr>::ExeContext::ExeContext( \
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \
@@ -865,7 +863,10 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const {
AlgoChooser<megdnn::Opr>::ExeContext::profile_single_algo( \
const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& \
policy, \
double& timeout) const;
double& timeout) const; \
template std::pair<AlgoAttribute, AlgoAttribute> \
AlgoChooser<megdnn::Opr>::ExeContext::extract_algo_attribute( \
const ExecutionStrategy& strategy) const;

MGB_FOREACH_FASTRUN_OPR(INST)



+ 10
- 0
src/opr/include/megbrain/opr/search_policy/algo_chooser.h View File

@@ -149,6 +149,16 @@ public:
ImplExecutionPolicy& policy,
bool retrive_from_cache = true) const;

/**
* \brief extract algo attribute from execution strategy and graph
* option.
*
* \param strategy select algo which matched this strategy
* \return pair<positive_attr, negative_attr>
*/
std::pair<AlgoAttribute, AlgoAttribute> extract_algo_attribute(
const ExecutionStrategy& strategy) const;

private:
Maybe<PreprocessFilter<Opr>> construct_fake_preprocess_filter() const;
};


Loading…
Cancel
Save