Browse Source

feat(gopt): add layout transform and fuse in load_and_run

GitOrigin-RevId: 9faa7ef068
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
36f17dec64
4 changed files with 51 additions and 14 deletions
  1. +22
    -3
      sdk/load-and-run/src/mgblar.cpp
  2. +4
    -7
      src/core/impl/graph/cg_impl.cpp
  3. +12
    -3
      src/core/include/megbrain/graph/cg.h
  4. +13
    -1
      src/gopt/impl/framework.cpp

+ 22
- 3
sdk/load-and-run/src/mgblar.cpp View File

@@ -972,9 +972,28 @@ Args Args::from_argv(int argc, char **argv) {
continue;
}
#endif
if (!strcmp(argv[i], "--enable-chwn4")) {
mgb_log_warn("enable chwn4 optimization");
graph_opt.graph_opt.enable_chwn4();

#define cb(_layout) \
if (!strcmp(argv[i], "--enable-" #_layout)) { \
mgb_log_warn("enable " #_layout " optimization"); \
graph_opt.graph_opt.enable_##_layout(); \
continue; \
}

cb(chwn4);
cb(nchw44);
cb(nchw88);
cb(nchw32);
cb(nhwcd4);
#undef cb
if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) {
mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization");
graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity();
continue;
}
if (!strcmp(argv[i], "--enable-fuse-conv-bias-with-z")) {
mgb_log_warn("enable fuse_conv_bias_with_z optimization");
graph_opt.graph_opt.enable_fuse_conv_bias_with_z();
continue;
}
#if MGB_ENABLE_JSON


+ 4
- 7
src/core/impl/graph/cg_impl.cpp View File

@@ -462,13 +462,6 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
options().graph_opt.winograd_transform = false;
gopt::transform_vars_inplace_with_winograd(dest_vars);
}
if (options().graph_opt.transform_chwn4()) {
gopt::GraphOptimizer optimizer;
optimizer.apply_optimize_options(options().graph_opt);
options().graph_opt.layout_transform =
cg::GraphCommonOptimizeOptions::LayoutTransform::DEFAULT;
optimizer.apply_inplace(dest_vars);
}

#if MGB_JIT
if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) {
@@ -480,6 +473,10 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
optimizer.apply_inplace(dest_vars);
}
#endif
gopt::GraphOptimizer optimizer;
optimizer.apply_optimize_options(options().graph_opt);
options().graph_opt.reset();
optimizer.apply_inplace(dest_vars);

const OprNodeArray* opr_seq = nullptr;
CompSeqExtraInfo extra_info;


+ 12
- 3
src/core/include/megbrain/graph/cg.h View File

@@ -92,6 +92,9 @@ struct GraphCommonOptimizeOptions {
bool f16_io_comp = false;
//! whether to enable conv bias nonlinearity fusion
bool fuse_conv_bias_nonlinearity = false;
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b)
//! + z -> conv_bias(x, w, b, z)
bool fuse_conv_bias_with_z = false;
enum LayoutTransform : uint32_t {
DEFAULT,
NHWCD4, ///< compute using NHWCD4 tensor format
@@ -103,9 +106,14 @@ struct GraphCommonOptimizeOptions {
///< used for cuda
};
LayoutTransform layout_transform = LayoutTransform::DEFAULT;
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b)
//! + z -> conv_bias(x, w, b, z)
bool fuse_conv_bias_with_z = false;

void reset() {
f16_io_f32_comp = false;
f16_io_comp = false;
fuse_conv_bias_nonlinearity = false;
fuse_conv_bias_with_z = false;
layout_transform = LayoutTransform::DEFAULT;
}

#define SET(n) \
GraphCommonOptimizeOptions& enable_##n() { \
@@ -119,6 +127,7 @@ struct GraphCommonOptimizeOptions {
#undef SET
#define SET(_trans, _trans_capital) \
GraphCommonOptimizeOptions& enable_##_trans() { \
mgb_assert(layout_transform == LayoutTransform::DEFAULT); \
layout_transform = LayoutTransform::_trans_capital; \
return *this; \
} \


+ 13
- 1
src/gopt/impl/framework.cpp View File

@@ -706,21 +706,27 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) {

const GraphOptimizer& GraphOptimizer::apply_optimize_options(
const cg::GraphCommonOptimizeOptions& options) {
bool need_param_fuse = false;
if (options.f16_io_comp) {
add_pass(ConvertF32ToF16Pass::make(false));
need_param_fuse = true;
}
if (options.f16_io_f32_comp) {
add_pass(ConvertF32ToF16Pass::make(true));
need_param_fuse = true;
}
if (options.transform_nhwcd4()) {
add_pass(ConvertFormatPass::make_nhwcd4_converter());
add_pass<FuseConvBiasNonlinPass>();
need_param_fuse = true;
}
if (options.transform_nchw88()) {
add_pass(EnableNchwxxPass::make_nchwxx_converter(8));
need_param_fuse = true;
}
if (options.transform_nchw44()) {
add_pass(EnableNchwxxPass::make_nchwxx_converter(4));
need_param_fuse = true;
}
if (options.transform_nchw32()) {
add_pass<FuseConvBiasNonlinPass>();
@@ -728,6 +734,7 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options(
add_pass(EnableTensorCorePass::make_tensorcore_converter());
add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>();
need_param_fuse = true;
}
if (options.transform_chwn4()) {
add_pass<FuseConvBiasNonlinPass>();
@@ -735,16 +742,21 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options(
add_pass(EnableCHWN4Pass::make_chwn4_converter());
add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>();
need_param_fuse = true;
}

if (options.fuse_conv_bias_nonlinearity) {
add_pass<FuseConvBiasNonlinPass>();
need_param_fuse = true;
}
if (options.fuse_conv_bias_with_z) {
add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
need_param_fuse = true;
}
if (need_param_fuse) {
add_pass<ParamFusePass>();
}
add_pass<ParamFusePass>();
return *this;
}



Loading…
Cancel
Save