From 74cbc10d82a86078ee54f2261bbc5472d43ec40f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 9 Sep 2021 21:42:48 +0800 Subject: [PATCH] feat(opr): let batchnorm support empty IO GitOrigin-RevId: 219411c80cd3031b10f9ca59af4988c9ebe0c35c --- .../python/test/unit/module/test_batchnorm.py | 28 +++++++++++++++++++++- src/opr/impl/dnn/batch_norm.cpp | 24 +++++++++++++++++-- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/imperative/python/test/unit/module/test_batchnorm.py b/imperative/python/test/unit/module/test_batchnorm.py index de659cc0..901b770d 100644 --- a/imperative/python/test/unit/module/test_batchnorm.py +++ b/imperative/python/test/unit/module/test_batchnorm.py @@ -14,7 +14,7 @@ import pytest import megengine as mge import megengine.distributed as dist -from megengine import Tensor +from megengine import Tensor, jit from megengine.autodiff.grad_manager import GradManager from megengine.core._trace_option import use_symbolic_shape from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm @@ -368,3 +368,29 @@ def test_syncbn2d_grad(): _assert_allclose(oup.numpy(), oup_expect.numpy()) _assert_allclose(grad.numpy(), grad_expect.numpy()) + + +@pytest.mark.parametrize("dim", [1, 2]) +@pytest.mark.parametrize("is_symbolic", [None, False, True]) +def test_batchnorm_empty_tensor(dim, is_symbolic): + if dim == 1: + m = BatchNorm1d(4, affine=True) + inp = mge.tensor(np.random.randn(0, 4, 0).astype("float32")) + elif dim == 2: + m = BatchNorm2d(4, affine=True) + inp = mge.tensor(np.random.randn(0, 4, 0, 0).astype("float32")) + else: + raise NotImplementedError + + m.train() + + def fn(inp): + return m(inp) + + if is_symbolic is not None: + fn = jit.trace(symbolic=is_symbolic)(fn) + for _ in range(3): + out = fn(inp) + np.testing.assert_equal(out.numpy(), inp) + if is_symbolic is None: + break diff --git a/src/opr/impl/dnn/batch_norm.cpp b/src/opr/impl/dnn/batch_norm.cpp index b3e24199..56363f0f 100644 --- a/src/opr/impl/dnn/batch_norm.cpp +++ b/src/opr/impl/dnn/batch_norm.cpp @@ -62,6 +62,7 @@ BatchNormForward::BatchNormForward(VarNode *x, } init_megdnn_opr(*this, param); + output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); add_input({x, scale, bias, mean, variance}); @@ -91,6 +92,7 @@ BatchNormForward::BatchNormForward(VarNode *x, {x, scale, bias}} { init_megdnn_opr(*this, param); + output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); add_input({x, scale, bias}); auto mark_empty_var = [&](VarNode *var) { @@ -139,6 +141,8 @@ SymbolVarArray BatchNormForward::make(SymbolVar x, cg::OperatorNodeBase::NodeProp* BatchNormForward::do_make_node_prop() const { auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(0), + NodeProp::DepType::VALUE_ALLOW_EMPTY); if (need_stats() && m_force_inplace) { ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR); } @@ -148,8 +152,6 @@ BatchNormForward::do_make_node_prop() const { void BatchNormForward::scn_do_execute() { auto &&x = input(0)->dev_tensor(); auto &&y = output(4)->dev_tensor(); - mgb_assert(x.layout().is_contiguous() && - y.layout().is_contiguous()); if (need_stats()) { auto &&o0 = output(0)->dev_tensor(), &&o1 = output(1)->dev_tensor(), @@ -172,6 +174,12 @@ void BatchNormForward::scn_do_execute() { && o1.raw_ptr() == i1.raw_ptr()); } } + mgb_assert(x.layout().eq_layout(y.layout())); + if (x.layout().is_empty()) { + return; + } + mgb_assert(x.layout().is_contiguous() && + y.layout().is_contiguous()); auto scale = input(1)->dev_tensor().as_megdnn(); auto bias = input(2)->dev_tensor().as_megdnn(); megdnn::TensorND mean, variance; @@ -196,6 +204,18 @@ void BatchNormForward::add_input_layout_constraint() { void BatchNormForward::get_output_var_shape( const TensorShapeArray &inp_shape, TensorShapeArray &out_shape) const { + mgb_assert(inp_shape[0].ndim == 4 && inp_shape[0].ndim == 4 && inp_shape[1].ndim == 4, + "expect input, scale and bias to be 4 dim tensor, but " + "got input dim: %zu, scale dim: %zu, bias dim: %zu", + inp_shape[0].ndim, inp_shape[1].ndim, inp_shape[2].ndim); + + size_t inp_c = inp_shape[0][1], + scale_c = inp_shape[1][1], + bias_c = inp_shape[2][1]; + mgb_assert(inp_c == scale_c && inp_c == bias_c, + "inconsistent channel size, input chennel: %zu, scale channel: %zu, bias channel: %zu", + inp_c, scale_c, bias_c); + out_shape[4] = inp_shape[0]; for (size_t i = 0; i < 4; ++ i) { out_shape[i] = inp_shape[1];