Browse Source

refactor(gopt): refactor interface of add passes for common optimizations

GitOrigin-RevId: d0f3819c3a
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
6d6b42bb77
4 changed files with 58 additions and 50 deletions
  1. +1
    -2
      src/core/impl/graph/cg_impl.cpp
  2. +15
    -13
      src/core/include/megbrain/graph/cg.h
  3. +33
    -33
      src/gopt/impl/framework.cpp
  4. +9
    -2
      src/gopt/include/megbrain/gopt/framework.h

+ 1
- 2
src/core/impl/graph/cg_impl.cpp View File

@@ -474,8 +474,7 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
}
#endif
gopt::GraphOptimizer optimizer;
optimizer.apply_optimize_options(options().graph_opt);
options().graph_opt.reset();
optimizer.add_passes_for_optimize_options(options().graph_opt, true);
optimizer.apply_inplace(dest_vars);

const OprNodeArray* opr_seq = nullptr;


+ 15
- 13
src/core/include/megbrain/graph/cg.h View File

@@ -107,19 +107,17 @@ struct GraphCommonOptimizeOptions {
};
LayoutTransform layout_transform = LayoutTransform::DEFAULT;

void reset() {
f16_io_f32_comp = false;
f16_io_comp = false;
fuse_conv_bias_nonlinearity = false;
fuse_conv_bias_with_z = false;
layout_transform = LayoutTransform::DEFAULT;
}
#define SET(n) \
GraphCommonOptimizeOptions& enable_##n() { \
n = true; \
return *this; \
} \
GraphCommonOptimizeOptions& disable_##n() { \
n = false; \
return *this; \
} \
bool has_set_##n() { return n == true; }

#define SET(n) \
GraphCommonOptimizeOptions& enable_##n() { \
n = true; \
return *this; \
}
SET(f16_io_f32_comp);
SET(f16_io_comp);
SET(fuse_conv_bias_nonlinearity);
@@ -131,7 +129,11 @@ struct GraphCommonOptimizeOptions {
layout_transform = LayoutTransform::_trans_capital; \
return *this; \
} \
bool transform_##_trans() const { \
GraphCommonOptimizeOptions& disable_##_trans() { \
layout_transform = LayoutTransform::DEFAULT; \
return *this; \
} \
bool has_set_##_trans() const { \
return layout_transform == LayoutTransform::_trans_capital; \
}



+ 33
- 33
src/gopt/impl/framework.cpp View File

@@ -675,7 +675,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes(

if (inference_opt) {
add_pass<ParamFusePass>();
apply_optimize_options(*inference_opt);
add_passes_for_optimize_options(*inference_opt);
}


@@ -704,56 +704,56 @@ VarNode* GraphOptimizer::var_replace_lookup(VarNode *var) {
}
}

const GraphOptimizer& GraphOptimizer::apply_optimize_options(

const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
const cg::GraphCommonOptimizeOptions& options) {
return add_passes_for_optimize_options(
const_cast<cg::GraphCommonOptimizeOptions&>(options));
}

const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
cg::GraphCommonOptimizeOptions& options, bool reset) {
bool need_param_fuse = false;
if (options.f16_io_comp) {
add_pass(ConvertF32ToF16Pass::make(false));
need_param_fuse = true;
}
if (options.f16_io_f32_comp) {
add_pass(ConvertF32ToF16Pass::make(true));
need_param_fuse = true;
#define cb(_option, _passes) \
if (options.has_set_##_option()) { \
_passes need_param_fuse = true; \
if (reset) { \
options.disable_##_option(); \
} \
}
if (options.transform_nhwcd4()) {
cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); });
cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); });

cb(nhwcd4, {
add_pass<FuseConvBiasNonlinPass>();
add_pass(ConvertFormatPass::make_nhwcd4_converter());
need_param_fuse = true;
}
if (options.transform_nchw88()) {
add_pass(EnableNchwxxPass::make_nchwxx_converter(8));
need_param_fuse = true;
}
if (options.transform_nchw44()) {
add_pass(EnableNchwxxPass::make_nchwxx_converter(4));
need_param_fuse = true;
}
if (options.transform_nchw32()) {
});
cb(nchw88, { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); });
cb(nchw44, { add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); });
cb(nchw32, {
add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
add_pass(EnableTensorCorePass::make_tensorcore_converter());
add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>();
need_param_fuse = true;
}
if (options.transform_chwn4()) {
});
cb(chwn4, {
add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
add_pass(EnableCHWN4Pass::make_chwn4_converter());
add_pass<ShuffleShuffleRemovePass>();
add_pass<RemoveRedundantTypeCvtPass>();
need_param_fuse = true;
}
});

if (options.fuse_conv_bias_nonlinearity) {
add_pass<FuseConvBiasNonlinPass>();
need_param_fuse = true;
}
if (options.fuse_conv_bias_with_z) {
cb(fuse_conv_bias_nonlinearity, { add_pass<FuseConvBiasNonlinPass>(); });
cb(fuse_conv_bias_with_z, {
add_pass<FuseConvBiasNonlinPass>();
add_pass<FuseConvBiasZPass>();
need_param_fuse = true;
}
});

#undef cb

if (need_param_fuse) {
add_pass<ParamFusePass>();
}


+ 9
- 2
src/gopt/include/megbrain/gopt/framework.h View File

@@ -468,9 +468,16 @@ namespace gopt {
static VarNode* var_replace_lookup(VarNode *var);

/**
* \brief apply optimize options
* \brief add pass indicated by optimize options.
*
* \param options common options
* \param reset if set true, it will reset options when add passes.
*/
const GraphOptimizer& apply_optimize_options(
const GraphOptimizer& add_passes_for_optimize_options(
cg::GraphCommonOptimizeOptions& options,
bool reset = false);

const GraphOptimizer& add_passes_for_optimize_options(
const cg::GraphCommonOptimizeOptions& options);
};



Loading…
Cancel
Save