Browse Source

feat(mgb/gopt): add AxisAddRemove opr support for cd4 opt pass

GitOrigin-RevId: 85218ee0c4
release-1.2
Megvii Engine Team 4 years ago
parent
commit
5f171298aa
2 changed files with 9 additions and 2 deletions
  1. +1
    -0
      src/gopt/impl/inference.cpp
  2. +8
    -2
      src/gopt/test/inference.cpp

+ 1
- 0
src/gopt/impl/inference.cpp View File

@@ -1635,6 +1635,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() {
replace_func[opr::Subtensor::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::Broadcast::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::AxisAddRemove::typeinfo()] = relayout_inp_to_chw;
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
replace_func[opr::WarpPerspectiveForward::typeinfo()] =
replace_warp_perspective_opr;


+ 8
- 2
src/gopt/test/inference.cpp View File

@@ -1171,16 +1171,22 @@ TEST(TestGoptInference, ConvertFormatNHWCD4) {
opr::Elemwise::Param::Mode::RELU);
param.pad_h = param.pad_w = 1;
auto w2 = mkcvar("w2", {4, 4, 3, 3}),
y = opr::Convolution::make(elem, w2, param);
y = opr::Convolution::make(elem, w2, param),
z = opr::AxisAddRemove::make(y, {opr::AxisAddRemove::AxisDesc::make_add(0)});

SymbolVar y_opt;
SymbolVar y_opt, z_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nhwcd4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
unpack_vector(gopt::optimize_for_inference({z}, options), z_opt);

ASSERT_EQ(opr::Convolution::Param::Format::NHWCD4,
find_opr<opr::Convolution>(y_opt).param().format);

ASSERT_EQ(TensorFormat::Type::DEFAULT,
find_opr<opr::AxisAddRemove>(z_opt).input(0)->format().type());
ASSERT_EQ(4, find_opr<opr::AxisAddRemove>(z_opt).input(0)->shape().ndim);

graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(


Loading…
Cancel
Save