You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

clip_grad.py 2.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. # pylint: disable=redefined-builtin
  10. from typing import Iterable, Union
  11. from ..core._imperative_rt.core2 import pop_scope, push_scope
  12. from ..functional import clip, concat, minimum, norm
  13. from ..tensor import Tensor
  14. __all__ = ["clip_grad_norm", "clip_grad_value"]
  15. def clip_grad_norm(
  16. tensors: Union[Tensor, Iterable[Tensor]], max_norm: float, ord: float = 2.0,
  17. ):
  18. r"""Clips gradient norm of an iterable of parameters.
  19. The norm is computed over all gradients together, as if they were
  20. concatenated into a single vector. Gradients are modified in-place.
  21. :param tensors: an iterable of Tensors or a single Tensor.
  22. :param max_norm: max norm of the gradients.
  23. :param ord: type of the used p-norm. Can be ``'inf'`` for infinity norm.
  24. :return: total norm of the parameters (viewed as a single vector).
  25. """
  26. push_scope("clip_grad_norm")
  27. if isinstance(tensors, Tensor):
  28. tensors = [tensors]
  29. tensors = [t for t in tensors if t.grad is not None]
  30. if len(tensors) == 0:
  31. pop_scope("clip_grad_norm")
  32. return Tensor(0.0)
  33. norm_ = [norm(t.grad.flatten(), ord=ord) for t in tensors]
  34. if len(norm_) > 1:
  35. norm_ = norm(concat(norm_), ord=ord)
  36. else:
  37. norm_ = norm_[0]
  38. scale = max_norm / (norm_ + 1e-6)
  39. scale = minimum(scale, 1)
  40. for tensor in tensors:
  41. tensor.grad._reset(tensor.grad * scale)
  42. pop_scope("clip_grad_norm")
  43. return norm_
  44. def clip_grad_value(
  45. tensors: Union[Tensor, Iterable[Tensor]], lower: float, upper: float
  46. ):
  47. r"""Clips gradient of an iterable of parameters to a specified lower and
  48. upper. Gradients are modified in-place.
  49. The gradients are clipped in the range:
  50. .. math:: \left[\text{lower}, \text{upper}\right]
  51. :param tensors: an iterable of Tensors or a single Tensor.
  52. :param lower: minimum allowed value of the gradients.
  53. :param upper: maximum allowed value of the gradients.
  54. """
  55. push_scope("clip_grad_value")
  56. if isinstance(tensors, Tensor):
  57. tensors = [tensors]
  58. for tensor in tensors:
  59. if tensor.grad is None:
  60. continue
  61. tensor.grad._reset(clip(tensor.grad, lower, upper))
  62. pop_scope("clip_grad_value")

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台