|
|
@@ -1686,6 +1686,13 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const { |
|
|
|
SymbolVar invsqrt_variance = opr::PowC::make(variance |
|
|
|
+ variance.make_scalar_dt(float(bn->param().epsilon)), {-0.5}); |
|
|
|
auto res = scale * (x - mean) * invsqrt_variance + bias; |
|
|
|
if (x.dtype() != res.dtype()) { |
|
|
|
mgb_throw(MegBrainError, |
|
|
|
"BN's input dtype %s is not compatible with " |
|
|
|
"param dtype %s when fusing BN. You may need to " |
|
|
|
"dump FP32 model.", |
|
|
|
x.dtype().name(), res.dtype().name()); |
|
|
|
} |
|
|
|
rewriter.replace_var( |
|
|
|
opr->output(4), res.node(), |
|
|
|
mgb_cstr_log( |
|
|
|