You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # -*- coding: utf-8 -*-
  2. import weakref
  3. from .._imperative_rt import core2
  4. _grad_count = 0
  5. _grad_manager_dict = weakref.WeakValueDictionary()
  6. def get_grad_managers():
  7. return [_grad_manager_dict[key] for key in _grad_manager_dict]
  8. class GradKey(core2.GradKey):
  9. def __init__(self, name=None):
  10. if name:
  11. self.name = name
  12. def backward(self, ys, dys):
  13. return core2.backward(self, ys, dys)
  14. class Grad:
  15. stack = []
  16. grouping = False
  17. key2grad = weakref.WeakValueDictionary()
  18. def __init__(self, name=None):
  19. global _grad_count
  20. if name is None:
  21. name = "grad_%d" % _grad_count
  22. _grad_count += 1
  23. self._refkeeper = []
  24. self._impl = GradKey(name)
  25. Grad.key2grad[self._impl] = self
  26. _grad_manager_dict[self._name] = self
  27. self._group = [weakref.ref(self)]
  28. @property
  29. def _name(self):
  30. return self._impl.name
  31. def _is_attached_to(self, tensor):
  32. return self._impl.is_attached_to(tensor)
  33. def wrt(self, *tensors, callback=None):
  34. for x in tensors:
  35. self._impl.attach(x, callback)
  36. return self
  37. def __call__(self, ys, dys):
  38. from collections.abc import Sequence
  39. if not isinstance(ys, Sequence):
  40. ys = [ys]
  41. if not isinstance(dys, Sequence):
  42. dys = [dys]
  43. group = [ref() for ref in self._group]
  44. for grad in group:
  45. if grad is self:
  46. continue
  47. grad.suppress()
  48. self._impl.backward(ys, dys)
  49. for grad in group:
  50. if grad is self:
  51. continue
  52. grad.resume()
  53. self._refkeeper = None
  54. return None
  55. def __enter__(self):
  56. ref = weakref.ref(self)
  57. self._impl.enter()
  58. if Grad.grouping:
  59. group = Grad.stack[-1]
  60. self._group = group
  61. group.append(ref)
  62. else:
  63. Grad.stack.append(self._group)
  64. return self
  65. def __exit__(self, _1, _2, _3):
  66. self._impl.exit()
  67. self._refkeeper = None
  68. del Grad.key2grad[self._impl]
  69. self._impl = None
  70. self._group.remove(weakref.ref(self))
  71. if len(self._group) == 0:
  72. Grad.stack.remove(self._group)
  73. @staticmethod
  74. def begin_group():
  75. assert not Grad.grouping
  76. Grad.grouping = True
  77. @staticmethod
  78. def end_group():
  79. group = Grad.stack[-1]
  80. assert len(group) > 0
  81. assert Grad.grouping
  82. Grad.grouping = False
  83. def suppress(self):
  84. if self._impl is not None:
  85. self._impl.suppress()
  86. def resume(self):
  87. if self._impl is not None:
  88. self._impl.resume()
  89. class Function:
  90. r"""Defines a block of operations with customizable differentiation.
  91. The computation should be defined in ``forward`` method, with gradient
  92. computation defined in ``backward`` method.
  93. Each instance of ``Function`` should be used only once during forwardding.
  94. Examples:
  95. .. code-block::
  96. class Sigmoid(Function):
  97. def forward(self, x):
  98. y = 1 / (1 + F.exp(-x))
  99. self.y = y
  100. return y
  101. def backward(self, dy):
  102. y = self.y
  103. """
  104. def forward(self, *args, **kwargs):
  105. r"""Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses.
  106. Args:
  107. input: input tensors.
  108. Returns:
  109. a tuple of Tensor or a single Tensor.
  110. Note:
  111. * This method should return a tuple of Tensor or a single Tensor representing the output
  112. of the function.
  113. * positional arguments should all be Tensor
  114. """
  115. raise NotImplementedError
  116. def backward(self, *output_grads):
  117. r"""Compute the gradient of the forward function. It must be overriden by all subclasses.
  118. Args:
  119. output_grads: gradients of outputs that are returned by :meth:`forward`.
  120. Note:
  121. * In case when some tensors of outputs are not related to loss function, the corresponding
  122. values in ``output_grads`` would be ``None``.
  123. * This method should return a tuple which containing the gradients of all inputs, in the same order
  124. as the ``inputs`` argument of :meth:`forward` . A ``Tensor`` could be returned
  125. instead if there is only one input. If users want to stop the propagation of some gradients,
  126. the corresponding returned values should be set ``None`` .
  127. """
  128. raise NotImplementedError
  129. def _default_rule(self, *args):
  130. ret = self.forward(*args)
  131. self.__single_output = isinstance(ret, core2.Tensor)
  132. return ret
  133. def _grad_rule(self, *args):
  134. return self._default_rule(*args), self.backward
  135. def __call__(self, *args):
  136. from ...tensor import Tensor
  137. for arg in args:
  138. if not isinstance(arg, Tensor):
  139. raise TypeError(
  140. "op Function expect type Tensor as inputs, got {}".format(type(arg))
  141. )
  142. grad_key = core2.get_grad_key(args)
  143. if grad_key is None:
  144. return self._default_rule(*args)
  145. grad = Grad.key2grad[grad_key]
  146. group = [ref() for ref in grad._group]
  147. origin_args = [Tensor(arg) for arg in args]
  148. for grad in group:
  149. grad.suppress()
  150. outputs, backward = self._grad_rule(*args)
  151. for grad in reversed(group):
  152. grad.resume()
  153. def normalized_backward(*output_grads):
  154. input_grads = backward(*output_grads)
  155. if isinstance(input_grads, Tensor) or input_grads is None:
  156. input_grads = (input_grads,)
  157. return input_grads
  158. if self.__single_output:
  159. outputs = (outputs,)
  160. outputs = core2.set_grad(normalized_backward, origin_args, outputs)
  161. if self.__single_output:
  162. (outputs,) = outputs
  163. return outputs
  164. def __getstate__(self):
  165. return self.__dict__
  166. def __setstate__(self, state):
  167. self.__dict__.update(state)