Browse Source

perf(functional/dropout): add fastpath for dropout

GitOrigin-RevId: 3bf8546908
tags/v1.6.0-rc1
Megvii Engine Team 3 years ago
parent
commit
77309609fa
1 changed files with 2 additions and 0 deletions
  1. +2
    -0
      imperative/python/megengine/functional/nn.py

+ 2
- 0
imperative/python/megengine/functional/nn.py View File

@@ -1304,6 +1304,8 @@ def dropout(inp: Tensor, drop_prob: float, training: bool = True) -> Tensor:


""" """
assert 0 <= drop_prob < 1 assert 0 <= drop_prob < 1
if drop_prob == 0:
return inp
rv = uniform(size=inp.shape) rv = uniform(size=inp.shape)
mask = rv > drop_prob mask = rv > drop_prob
inp *= mask.astype(inp.dtype) inp *= mask.astype(inp.dtype)


Loading…
Cancel
Save