You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

clip_grad.py 2.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # -*- coding: utf-8 -*-
  2. # pylint: disable=redefined-builtin
  3. from typing import Iterable, Union
  4. from ..core._imperative_rt.core2 import pop_scope, push_scope
  5. from ..functional import clip, concat, minimum, norm
  6. from ..tensor import Tensor
  7. __all__ = ["clip_grad_norm", "clip_grad_value"]
  8. def clip_grad_norm(
  9. tensors: Union[Tensor, Iterable[Tensor]], max_norm: float, ord: float = 2.0,
  10. ):
  11. r"""Clips gradient norm of an iterable of parameters.
  12. The norm is computed over all gradients together, as if they were
  13. concatenated into a single vector. Gradients are modified in-place.
  14. Args:
  15. tensors: an iterable of Tensors or a single Tensor.
  16. max_norm: max norm of the gradients.
  17. ord: type of the used p-norm. Can be ``'inf'`` for infinity norm.
  18. Returns:
  19. total norm of the parameters (viewed as a single vector).
  20. """
  21. push_scope("clip_grad_norm")
  22. if isinstance(tensors, Tensor):
  23. tensors = [tensors]
  24. tensors = [t for t in tensors if t.grad is not None]
  25. if len(tensors) == 0:
  26. pop_scope("clip_grad_norm")
  27. return Tensor(0.0)
  28. norm_ = [norm(t.grad.flatten(), ord=ord) for t in tensors]
  29. if len(norm_) > 1:
  30. norm_ = norm(concat(norm_), ord=ord)
  31. else:
  32. norm_ = norm_[0]
  33. scale = max_norm / (norm_ + 1e-6)
  34. scale = minimum(scale, 1)
  35. for tensor in tensors:
  36. tensor.grad._reset(tensor.grad * scale)
  37. pop_scope("clip_grad_norm")
  38. return norm_
  39. def clip_grad_value(
  40. tensors: Union[Tensor, Iterable[Tensor]], lower: float, upper: float
  41. ):
  42. r"""Clips gradient of an iterable of parameters to a specified lower and
  43. upper. Gradients are modified in-place.
  44. The gradients are clipped in the range:
  45. .. math:: \left[\text{lower}, \text{upper}\right]
  46. Args:
  47. tensors: an iterable of Tensors or a single Tensor.
  48. lower: minimum allowed value of the gradients.
  49. upper: maximum allowed value of the gradients.
  50. """
  51. push_scope("clip_grad_value")
  52. if isinstance(tensors, Tensor):
  53. tensors = [tensors]
  54. for tensor in tensors:
  55. if tensor.grad is None:
  56. continue
  57. tensor.grad._reset(clip(tensor.grad, lower, upper))
  58. pop_scope("clip_grad_value")