Browse Source

fix(mgb/gopt): fix NormalizeArithChainPass to process sub-div chains

GitOrigin-RevId: d71debbfde
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
0f739c1118
2 changed files with 22 additions and 0 deletions
  1. +9
    -0
      src/gopt/impl/basic_arith/chain.cpp
  2. +13
    -0
      src/gopt/test/basic_arith.cpp

+ 9
- 0
src/gopt/impl/basic_arith/chain.cpp View File

@@ -322,6 +322,15 @@ NormalizeArithChainPass::Impl::AddTrait::extract_coeff(
return AbstractOpr::make_coeff(
i1.node(), i0v->get_cast<dt_max_float>());
}
if (mode == Mode::TRUE_DIV) {
SymbolVar i0 = opr->input(0), i1 = opr->input(1),
i1r = opr::powf(i1, -1);
auto i1rv = i1r.as_immutable_scalar_require_shape();
if (!i1rv.valid())
return None;
return AbstractOpr::make_coeff(
i0.node(), i1rv->get_cast<dt_max_float>());
}
return None;
}



+ 13
- 0
src/gopt/test/basic_arith.cpp View File

@@ -548,6 +548,19 @@ TEST(TestNormalizeArithChainPass, PowcCExpand2) {
graph->compile({make_callback_copy(grad, host_g)}));
}

TEST_PASS(NormalizeArithChainPass, SubDiv) {
auto x = mkvar("x"), y = mkvar("y"), z = mkvar("z"),
a0_ = x - y / 2.f,
a1 = x + (-0.5f) * y,
b0_ = x - ((y - (z / 5.f)) / 2.f),
b1 = x + (-0.5f) * y + 0.1f * z;

SymbolVar a0, b0;
unpack_vector(run_opt({a0_, b0_}), a0, b0);
EXPECT_EQ(a1, a0);
EXPECT_EQ(b1, b0);
}

TEST_PASS(ReorderArithChainPass, 0) {
auto chk = [this](SymbolVar inp, SymbolVar expect) {
check(expect, inp, gopt::ConstVarType::IMMUTABLE_AND_PARAM);


Loading…
Cancel
Save