Browse Source

feat(opr): add bn backward for inference mode

GitOrigin-RevId: bb643cb62f
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
e6caa9ff89
3 changed files with 57 additions and 35 deletions
  1. +0
    -10
      imperative/python/megengine/module/batchnorm.py
  2. +29
    -17
      imperative/python/test/integration/test_bn.py
  3. +28
    -8
      src/opr/impl/dnn/batch_norm.cpp

+ 0
- 10
imperative/python/megengine/module/batchnorm.py View File

@@ -100,16 +100,6 @@ class _BatchNorm(Module):
if _bias is not None:
_bias = _bias.detach()

# Need to expand to elementwise operations here
# see MGB_IMPL_OPR_GRAD(BatchNormForward) in src/opr/impl/dnn/batch_norm.cpp
scale = (self.running_var + self.eps) ** (-0.5)
if _weight is not None:
scale *= _weight
bias = -self.running_mean * scale
if _bias is not None:
bias += _bias
return inp * scale + bias

if self.training and self.track_running_stats:
exponential_average_factor = self.momentum
else:


+ 29
- 17
imperative/python/test/integration/test_bn.py View File

@@ -19,9 +19,13 @@ from megengine.jit import trace
from megengine.module import BatchNorm2d, Conv2d, Module, Sequential, SyncBatchNorm


def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False):
def run_frozen_bn(BNModule, is_training, use_trace, use_symbolic):
nchannel = 3
m = BNModule(nchannel, freeze=True)
if is_training:
m.train()
else:
m.eval()
var = 4.0
bias = 1.0
shape = (1, nchannel, 1, 1)
@@ -51,30 +55,33 @@ def run_frozen_bn(BNModule, use_trace=False, use_symbolic=False):
train_fn = trace(train_fn, symbolic=use_symbolic)

for _ in range(3):
loss = train_fn(megengine.Tensor(data))
np.testing.assert_equal(m.running_var.numpy(), saved_var)
np.testing.assert_equal(m.running_mean.numpy(), saved_mean)
loss = train_fn(megengine.tensor(data))
if not is_training:
np.testing.assert_equal(m.running_var.numpy(), saved_var)
np.testing.assert_equal(m.running_mean.numpy(), saved_mean)
np.testing.assert_almost_equal(
loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5
)
np.testing.assert_equal(m.weight.numpy(), saved_wt)
np.testing.assert_equal(m.bias.numpy(), saved_bias)
np.testing.assert_almost_equal(
loss.numpy(), ((data - bias) / np.sqrt(var)).mean(), 5
)


def test_frozen_bn():
run_frozen_bn(BatchNorm2d)
run_frozen_bn(BatchNorm2d, True, False)
run_frozen_bn(BatchNorm2d, True, True)
@pytest.mark.parametrize("is_training", [False, True])
@pytest.mark.parametrize("use_trace", [False, True])
@pytest.mark.parametrize("use_symbolic", [False, True])
def test_frozen_bn(is_training, use_trace, use_symbolic):
run_frozen_bn(BatchNorm2d, is_training, use_trace, use_symbolic)


@pytest.mark.require_ngpu(2)
@pytest.mark.isolated_distributed
def test_frozen_synced_bn():
@pytest.mark.parametrize("is_training", [False, True])
@pytest.mark.parametrize("use_trace", [False, True])
@pytest.mark.parametrize("use_symbolic", [False, True])
def test_frozen_synced_bn(is_training, use_trace, use_symbolic):
@dist.launcher(n_gpus=2)
def worker():
run_frozen_bn(SyncBatchNorm)
run_frozen_bn(SyncBatchNorm, True, False)
run_frozen_bn(SyncBatchNorm, True, True)
run_frozen_bn(SyncBatchNorm, is_training, use_trace, use_symbolic)

worker()

@@ -190,8 +197,13 @@ def test_trace_several_syncbn(trace_mode):


# https://github.com/MegEngine/MegEngine/issues/145
def test_frozen_bn_no_affine():
@pytest.mark.parametrize("is_training", [False, True])
def test_frozen_bn_no_affine(is_training):
nchannel = 3
m = BatchNorm2d(nchannel, freeze=True, affine=False)
data = tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
if is_training:
m.train()
else:
m.eval()
data = megengine.tensor(np.random.random((6, nchannel, 2, 2)).astype("float32"))
m(data).numpy()

+ 28
- 8
src/opr/impl/dnn/batch_norm.cpp View File

@@ -12,6 +12,8 @@
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/io.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/tensor_manip.h"

#include "../internal/megdnn_opr_wrapper.inl"

@@ -243,16 +245,34 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() {

#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(BatchNormForward) {
mgb_assert(opr.param().fwd_mode == BatchNorm::Param::FwdMode::TRAINING,
"batch norm could only take grad in training mode");
mgb_assert(wrt_idx < 5, "wrt_idx %zu is out of range", wrt_idx);
VarNodeArray ret(opr.input().size(), nullptr);
SymbolVarArray grad = BatchNormBackward::make(
opr.input(0), out_grad[4],
opr.output(2), opr.output(3),
opr.input(1), opr.param());
for (size_t i = 0; i < 3; ++ i) {
ret[i] = grad[(i + 2) % 3].node();
SymbolVarArray grad;
switch (opr.param().fwd_mode) {
case BatchNorm::Param::FwdMode::TRAINING:
grad = BatchNormBackward::make(
opr.input(0), out_grad[4],
opr.output(2), opr.output(3),
opr.input(1), opr.param());
for (size_t i = 0; i < 3; ++ i) {
ret[i] = grad[(i + 2) % 3].node();
}
return ret;
case BatchNorm::Param::FwdMode::INFERENCE:
auto sqrt_var = PowC::make((SymbolVar{opr.input(4)}
+ static_cast<dt_float32>(opr.param().epsilon)), 0.5, opr.config());
auto d_bn_scale_unreduced = SymbolVar{out_grad[4]} *
(SymbolVar{opr.input(0)} - SymbolVar{opr.input(3)}) / sqrt_var;
auto d_bn_scale = Reduce::make(d_bn_scale_unreduced,
Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(1)));
auto d_bn_bias = Reduce::make(out_grad[4],
Reduce::Param::Mode::SUM, GetVarShape::make(opr.input(2)));
auto dx = SymbolVar{out_grad[4]} * SymbolVar{opr.input(1)} / sqrt_var;

ret[0] = dx.node();
ret[1] = d_bn_scale.node();
ret[2] = d_bn_bias.node();
return ret;
}
return ret;
}


Loading…
Cancel
Save