GitOrigin-RevId: 16f22baa80
tags/v1.9.0
@@ -1000,6 +1000,46 @@ void ConvertFormatPass::apply(OptState& state) const { | |||
}; | |||
state.graph().iter(on_opr); | |||
rewriter.apply_inplace(); | |||
//! start a second pass that merge consecutive dimshuffle(NHWC->NCHW) + | |||
//! relayout_format(NCHW->NHWCD4) to only one relayout_format(NHWC->NHWCD4) | |||
auto on_opr_merge = [&rewriter](OperatorNodeBase* opr) { | |||
auto opr_is_relayout = [](OperatorNodeBase* opr) { | |||
return opr->try_cast_final<opr::RelayoutFormat>(); | |||
}; | |||
auto opr_is_dimshuffle = [](OperatorNodeBase* opr) { | |||
return opr->try_cast_final<opr::Dimshuffle>(); | |||
}; | |||
auto match_pattern = [](const opr::Dimshuffle::Param& param, | |||
const std::vector<int>&& patten) { | |||
if (param.pattern_len == patten.size() && param.pattern[0] == patten[0] && | |||
param.pattern[1] == patten[1] && param.pattern[2] == patten[2] && | |||
param.pattern[3] == patten[3]) { | |||
return true; | |||
} | |||
return false; | |||
}; | |||
auto this_opr_is_relayout = opr_is_relayout(opr); | |||
auto prev_opr_is_dimshuffle = static_cast<opr::Dimshuffle*>(nullptr); | |||
if (this_opr_is_relayout) { | |||
prev_opr_is_dimshuffle = opr_is_dimshuffle(opr->input(0)->owner_opr()); | |||
} | |||
if (this_opr_is_relayout && prev_opr_is_dimshuffle) { | |||
if (this_opr_is_relayout->param().mode == | |||
megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I && | |||
match_pattern(prev_opr_is_dimshuffle->param(), {0, 3, 1, 2})) { | |||
auto inp = rewriter.get_var(prev_opr_is_dimshuffle->input(0)); | |||
auto new_param = megdnn::param::RelayoutFormat(); | |||
new_param.mode = megdnn::param::RelayoutFormat::Mode::NHWC_NHWCD4I; | |||
auto new_opr = opr::RelayoutFormat::make(inp, new_param); | |||
rewriter.replace_var(opr->output(0), new_opr.node(), nullptr); | |||
} | |||
} else { | |||
rewriter.auto_replace_outputs(opr); | |||
} | |||
}; | |||
state.graph().iter(on_opr_merge); | |||
rewriter.apply_inplace(); | |||
MIDOUT_E | |||
} | |||
@@ -1318,6 +1318,53 @@ TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise0) { | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||
} | |||
TEST(TestGoptInference, MergeDimShuffleAndRelayoutFormat) { | |||
// hwcd4 is only supported in naive handle | |||
NaiveMegDNNHandleScope naive_megdnn_handle; | |||
HostTensorGenerator<> gen; | |||
auto cn = CompNode::load("cpu0"); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||
}; | |||
auto host_x = gen({8, 8, 8, 8}, cn); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||
auto d0 = opr::Dimshuffle::make(x, {0, 3, 1, 2}); | |||
auto a = mkvar("a", {1}); | |||
auto b = mkvar("b", {1}); | |||
auto y = d0 * a + b; | |||
SymbolVar y_opt; | |||
auto options = gopt::OptimizeForInferenceOptions{}; | |||
options.enable_nhwcd4(); | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
ASSERT_EQ( | |||
megdnn::param::RelayoutFormat::Mode::NHWC_NHWCD4I, | |||
find_opr<opr::RelayoutFormat>(y_opt).param().mode); | |||
ASSERT_EQ(0, find_opr_num<opr::Dimshuffle>(y_opt)); | |||
graph->compile({{y_opt, {}}}) | |||
->to_json() | |||
->writeto_fpath(output_file( | |||
"TestGoptInference.MergeDimShuffleAndRelayoutFormat.json")); | |||
HostTensorND host_y_opt, host_y; | |||
auto func = graph->compile( | |||
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); | |||
func->execute(); | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||
*host_x = *gen({8, 8, 16, 16}, cn); | |||
func->execute(); | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||
} | |||
TEST(TestGoptInference, ConvertFormatNHWCD4Elemwise) { | |||
// hwcd4 is only supported in naive handle | |||
NaiveMegDNNHandleScope naive_megdnn_handle; | |||