|
|
@@ -1592,7 +1592,8 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const { |
|
|
|
SymbolVar bias = {rewriter.get_var(bn->input(2))}; |
|
|
|
SymbolVar mean = {rewriter.get_var(bn->input(3))}; |
|
|
|
SymbolVar variance = {rewriter.get_var(bn->input(4))}; |
|
|
|
SymbolVar invsqrt_variance = opr::PowC::make(variance, {-0.5}); |
|
|
|
SymbolVar invsqrt_variance = opr::PowC::make(variance |
|
|
|
+ variance.make_scalar_dt(float(bn->param().epsilon)), {-0.5}); |
|
|
|
auto res = scale * (x - mean) * invsqrt_variance + bias; |
|
|
|
rewriter.replace_var( |
|
|
|
opr->output(4), res.node(), |
|
|
|