From 36f17dec64a8fcd04483324aba05e9f9407c03e0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 11 May 2020 13:13:00 +0800 Subject: [PATCH] feat(gopt): add layout transform and fuse in load_and_run GitOrigin-RevId: 9faa7ef06846fcf851e5b65b94cf5daca1de408b --- sdk/load-and-run/src/mgblar.cpp | 25 ++++++++++++++++++++++--- src/core/impl/graph/cg_impl.cpp | 11 ++++------- src/core/include/megbrain/graph/cg.h | 15 ++++++++++++--- src/gopt/impl/framework.cpp | 14 +++++++++++++- 4 files changed, 51 insertions(+), 14 deletions(-) diff --git a/sdk/load-and-run/src/mgblar.cpp b/sdk/load-and-run/src/mgblar.cpp index 7488a3bd..6789e7e7 100644 --- a/sdk/load-and-run/src/mgblar.cpp +++ b/sdk/load-and-run/src/mgblar.cpp @@ -972,9 +972,28 @@ Args Args::from_argv(int argc, char **argv) { continue; } #endif - if (!strcmp(argv[i], "--enable-chwn4")) { - mgb_log_warn("enable chwn4 optimization"); - graph_opt.graph_opt.enable_chwn4(); + +#define cb(_layout) \ + if (!strcmp(argv[i], "--enable-" #_layout)) { \ + mgb_log_warn("enable " #_layout " optimization"); \ + graph_opt.graph_opt.enable_##_layout(); \ + continue; \ + } + + cb(chwn4); + cb(nchw44); + cb(nchw88); + cb(nchw32); + cb(nhwcd4); +#undef cb + if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) { + mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization"); + graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); + continue; + } + if (!strcmp(argv[i], "--enable-fuse-conv-bias-with-z")) { + mgb_log_warn("enable fuse_conv_bias_with_z optimization"); + graph_opt.graph_opt.enable_fuse_conv_bias_with_z(); continue; } #if MGB_ENABLE_JSON diff --git a/src/core/impl/graph/cg_impl.cpp b/src/core/impl/graph/cg_impl.cpp index 53f31e6d..db738e1e 100644 --- a/src/core/impl/graph/cg_impl.cpp +++ b/src/core/impl/graph/cg_impl.cpp @@ -462,13 +462,6 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( options().graph_opt.winograd_transform = false; gopt::transform_vars_inplace_with_winograd(dest_vars); } - if (options().graph_opt.transform_chwn4()) { - gopt::GraphOptimizer optimizer; - optimizer.apply_optimize_options(options().graph_opt); - options().graph_opt.layout_transform = - cg::GraphCommonOptimizeOptions::LayoutTransform::DEFAULT; - optimizer.apply_inplace(dest_vars); - } #if MGB_JIT if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { @@ -480,6 +473,10 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( optimizer.apply_inplace(dest_vars); } #endif + gopt::GraphOptimizer optimizer; + optimizer.apply_optimize_options(options().graph_opt); + options().graph_opt.reset(); + optimizer.apply_inplace(dest_vars); const OprNodeArray* opr_seq = nullptr; CompSeqExtraInfo extra_info; diff --git a/src/core/include/megbrain/graph/cg.h b/src/core/include/megbrain/graph/cg.h index d84cdc6c..229d101f 100644 --- a/src/core/include/megbrain/graph/cg.h +++ b/src/core/include/megbrain/graph/cg.h @@ -92,6 +92,9 @@ struct GraphCommonOptimizeOptions { bool f16_io_comp = false; //! whether to enable conv bias nonlinearity fusion bool fuse_conv_bias_nonlinearity = false; + //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) + //! + z -> conv_bias(x, w, b, z) + bool fuse_conv_bias_with_z = false; enum LayoutTransform : uint32_t { DEFAULT, NHWCD4, ///< compute using NHWCD4 tensor format @@ -103,9 +106,14 @@ struct GraphCommonOptimizeOptions { ///< used for cuda }; LayoutTransform layout_transform = LayoutTransform::DEFAULT; - //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) - //! + z -> conv_bias(x, w, b, z) - bool fuse_conv_bias_with_z = false; + + void reset() { + f16_io_f32_comp = false; + f16_io_comp = false; + fuse_conv_bias_nonlinearity = false; + fuse_conv_bias_with_z = false; + layout_transform = LayoutTransform::DEFAULT; + } #define SET(n) \ GraphCommonOptimizeOptions& enable_##n() { \ @@ -119,6 +127,7 @@ struct GraphCommonOptimizeOptions { #undef SET #define SET(_trans, _trans_capital) \ GraphCommonOptimizeOptions& enable_##_trans() { \ + mgb_assert(layout_transform == LayoutTransform::DEFAULT); \ layout_transform = LayoutTransform::_trans_capital; \ return *this; \ } \ diff --git a/src/gopt/impl/framework.cpp b/src/gopt/impl/framework.cpp index eb1e64a0..ce22d076 100644 --- a/src/gopt/impl/framework.cpp +++ b/src/gopt/impl/framework.cpp @@ -706,21 +706,27 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) { const GraphOptimizer& GraphOptimizer::apply_optimize_options( const cg::GraphCommonOptimizeOptions& options) { + bool need_param_fuse = false; if (options.f16_io_comp) { add_pass(ConvertF32ToF16Pass::make(false)); + need_param_fuse = true; } if (options.f16_io_f32_comp) { add_pass(ConvertF32ToF16Pass::make(true)); + need_param_fuse = true; } if (options.transform_nhwcd4()) { add_pass(ConvertFormatPass::make_nhwcd4_converter()); add_pass(); + need_param_fuse = true; } if (options.transform_nchw88()) { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); + need_param_fuse = true; } if (options.transform_nchw44()) { add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); + need_param_fuse = true; } if (options.transform_nchw32()) { add_pass(); @@ -728,6 +734,7 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options( add_pass(EnableTensorCorePass::make_tensorcore_converter()); add_pass(); add_pass(); + need_param_fuse = true; } if (options.transform_chwn4()) { add_pass(); @@ -735,16 +742,21 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options( add_pass(EnableCHWN4Pass::make_chwn4_converter()); add_pass(); add_pass(); + need_param_fuse = true; } if (options.fuse_conv_bias_nonlinearity) { add_pass(); + need_param_fuse = true; } if (options.fuse_conv_bias_with_z) { add_pass(); add_pass(); + need_param_fuse = true; + } + if (need_param_fuse) { + add_pass(); } - add_pass(); return *this; }