You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import weakref
  10. from .._imperative_rt import core2
  11. _grad_count = 0
  12. _grad_manager_dict = weakref.WeakValueDictionary()
  13. def get_grad_managers():
  14. return [_grad_manager_dict[key] for key in _grad_manager_dict]
  15. class GradKey(core2.GradKey):
  16. def __init__(self, name=None):
  17. if name:
  18. self.name = name
  19. def backward(self, ys, dys):
  20. return core2.backward(self, ys, dys)
  21. class Grad:
  22. stack = []
  23. grouping = False
  24. key2grad = weakref.WeakValueDictionary()
  25. def __init__(self, name=None):
  26. global _grad_count
  27. if name is None:
  28. name = "grad_%d" % _grad_count
  29. _grad_count += 1
  30. self._refkeeper = []
  31. self._impl = GradKey(name)
  32. Grad.key2grad[self._impl] = self
  33. _grad_manager_dict[self._name] = self
  34. self._group = [weakref.ref(self)]
  35. @property
  36. def _name(self):
  37. return self._impl.name
  38. def _is_attached_to(self, tensor):
  39. return self._impl.is_attached_to(tensor)
  40. def wrt(self, *tensors, callback=None):
  41. for x in tensors:
  42. self._impl.attach(x, callback)
  43. return self
  44. def __call__(self, ys, dys):
  45. from collections.abc import Sequence
  46. if not isinstance(ys, Sequence):
  47. ys = [ys]
  48. if not isinstance(dys, Sequence):
  49. dys = [dys]
  50. group = [ref() for ref in self._group]
  51. for grad in group:
  52. if grad is self:
  53. continue
  54. grad.suppress()
  55. self._impl.backward(ys, dys)
  56. for grad in group:
  57. if grad is self:
  58. continue
  59. grad.resume()
  60. self._refkeeper = None
  61. return None
  62. def __enter__(self):
  63. ref = weakref.ref(self)
  64. self._impl.enter()
  65. if Grad.grouping:
  66. group = Grad.stack[-1]
  67. self._group = group
  68. group.append(ref)
  69. else:
  70. Grad.stack.append(self._group)
  71. return self
  72. def __exit__(self, _1, _2, _3):
  73. self._impl.exit()
  74. self._refkeeper = None
  75. del Grad.key2grad[self._impl]
  76. self._impl = None
  77. self._group.remove(weakref.ref(self))
  78. if len(self._group) == 0:
  79. Grad.stack.remove(self._group)
  80. @staticmethod
  81. def begin_group():
  82. assert not Grad.grouping
  83. Grad.grouping = True
  84. @staticmethod
  85. def end_group():
  86. group = Grad.stack[-1]
  87. assert len(group) > 0
  88. assert Grad.grouping
  89. Grad.grouping = False
  90. def suppress(self):
  91. if self._impl is not None:
  92. self._impl.suppress()
  93. def resume(self):
  94. if self._impl is not None:
  95. self._impl.resume()
  96. class Function:
  97. r"""Defines a block of operations with customizable differentiation.
  98. The computation should be defined in ``forward`` method, with gradient
  99. computation defined in ``backward`` method.
  100. Each instance of ``Function`` should be used only once during forwardding.
  101. Examples:
  102. .. code-block::
  103. class Sigmoid(Function):
  104. def forward(self, x):
  105. y = 1 / (1 + F.exp(-x))
  106. self.y = y
  107. return y
  108. def backward(self, dy):
  109. y = self.y
  110. """
  111. def forward(self, *args, **kwargs):
  112. r"""Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses.
  113. Args:
  114. input: input tensors.
  115. Returns:
  116. a tuple of Tensor or a single Tensor.
  117. Note:
  118. * This method should return a tuple of Tensor or a single Tensor representing the output
  119. of the function.
  120. * positional arguments should all be Tensor
  121. """
  122. raise NotImplementedError
  123. def backward(self, *output_grads):
  124. r"""Compute the gradient of the forward function. It must be overriden by all subclasses.
  125. Args:
  126. output_grads: gradients of outputs that are returned by :meth:`forward`.
  127. Note:
  128. * In case when some tensors of outputs are not related to loss function, the corresponding
  129. values in ``output_grads`` would be ``None``.
  130. * This method should return a tuple which containing the gradients of all inputs, in the same order
  131. as the ``inputs`` argument of :meth:`forward` . A ``Tensor`` could be returned
  132. instead if there is only one input. If users want to stop the propagation of some gradients,
  133. the corresponding returned values should be set ``None`` .
  134. """
  135. raise NotImplementedError
  136. def _default_rule(self, *args):
  137. ret = self.forward(*args)
  138. self.__single_output = isinstance(ret, core2.Tensor)
  139. return ret
  140. def _grad_rule(self, *args):
  141. return self._default_rule(*args), self.backward
  142. def __call__(self, *args):
  143. for arg in args:
  144. if not isinstance(arg, core2.Tensor):
  145. raise TypeError(
  146. "op Function expect type Tensor as inputs, got {}".format(type(arg))
  147. )
  148. grad_key = core2.get_grad_key(args)
  149. if grad_key is None:
  150. return self._default_rule(*args)
  151. grad = Grad.key2grad[grad_key]
  152. group = [ref() for ref in grad._group]
  153. for grad in group:
  154. grad.suppress()
  155. outputs, backward = self._grad_rule(*args)
  156. for grad in reversed(group):
  157. grad.resume()
  158. def normalized_backward(*output_grads):
  159. input_grads = backward(*output_grads)
  160. if isinstance(input_grads, core2.Tensor) or input_grads is None:
  161. input_grads = (input_grads,)
  162. return input_grads
  163. if self.__single_output:
  164. outputs = (outputs,)
  165. for grad in reversed(group):
  166. if grad._impl is None:
  167. continue
  168. outputs = core2.set_grad(grad._impl, normalized_backward, args, outputs)
  169. if self.__single_output:
  170. (outputs,) = outputs
  171. return outputs
  172. def __getstate__(self):
  173. return self.__dict__
  174. def __setstate__(self, state):
  175. self.__dict__.update(state)