Browse Source

feat(gopt): merge consecutive dimshuffle and relayout to one relayout to optimize CD4 performace

GitOrigin-RevId: 16f22baa80
tags/v1.9.0
Megvii Engine Team 3 years ago
parent
commit
1fead9b6b0
2 changed files with 87 additions and 0 deletions
  1. +40
    -0
      src/gopt/impl/inference.cpp
  2. +47
    -0
      src/gopt/test/inference.cpp

+ 40
- 0
src/gopt/impl/inference.cpp View File

@@ -1000,6 +1000,46 @@ void ConvertFormatPass::apply(OptState& state) const {
};
state.graph().iter(on_opr);
rewriter.apply_inplace();

//! start a second pass that merge consecutive dimshuffle(NHWC->NCHW) +
//! relayout_format(NCHW->NHWCD4) to only one relayout_format(NHWC->NHWCD4)
auto on_opr_merge = [&rewriter](OperatorNodeBase* opr) {
auto opr_is_relayout = [](OperatorNodeBase* opr) {
return opr->try_cast_final<opr::RelayoutFormat>();
};
auto opr_is_dimshuffle = [](OperatorNodeBase* opr) {
return opr->try_cast_final<opr::Dimshuffle>();
};
auto match_pattern = [](const opr::Dimshuffle::Param& param,
const std::vector<int>&& patten) {
if (param.pattern_len == patten.size() && param.pattern[0] == patten[0] &&
param.pattern[1] == patten[1] && param.pattern[2] == patten[2] &&
param.pattern[3] == patten[3]) {
return true;
}
return false;
};
auto this_opr_is_relayout = opr_is_relayout(opr);
auto prev_opr_is_dimshuffle = static_cast<opr::Dimshuffle*>(nullptr);
if (this_opr_is_relayout) {
prev_opr_is_dimshuffle = opr_is_dimshuffle(opr->input(0)->owner_opr());
}
if (this_opr_is_relayout && prev_opr_is_dimshuffle) {
if (this_opr_is_relayout->param().mode ==
megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I &&
match_pattern(prev_opr_is_dimshuffle->param(), {0, 3, 1, 2})) {
auto inp = rewriter.get_var(prev_opr_is_dimshuffle->input(0));
auto new_param = megdnn::param::RelayoutFormat();
new_param.mode = megdnn::param::RelayoutFormat::Mode::NHWC_NHWCD4I;
auto new_opr = opr::RelayoutFormat::make(inp, new_param);
rewriter.replace_var(opr->output(0), new_opr.node(), nullptr);
}
} else {
rewriter.auto_replace_outputs(opr);
}
};
state.graph().iter(on_opr_merge);
rewriter.apply_inplace();
MIDOUT_E
}



+ 47
- 0
src/gopt/test/inference.cpp View File

@@ -1318,6 +1318,53 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise0) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}

TEST(TestGoptInference, MergeDimShuffleAndRelayoutFormat) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;

HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};

auto host_x = gen({8, 8, 8, 8}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
auto d0 = opr::Dimshuffle::make(x, {0, 3, 1, 2});

auto a = mkvar("a", {1});
auto b = mkvar("b", {1});
auto y = d0 * a + b;

SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);

ASSERT_EQ(
megdnn::param::RelayoutFormat::Mode::NHWC_NHWCD4I,
find_opr<opr::RelayoutFormat>(y_opt).param().mode);

ASSERT_EQ(0, find_opr_num<opr::Dimshuffle>(y_opt));

graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.MergeDimShuffleAndRelayoutFormat.json"));

HostTensorND host_y_opt, host_y;
auto func = graph->compile(
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);

*host_x = *gen({8, 8, 16, 16}, cn);
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}

TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) {
// hwcd4 is only supported in naive handle
NaiveMegDNNHandleScope naive_megdnn_handle;


Loading…
Cancel
Save