You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dropout.py 992 B

123456789101112131415161718192021222324252627
  1. # -*- coding: utf-8 -*-
  2. from ..functional import dropout
  3. from .module import Module
  4. class Dropout(Module):
  5. r"""Randomly sets input elements to zeros with the probability :math:`drop\_prob` during training.
  6. Commonly used in large networks to prevent overfitting.
  7. Note that we perform dropout only during training, we also rescale(multiply) the output tensor
  8. by :math:`\frac{1}{1 - drop\_prob}`. During inference :class:`~.Dropout` is equal to :class:`~.module.identity.Identity`.
  9. Args:
  10. drop_prob: The probability to drop (set to zero) each single element
  11. """
  12. def __init__(self, drop_prob=0.0, **kwargs):
  13. super().__init__(**kwargs)
  14. self.drop_prob = drop_prob
  15. def forward(self, inputs):
  16. if self.training:
  17. return dropout(inputs, self.drop_prob, training=True)
  18. else:
  19. return inputs
  20. def _module_info_string(self) -> str:
  21. return "drop_prob={drop_prob}".format(drop_prob=self.drop_prob)