GitOrigin-RevId: 9faa7ef068
tags/v0.5.0
@@ -972,9 +972,28 @@ Args Args::from_argv(int argc, char **argv) { | |||
continue; | |||
} | |||
#endif | |||
if (!strcmp(argv[i], "--enable-chwn4")) { | |||
mgb_log_warn("enable chwn4 optimization"); | |||
graph_opt.graph_opt.enable_chwn4(); | |||
#define cb(_layout) \ | |||
if (!strcmp(argv[i], "--enable-" #_layout)) { \ | |||
mgb_log_warn("enable " #_layout " optimization"); \ | |||
graph_opt.graph_opt.enable_##_layout(); \ | |||
continue; \ | |||
} | |||
cb(chwn4); | |||
cb(nchw44); | |||
cb(nchw88); | |||
cb(nchw32); | |||
cb(nhwcd4); | |||
#undef cb | |||
if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) { | |||
mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization"); | |||
graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); | |||
continue; | |||
} | |||
if (!strcmp(argv[i], "--enable-fuse-conv-bias-with-z")) { | |||
mgb_log_warn("enable fuse_conv_bias_with_z optimization"); | |||
graph_opt.graph_opt.enable_fuse_conv_bias_with_z(); | |||
continue; | |||
} | |||
#if MGB_ENABLE_JSON | |||
@@ -462,13 +462,6 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||
options().graph_opt.winograd_transform = false; | |||
gopt::transform_vars_inplace_with_winograd(dest_vars); | |||
} | |||
if (options().graph_opt.transform_chwn4()) { | |||
gopt::GraphOptimizer optimizer; | |||
optimizer.apply_optimize_options(options().graph_opt); | |||
options().graph_opt.layout_transform = | |||
cg::GraphCommonOptimizeOptions::LayoutTransform::DEFAULT; | |||
optimizer.apply_inplace(dest_vars); | |||
} | |||
#if MGB_JIT | |||
if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { | |||
@@ -480,6 +473,10 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||
optimizer.apply_inplace(dest_vars); | |||
} | |||
#endif | |||
gopt::GraphOptimizer optimizer; | |||
optimizer.apply_optimize_options(options().graph_opt); | |||
options().graph_opt.reset(); | |||
optimizer.apply_inplace(dest_vars); | |||
const OprNodeArray* opr_seq = nullptr; | |||
CompSeqExtraInfo extra_info; | |||
@@ -92,6 +92,9 @@ struct GraphCommonOptimizeOptions { | |||
bool f16_io_comp = false; | |||
//! whether to enable conv bias nonlinearity fusion | |||
bool fuse_conv_bias_nonlinearity = false; | |||
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) | |||
//! + z -> conv_bias(x, w, b, z) | |||
bool fuse_conv_bias_with_z = false; | |||
enum LayoutTransform : uint32_t { | |||
DEFAULT, | |||
NHWCD4, ///< compute using NHWCD4 tensor format | |||
@@ -103,9 +106,14 @@ struct GraphCommonOptimizeOptions { | |||
///< used for cuda | |||
}; | |||
LayoutTransform layout_transform = LayoutTransform::DEFAULT; | |||
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) | |||
//! + z -> conv_bias(x, w, b, z) | |||
bool fuse_conv_bias_with_z = false; | |||
void reset() { | |||
f16_io_f32_comp = false; | |||
f16_io_comp = false; | |||
fuse_conv_bias_nonlinearity = false; | |||
fuse_conv_bias_with_z = false; | |||
layout_transform = LayoutTransform::DEFAULT; | |||
} | |||
#define SET(n) \ | |||
GraphCommonOptimizeOptions& enable_##n() { \ | |||
@@ -119,6 +127,7 @@ struct GraphCommonOptimizeOptions { | |||
#undef SET | |||
#define SET(_trans, _trans_capital) \ | |||
GraphCommonOptimizeOptions& enable_##_trans() { \ | |||
mgb_assert(layout_transform == LayoutTransform::DEFAULT); \ | |||
layout_transform = LayoutTransform::_trans_capital; \ | |||
return *this; \ | |||
} \ | |||
@@ -706,21 +706,27 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) { | |||
const GraphOptimizer& GraphOptimizer::apply_optimize_options( | |||
const cg::GraphCommonOptimizeOptions& options) { | |||
bool need_param_fuse = false; | |||
if (options.f16_io_comp) { | |||
add_pass(ConvertF32ToF16Pass::make(false)); | |||
need_param_fuse = true; | |||
} | |||
if (options.f16_io_f32_comp) { | |||
add_pass(ConvertF32ToF16Pass::make(true)); | |||
need_param_fuse = true; | |||
} | |||
if (options.transform_nhwcd4()) { | |||
add_pass(ConvertFormatPass::make_nhwcd4_converter()); | |||
add_pass<FuseConvBiasNonlinPass>(); | |||
need_param_fuse = true; | |||
} | |||
if (options.transform_nchw88()) { | |||
add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); | |||
need_param_fuse = true; | |||
} | |||
if (options.transform_nchw44()) { | |||
add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); | |||
need_param_fuse = true; | |||
} | |||
if (options.transform_nchw32()) { | |||
add_pass<FuseConvBiasNonlinPass>(); | |||
@@ -728,6 +734,7 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options( | |||
add_pass(EnableTensorCorePass::make_tensorcore_converter()); | |||
add_pass<ShuffleShuffleRemovePass>(); | |||
add_pass<RemoveRedundantTypeCvtPass>(); | |||
need_param_fuse = true; | |||
} | |||
if (options.transform_chwn4()) { | |||
add_pass<FuseConvBiasNonlinPass>(); | |||
@@ -735,16 +742,21 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options( | |||
add_pass(EnableCHWN4Pass::make_chwn4_converter()); | |||
add_pass<ShuffleShuffleRemovePass>(); | |||
add_pass<RemoveRedundantTypeCvtPass>(); | |||
need_param_fuse = true; | |||
} | |||
if (options.fuse_conv_bias_nonlinearity) { | |||
add_pass<FuseConvBiasNonlinPass>(); | |||
need_param_fuse = true; | |||
} | |||
if (options.fuse_conv_bias_with_z) { | |||
add_pass<FuseConvBiasNonlinPass>(); | |||
add_pass<FuseConvBiasZPass>(); | |||
need_param_fuse = true; | |||
} | |||
if (need_param_fuse) { | |||
add_pass<ParamFusePass>(); | |||
} | |||
add_pass<ParamFusePass>(); | |||
return *this; | |||
} | |||