Browse Source

feat(opr): let batchnorm support empty IO

GitOrigin-RevId: 219411c80c
release-1.6
Megvii Engine Team 3 years ago
parent
commit
74cbc10d82
2 changed files with 49 additions and 3 deletions
  1. +27
    -1
      imperative/python/test/unit/module/test_batchnorm.py
  2. +22
    -2
      src/opr/impl/dnn/batch_norm.cpp

+ 27
- 1
imperative/python/test/unit/module/test_batchnorm.py View File

@@ -14,7 +14,7 @@ import pytest


import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine import Tensor
from megengine import Tensor, jit
from megengine.autodiff.grad_manager import GradManager from megengine.autodiff.grad_manager import GradManager
from megengine.core._trace_option import use_symbolic_shape from megengine.core._trace_option import use_symbolic_shape
from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm
@@ -368,3 +368,29 @@ def test_syncbn2d_grad():


_assert_allclose(oup.numpy(), oup_expect.numpy()) _assert_allclose(oup.numpy(), oup_expect.numpy())
_assert_allclose(grad.numpy(), grad_expect.numpy()) _assert_allclose(grad.numpy(), grad_expect.numpy())


@pytest.mark.parametrize("dim", [1, 2])
@pytest.mark.parametrize("is_symbolic", [None, False, True])
def test_batchnorm_empty_tensor(dim, is_symbolic):
if dim == 1:
m = BatchNorm1d(4, affine=True)
inp = mge.tensor(np.random.randn(0, 4, 0).astype("float32"))
elif dim == 2:
m = BatchNorm2d(4, affine=True)
inp = mge.tensor(np.random.randn(0, 4, 0, 0).astype("float32"))
else:
raise NotImplementedError

m.train()

def fn(inp):
return m(inp)

if is_symbolic is not None:
fn = jit.trace(symbolic=is_symbolic)(fn)
for _ in range(3):
out = fn(inp)
np.testing.assert_equal(out.numpy(), inp)
if is_symbolic is None:
break

+ 22
- 2
src/opr/impl/dnn/batch_norm.cpp View File

@@ -62,6 +62,7 @@ BatchNormForward::BatchNormForward(VarNode *x,
} }


init_megdnn_opr(*this, param); init_megdnn_opr(*this, param);
output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);


add_input({x, scale, bias, mean, variance}); add_input({x, scale, bias, mean, variance});


@@ -91,6 +92,7 @@ BatchNormForward::BatchNormForward(VarNode *x,
{x, scale, bias}} {x, scale, bias}}
{ {
init_megdnn_opr(*this, param); init_megdnn_opr(*this, param);
output(4)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);


add_input({x, scale, bias}); add_input({x, scale, bias});
auto mark_empty_var = [&](VarNode *var) { auto mark_empty_var = [&](VarNode *var) {
@@ -139,6 +141,8 @@ SymbolVarArray BatchNormForward::make(SymbolVar x,
cg::OperatorNodeBase::NodeProp* cg::OperatorNodeBase::NodeProp*
BatchNormForward::do_make_node_prop() const { BatchNormForward::do_make_node_prop() const {
auto ret = Super::do_make_node_prop(); auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
if (need_stats() && m_force_inplace) { if (need_stats() && m_force_inplace) {
ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR); ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR);
} }
@@ -148,8 +152,6 @@ BatchNormForward::do_make_node_prop() const {
void BatchNormForward::scn_do_execute() { void BatchNormForward::scn_do_execute() {
auto &&x = input(0)->dev_tensor(); auto &&x = input(0)->dev_tensor();
auto &&y = output(4)->dev_tensor(); auto &&y = output(4)->dev_tensor();
mgb_assert(x.layout().is_contiguous() &&
y.layout().is_contiguous());
if (need_stats()) { if (need_stats()) {
auto &&o0 = output(0)->dev_tensor(), auto &&o0 = output(0)->dev_tensor(),
&&o1 = output(1)->dev_tensor(), &&o1 = output(1)->dev_tensor(),
@@ -172,6 +174,12 @@ void BatchNormForward::scn_do_execute() {
&& o1.raw_ptr() == i1.raw_ptr()); && o1.raw_ptr() == i1.raw_ptr());
} }
} }
mgb_assert(x.layout().eq_layout(y.layout()));
if (x.layout().is_empty()) {
return;
}
mgb_assert(x.layout().is_contiguous() &&
y.layout().is_contiguous());
auto scale = input(1)->dev_tensor().as_megdnn(); auto scale = input(1)->dev_tensor().as_megdnn();
auto bias = input(2)->dev_tensor().as_megdnn(); auto bias = input(2)->dev_tensor().as_megdnn();
megdnn::TensorND mean, variance; megdnn::TensorND mean, variance;
@@ -196,6 +204,18 @@ void BatchNormForward::add_input_layout_constraint() {
void BatchNormForward::get_output_var_shape( void BatchNormForward::get_output_var_shape(
const TensorShapeArray &inp_shape, const TensorShapeArray &inp_shape,
TensorShapeArray &out_shape) const { TensorShapeArray &out_shape) const {
mgb_assert(inp_shape[0].ndim == 4 && inp_shape[0].ndim == 4 && inp_shape[1].ndim == 4,
"expect input, scale and bias to be 4 dim tensor, but "
"got input dim: %zu, scale dim: %zu, bias dim: %zu",
inp_shape[0].ndim, inp_shape[1].ndim, inp_shape[2].ndim);

size_t inp_c = inp_shape[0][1],
scale_c = inp_shape[1][1],
bias_c = inp_shape[2][1];
mgb_assert(inp_c == scale_c && inp_c == bias_c,
"inconsistent channel size, input chennel: %zu, scale channel: %zu, bias channel: %zu",
inp_c, scale_c, bias_c);

out_shape[4] = inp_shape[0]; out_shape[4] = inp_shape[0];
for (size_t i = 0; i < 4; ++ i) { for (size_t i = 0; i < 4; ++ i) {
out_shape[i] = inp_shape[1]; out_shape[i] = inp_shape[1];


Loading…
Cancel
Save