Browse Source

fix(mgb/gopt): remove redundant reshape in nchw->nchw4 pass

GitOrigin-RevId: 0f5c7c3e48
release-0.5
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
90d5895ec3
2 changed files with 8 additions and 7 deletions
  1. +6
    -7
      src/gopt/impl/tensor_reformat.cpp
  2. +2
    -0
      src/gopt/test/inference.cpp

+ 6
- 7
src/gopt/impl/tensor_reformat.cpp View File

@@ -435,13 +435,10 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0);
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
return y1.node();
};
reformat[LayoutType::NCHW4_TO_NCHW] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
@@ -455,7 +452,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] = [](VarNode* inp) -> VarNode* {
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
@@ -471,7 +469,8 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] = [](VarNode* inp) -> VarNode* {
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };


+ 2
- 0
src/gopt/test/inference.cpp View File

@@ -2450,6 +2450,8 @@ TEST(TestGoptInference, ConvertFormatNCHW4GPU) {

ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4,
find_opr<opr::ConvBias>(y_opt).param().format);
auto nr_reshape = find_opr_num<mgb::opr::Reshape>(y_opt);
ASSERT_EQ(2u, nr_reshape);

graph->compile({{y_opt, {}}})
->to_json()


Loading…
Cancel
Save