GitOrigin-RevId: 9faa7ef068
tags/v0.5.0
@@ -972,9 +972,28 @@ Args Args::from_argv(int argc, char **argv) { | |||||
continue; | continue; | ||||
} | } | ||||
#endif | #endif | ||||
if (!strcmp(argv[i], "--enable-chwn4")) { | |||||
mgb_log_warn("enable chwn4 optimization"); | |||||
graph_opt.graph_opt.enable_chwn4(); | |||||
#define cb(_layout) \ | |||||
if (!strcmp(argv[i], "--enable-" #_layout)) { \ | |||||
mgb_log_warn("enable " #_layout " optimization"); \ | |||||
graph_opt.graph_opt.enable_##_layout(); \ | |||||
continue; \ | |||||
} | |||||
cb(chwn4); | |||||
cb(nchw44); | |||||
cb(nchw88); | |||||
cb(nchw32); | |||||
cb(nhwcd4); | |||||
#undef cb | |||||
if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) { | |||||
mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization"); | |||||
graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); | |||||
continue; | |||||
} | |||||
if (!strcmp(argv[i], "--enable-fuse-conv-bias-with-z")) { | |||||
mgb_log_warn("enable fuse_conv_bias_with_z optimization"); | |||||
graph_opt.graph_opt.enable_fuse_conv_bias_with_z(); | |||||
continue; | continue; | ||||
} | } | ||||
#if MGB_ENABLE_JSON | #if MGB_ENABLE_JSON | ||||
@@ -462,13 +462,6 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||||
options().graph_opt.winograd_transform = false; | options().graph_opt.winograd_transform = false; | ||||
gopt::transform_vars_inplace_with_winograd(dest_vars); | gopt::transform_vars_inplace_with_winograd(dest_vars); | ||||
} | } | ||||
if (options().graph_opt.transform_chwn4()) { | |||||
gopt::GraphOptimizer optimizer; | |||||
optimizer.apply_optimize_options(options().graph_opt); | |||||
options().graph_opt.layout_transform = | |||||
cg::GraphCommonOptimizeOptions::LayoutTransform::DEFAULT; | |||||
optimizer.apply_inplace(dest_vars); | |||||
} | |||||
#if MGB_JIT | #if MGB_JIT | ||||
if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { | if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { | ||||
@@ -480,6 +473,10 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||||
optimizer.apply_inplace(dest_vars); | optimizer.apply_inplace(dest_vars); | ||||
} | } | ||||
#endif | #endif | ||||
gopt::GraphOptimizer optimizer; | |||||
optimizer.apply_optimize_options(options().graph_opt); | |||||
options().graph_opt.reset(); | |||||
optimizer.apply_inplace(dest_vars); | |||||
const OprNodeArray* opr_seq = nullptr; | const OprNodeArray* opr_seq = nullptr; | ||||
CompSeqExtraInfo extra_info; | CompSeqExtraInfo extra_info; | ||||
@@ -92,6 +92,9 @@ struct GraphCommonOptimizeOptions { | |||||
bool f16_io_comp = false; | bool f16_io_comp = false; | ||||
//! whether to enable conv bias nonlinearity fusion | //! whether to enable conv bias nonlinearity fusion | ||||
bool fuse_conv_bias_nonlinearity = false; | bool fuse_conv_bias_nonlinearity = false; | ||||
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) | |||||
//! + z -> conv_bias(x, w, b, z) | |||||
bool fuse_conv_bias_with_z = false; | |||||
enum LayoutTransform : uint32_t { | enum LayoutTransform : uint32_t { | ||||
DEFAULT, | DEFAULT, | ||||
NHWCD4, ///< compute using NHWCD4 tensor format | NHWCD4, ///< compute using NHWCD4 tensor format | ||||
@@ -103,9 +106,14 @@ struct GraphCommonOptimizeOptions { | |||||
///< used for cuda | ///< used for cuda | ||||
}; | }; | ||||
LayoutTransform layout_transform = LayoutTransform::DEFAULT; | LayoutTransform layout_transform = LayoutTransform::DEFAULT; | ||||
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) | |||||
//! + z -> conv_bias(x, w, b, z) | |||||
bool fuse_conv_bias_with_z = false; | |||||
void reset() { | |||||
f16_io_f32_comp = false; | |||||
f16_io_comp = false; | |||||
fuse_conv_bias_nonlinearity = false; | |||||
fuse_conv_bias_with_z = false; | |||||
layout_transform = LayoutTransform::DEFAULT; | |||||
} | |||||
#define SET(n) \ | #define SET(n) \ | ||||
GraphCommonOptimizeOptions& enable_##n() { \ | GraphCommonOptimizeOptions& enable_##n() { \ | ||||
@@ -119,6 +127,7 @@ struct GraphCommonOptimizeOptions { | |||||
#undef SET | #undef SET | ||||
#define SET(_trans, _trans_capital) \ | #define SET(_trans, _trans_capital) \ | ||||
GraphCommonOptimizeOptions& enable_##_trans() { \ | GraphCommonOptimizeOptions& enable_##_trans() { \ | ||||
mgb_assert(layout_transform == LayoutTransform::DEFAULT); \ | |||||
layout_transform = LayoutTransform::_trans_capital; \ | layout_transform = LayoutTransform::_trans_capital; \ | ||||
return *this; \ | return *this; \ | ||||
} \ | } \ | ||||
@@ -706,21 +706,27 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) { | |||||
const GraphOptimizer& GraphOptimizer::apply_optimize_options( | const GraphOptimizer& GraphOptimizer::apply_optimize_options( | ||||
const cg::GraphCommonOptimizeOptions& options) { | const cg::GraphCommonOptimizeOptions& options) { | ||||
bool need_param_fuse = false; | |||||
if (options.f16_io_comp) { | if (options.f16_io_comp) { | ||||
add_pass(ConvertF32ToF16Pass::make(false)); | add_pass(ConvertF32ToF16Pass::make(false)); | ||||
need_param_fuse = true; | |||||
} | } | ||||
if (options.f16_io_f32_comp) { | if (options.f16_io_f32_comp) { | ||||
add_pass(ConvertF32ToF16Pass::make(true)); | add_pass(ConvertF32ToF16Pass::make(true)); | ||||
need_param_fuse = true; | |||||
} | } | ||||
if (options.transform_nhwcd4()) { | if (options.transform_nhwcd4()) { | ||||
add_pass(ConvertFormatPass::make_nhwcd4_converter()); | add_pass(ConvertFormatPass::make_nhwcd4_converter()); | ||||
add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
need_param_fuse = true; | |||||
} | } | ||||
if (options.transform_nchw88()) { | if (options.transform_nchw88()) { | ||||
add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); | add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); | ||||
need_param_fuse = true; | |||||
} | } | ||||
if (options.transform_nchw44()) { | if (options.transform_nchw44()) { | ||||
add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); | add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); | ||||
need_param_fuse = true; | |||||
} | } | ||||
if (options.transform_nchw32()) { | if (options.transform_nchw32()) { | ||||
add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
@@ -728,6 +734,7 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options( | |||||
add_pass(EnableTensorCorePass::make_tensorcore_converter()); | add_pass(EnableTensorCorePass::make_tensorcore_converter()); | ||||
add_pass<ShuffleShuffleRemovePass>(); | add_pass<ShuffleShuffleRemovePass>(); | ||||
add_pass<RemoveRedundantTypeCvtPass>(); | add_pass<RemoveRedundantTypeCvtPass>(); | ||||
need_param_fuse = true; | |||||
} | } | ||||
if (options.transform_chwn4()) { | if (options.transform_chwn4()) { | ||||
add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
@@ -735,16 +742,21 @@ const GraphOptimizer& GraphOptimizer::apply_optimize_options( | |||||
add_pass(EnableCHWN4Pass::make_chwn4_converter()); | add_pass(EnableCHWN4Pass::make_chwn4_converter()); | ||||
add_pass<ShuffleShuffleRemovePass>(); | add_pass<ShuffleShuffleRemovePass>(); | ||||
add_pass<RemoveRedundantTypeCvtPass>(); | add_pass<RemoveRedundantTypeCvtPass>(); | ||||
need_param_fuse = true; | |||||
} | } | ||||
if (options.fuse_conv_bias_nonlinearity) { | if (options.fuse_conv_bias_nonlinearity) { | ||||
add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
need_param_fuse = true; | |||||
} | } | ||||
if (options.fuse_conv_bias_with_z) { | if (options.fuse_conv_bias_with_z) { | ||||
add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
add_pass<FuseConvBiasZPass>(); | add_pass<FuseConvBiasZPass>(); | ||||
need_param_fuse = true; | |||||
} | |||||
if (need_param_fuse) { | |||||
add_pass<ParamFusePass>(); | |||||
} | } | ||||
add_pass<ParamFusePass>(); | |||||
return *this; | return *this; | ||||
} | } | ||||