|
|
@@ -675,7 +675,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( |
|
|
|
|
|
|
|
if (inference_opt) { |
|
|
|
add_pass<ParamFusePass>(); |
|
|
|
apply_optimize_options(*inference_opt); |
|
|
|
add_passes_for_optimize_options(*inference_opt); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
@@ -704,56 +704,56 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
const GraphOptimizer& GraphOptimizer::apply_optimize_options( |
|
|
|
|
|
|
|
const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( |
|
|
|
const cg::GraphCommonOptimizeOptions& options) { |
|
|
|
return add_passes_for_optimize_options( |
|
|
|
const_cast<cg::GraphCommonOptimizeOptions&>(options)); |
|
|
|
} |
|
|
|
|
|
|
|
const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( |
|
|
|
cg::GraphCommonOptimizeOptions& options, bool reset) { |
|
|
|
bool need_param_fuse = false; |
|
|
|
if (options.f16_io_comp) { |
|
|
|
add_pass(ConvertF32ToF16Pass::make(false)); |
|
|
|
need_param_fuse = true; |
|
|
|
} |
|
|
|
if (options.f16_io_f32_comp) { |
|
|
|
add_pass(ConvertF32ToF16Pass::make(true)); |
|
|
|
need_param_fuse = true; |
|
|
|
|
|
|
|
#define cb(_option, _passes) \ |
|
|
|
if (options.has_set_##_option()) { \ |
|
|
|
_passes need_param_fuse = true; \ |
|
|
|
if (reset) { \ |
|
|
|
options.disable_##_option(); \ |
|
|
|
} \ |
|
|
|
} |
|
|
|
if (options.transform_nhwcd4()) { |
|
|
|
cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); }); |
|
|
|
cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); }); |
|
|
|
|
|
|
|
cb(nhwcd4, { |
|
|
|
add_pass<FuseConvBiasNonlinPass>(); |
|
|
|
add_pass(ConvertFormatPass::make_nhwcd4_converter()); |
|
|
|
need_param_fuse = true; |
|
|
|
} |
|
|
|
if (options.transform_nchw88()) { |
|
|
|
add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); |
|
|
|
need_param_fuse = true; |
|
|
|
} |
|
|
|
if (options.transform_nchw44()) { |
|
|
|
add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); |
|
|
|
need_param_fuse = true; |
|
|
|
} |
|
|
|
if (options.transform_nchw32()) { |
|
|
|
}); |
|
|
|
cb(nchw88, { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); }); |
|
|
|
cb(nchw44, { add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); }); |
|
|
|
cb(nchw32, { |
|
|
|
add_pass<FuseConvBiasNonlinPass>(); |
|
|
|
add_pass<FuseConvBiasZPass>(); |
|
|
|
add_pass(EnableTensorCorePass::make_tensorcore_converter()); |
|
|
|
add_pass<ShuffleShuffleRemovePass>(); |
|
|
|
add_pass<RemoveRedundantTypeCvtPass>(); |
|
|
|
need_param_fuse = true; |
|
|
|
} |
|
|
|
if (options.transform_chwn4()) { |
|
|
|
}); |
|
|
|
cb(chwn4, { |
|
|
|
add_pass<FuseConvBiasNonlinPass>(); |
|
|
|
add_pass<FuseConvBiasZPass>(); |
|
|
|
add_pass(EnableCHWN4Pass::make_chwn4_converter()); |
|
|
|
add_pass<ShuffleShuffleRemovePass>(); |
|
|
|
add_pass<RemoveRedundantTypeCvtPass>(); |
|
|
|
need_param_fuse = true; |
|
|
|
} |
|
|
|
}); |
|
|
|
|
|
|
|
if (options.fuse_conv_bias_nonlinearity) { |
|
|
|
add_pass<FuseConvBiasNonlinPass>(); |
|
|
|
need_param_fuse = true; |
|
|
|
} |
|
|
|
if (options.fuse_conv_bias_with_z) { |
|
|
|
cb(fuse_conv_bias_nonlinearity, { add_pass<FuseConvBiasNonlinPass>(); }); |
|
|
|
cb(fuse_conv_bias_with_z, { |
|
|
|
add_pass<FuseConvBiasNonlinPass>(); |
|
|
|
add_pass<FuseConvBiasZPass>(); |
|
|
|
need_param_fuse = true; |
|
|
|
} |
|
|
|
}); |
|
|
|
|
|
|
|
#undef cb |
|
|
|
|
|
|
|
if (need_param_fuse) { |
|
|
|
add_pass<ParamFusePass>(); |
|
|
|
} |
|
|
|