You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import heapq
  11. import itertools
  12. import typing
  13. import weakref
  14. import numpy as np
  15. import megengine as mge
  16. from .._imperative_rt import core2, ops
  17. from ..ops.builtin import Elemwise, OpDef, RemoteSend
  18. from ..ops.special import Const
  19. _grad_count = 0
  20. _grad_manager_dict = weakref.WeakValueDictionary()
  21. def get_grad_managers():
  22. return [_grad_manager_dict[key] for key in _grad_manager_dict]
  23. class GradKey(core2.GradKey):
  24. def __init__(self, name=None):
  25. if name:
  26. self.name = name
  27. def backward(self, ys, dys):
  28. return core2.backward(self, ys, dys)
  29. class Grad:
  30. def __init__(self, name=None):
  31. global _grad_count
  32. if name is None:
  33. name = "grad_%d" % _grad_count
  34. _grad_count += 1
  35. self._refkeeper = []
  36. self._impl = GradKey(name)
  37. _grad_manager_dict[self._name] = self
  38. @property
  39. def _priority(self):
  40. return self._impl.priority
  41. @_priority.setter
  42. def _priority(self, priority):
  43. self._impl.priority = priority
  44. @property
  45. def _name(self):
  46. return self._impl.name
  47. def _is_attached_to(self, tensor):
  48. return self._impl.is_attached_to(tensor)
  49. def wrt(self, *tensors, callback=None):
  50. for x in tensors:
  51. self._impl.attach(x, callback)
  52. return self
  53. def __call__(self, ys, dys):
  54. from collections.abc import Sequence
  55. if not isinstance(ys, Sequence):
  56. ys = [ys]
  57. if not isinstance(dys, Sequence):
  58. dys = [dys]
  59. self._impl.backward(ys, dys)
  60. self._refkeeper = None
  61. def __enter__(self):
  62. return self
  63. def __exit__(self, _1, _2, _3):
  64. self._refkeeper = None
  65. del self._impl
  66. class Function(ops.PyOpBase):
  67. r"""Defines a block of operations with customizable differentiation.
  68. The computation should be defined in ``forward`` method, with gradient
  69. computation defined in ``backward`` method.
  70. Each instance of ``Function`` should be used only once during forwardding.
  71. Examples:
  72. .. code-block::
  73. class Sigmoid(Function):
  74. def forward(self, x):
  75. y = 1 / (1 + F.exp(-x))
  76. self.y = y
  77. return y
  78. def backward(self, dy):
  79. y = self.y
  80. """
  81. def forward(self, *args, **kwargs):
  82. r"""Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses.
  83. Args:
  84. input: input tensors.
  85. Returns:
  86. a tuple of Tensor or a single Tensor.
  87. Note:
  88. * This method should return a tuple of Tensor or a single Tensor representing the output
  89. of the function.
  90. * positional arguments should all be Tensor
  91. """
  92. raise NotImplementedError
  93. def backward(self, *output_grads):
  94. r"""Compute the gradient of the forward function. It must be overriden by all subclasses.
  95. Args:
  96. output_grads: gradients of outputs that are returned by :meth:`forward`.
  97. Note:
  98. * In case when some tensors of outputs are not related to loss function, the corresponding
  99. values in ``output_grads`` would be ``None``.
  100. * This method should return a tuple which containing the gradients of all inputs, in the same order
  101. as the ``inputs`` argument of :meth:`forward` . A ``Tensor`` could be returned
  102. instead if there is only one input. If users want to stop the propagation of some gradients,
  103. the corresponding returned values should be set ``None`` .
  104. """
  105. raise NotImplementedError
  106. def _default_rule(self, *args):
  107. ret = self.forward(*args)
  108. self.__single_output = isinstance(ret, core2.Tensor)
  109. return ret
  110. def _grad_rule(self, *args):
  111. return self._default_rule(*args), self.backward
  112. def __call__(self, *args):
  113. ret = core2.apply(self, *args)
  114. if self.__single_output:
  115. (ret,) = ret
  116. return ret
  117. def __getstate__(self):
  118. return self.__dict__
  119. def __setstate__(self, state):
  120. self.__dict__.update(state)