Browse Source

fix(mge): fix F.nn.dropout train and inference bugs

GitOrigin-RevId: 9d9f246d7b
release-1.6
Megvii Engine Team 3 years ago
parent
commit
76ce81e828
2 changed files with 31 additions and 13 deletions
  1. +23
    -9
      imperative/python/megengine/functional/nn.py
  2. +8
    -4
      imperative/python/test/unit/functional/test_functional.py

+ 23
- 9
imperative/python/megengine/functional/nn.py View File

@@ -13,7 +13,14 @@ from typing import NamedTuple, Optional, Sequence, Tuple, Union
from ..core._imperative_rt.core2 import apply, dtype_promotion from ..core._imperative_rt.core2 import apply, dtype_promotion
from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder
from ..core.ops import builtin from ..core.ops import builtin
from ..core.ops.builtin import BatchNorm, Elemwise, GetVarShape, Reduce, TypeCvt
from ..core.ops.builtin import (
BatchNorm,
Elemwise,
GetVarShape,
Identity,
Reduce,
TypeCvt,
)
from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import amp, megbrain_graph from ..core.tensor import amp, megbrain_graph
from ..core.tensor.array_method import _elwise_apply from ..core.tensor.array_method import _elwise_apply
@@ -1403,9 +1410,14 @@ def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:
from megengine import tensor from megengine import tensor
import megengine.functional as F import megengine.functional as F


x = tensor(np.ones(10, dtype=np.float32))
out = F.dropout(x, 1./3.)
print(out.numpy())
# test training mode
data = tensor(np.ones(10000000, dtype=np.float32))
out = F.nn.dropout(data, 1.0 / 3.0, training=True)
assert not out.numpy().all()

# test eval mode
out = F.nn.dropout(data, 1.0 / 3.0, training=False)
assert out.numpy().all()


Outputs: Outputs:


@@ -1416,14 +1428,16 @@ def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:


""" """
assert 0 <= drop_prob < 1 assert 0 <= drop_prob < 1
if drop_prob == 0:
if not training or drop_prob == 0:
return inp return inp

# model in training mode, e.g. model.train()
rv = uniform(size=inp.shape) rv = uniform(size=inp.shape)
mask = rv > drop_prob mask = rv > drop_prob
inp *= mask.astype(inp.dtype)
if training:
inp *= 1 / (1 - drop_prob)
return inp
ret = inp * mask.astype(inp.dtype)
ret *= 1 / (1 - drop_prob)
return ret




def one_hot(inp: Tensor, num_classes: int) -> Tensor: def one_hot(inp: Tensor, num_classes: int) -> Tensor:


+ 8
- 4
imperative/python/test/unit/functional/test_functional.py View File

@@ -57,10 +57,14 @@ def test_where():




def test_dropout(): def test_dropout():
data = tensor(np.ones(10, dtype=np.float32))
out = F.dropout(data, 1.0 / 3.0, training=False)

assert out.numpy().sum() >= 0.0
# test training mode
data = tensor(np.ones(10000000, dtype=np.float32))
out = F.nn.dropout(data, 1.0 / 3.0, training=True)
assert not out.numpy().all()

# test eval mode
out = F.nn.dropout(data, 1.0 / 3.0, training=False)
assert out.numpy().all()




def test_matinv(): def test_matinv():


Loading…
Cancel
Save