Browse Source

fix(mgb/gopt): fix convert batchnorm to elemwise pass issue

GitOrigin-RevId: eda7f1ab95
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
fade97d4ef
2 changed files with 3 additions and 2 deletions
  1. +2
    -1
      src/gopt/impl/inference.cpp
  2. +1
    -1
      src/gopt/test/inference.cpp

+ 2
- 1
src/gopt/impl/inference.cpp View File

@@ -1592,7 +1592,8 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const {
SymbolVar bias = {rewriter.get_var(bn->input(2))};
SymbolVar mean = {rewriter.get_var(bn->input(3))};
SymbolVar variance = {rewriter.get_var(bn->input(4))};
SymbolVar invsqrt_variance = opr::PowC::make(variance, {-0.5});
SymbolVar invsqrt_variance = opr::PowC::make(variance
+ variance.make_scalar_dt(float(bn->param().epsilon)), {-0.5});
auto res = scale * (x - mean) * invsqrt_variance + bias;
rewriter.replace_var(
opr->output(4), res.node(),


+ 1
- 1
src/gopt/test/inference.cpp View File

@@ -1404,7 +1404,7 @@ TEST(TestGoptInference, ConvertBatchNormPass) {
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-2);
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
}

TEST(TestGoptInference, ConvBiasNonlinearityFusePass) {


Loading…
Cancel
Save