Browse Source

fix(mgb/gopt): add error message when input dtype is not equal to param dtype in BN2Elemwise pass

GitOrigin-RevId: 3d09a2a12e
release-1.6
Megvii Engine Team 3 years ago
parent
commit
0ad377c7cf
1 changed files with 7 additions and 0 deletions
  1. +7
    -0
      src/gopt/impl/inference.cpp

+ 7
- 0
src/gopt/impl/inference.cpp View File

@@ -1686,6 +1686,13 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const {
SymbolVar invsqrt_variance = opr::PowC::make(variance
+ variance.make_scalar_dt(float(bn->param().epsilon)), {-0.5});
auto res = scale * (x - mean) * invsqrt_variance + bias;
if (x.dtype() != res.dtype()) {
mgb_throw(MegBrainError,
"BN's input dtype %s is not compatible with "
"param dtype %s when fusing BN. You may need to "
"dump FP32 model.",
x.dtype().name(), res.dtype().name());
}
rewriter.replace_var(
opr->output(4), res.node(),
mgb_cstr_log(


Loading…
Cancel
Save