You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import heapq
  11. import itertools
  12. import typing
  13. import weakref
  14. import numpy as np
  15. from .._imperative_rt import core2, ops
  16. from ..ops.builtin import Elemwise, OpDef, RemoteSend
  17. from ..ops.special import Const
  18. _grad_count = 0
  19. _grad_manager_dict = weakref.WeakValueDictionary()
  20. def get_grad_managers():
  21. return [_grad_manager_dict[key] for key in _grad_manager_dict]
  22. class GradKey(core2.GradKey):
  23. def __init__(self, name=None):
  24. if name:
  25. self.name = name
  26. def backward(self, ys, dys):
  27. return core2.backward(self, ys, dys)
  28. class Grad:
  29. def __init__(self, name=None):
  30. global _grad_count
  31. if name is None:
  32. name = "grad_%d" % _grad_count
  33. _grad_count += 1
  34. self._refkeeper = []
  35. self._impl = GradKey(name)
  36. _grad_manager_dict[self._name] = self
  37. @property
  38. def _priority(self):
  39. return self._impl.priority
  40. @_priority.setter
  41. def _priority(self, priority):
  42. self._impl.priority = priority
  43. @property
  44. def _name(self):
  45. return self._impl.name
  46. def _is_attached_to(self, tensor):
  47. return self._impl.is_attached_to(tensor)
  48. def wrt(self, *tensors, callback=None):
  49. for x in tensors:
  50. self._impl.attach(x, callback)
  51. return self
  52. def __call__(self, ys, dys):
  53. from collections.abc import Sequence
  54. if not isinstance(ys, Sequence):
  55. ys = [ys]
  56. if not isinstance(dys, Sequence):
  57. dys = [dys]
  58. self._impl.backward(ys, dys)
  59. self._refkeeper = None
  60. def __enter__(self):
  61. return self
  62. def __exit__(self, _1, _2, _3):
  63. self._refkeeper = None
  64. del self._impl
  65. class Function(ops.PyOpBase):
  66. r"""Defines a block of operations with customizable differentiation.
  67. The computation should be defined in ``forward`` method, with gradient
  68. computation defined in ``backward`` method.
  69. Each instance of ``Function`` should be used only once during forwardding.
  70. Examples:
  71. .. code-block::
  72. class Sigmoid(Function):
  73. def forward(self, x):
  74. y = 1 / (1 + F.exp(-x))
  75. self.y = y
  76. return y
  77. def backward(self, dy):
  78. y = self.y
  79. """
  80. def forward(self, *args, **kwargs):
  81. r"""Applies operations to ``inputs`` and returns results. It must be overriden by all subclasses.
  82. Args:
  83. input: input tensors.
  84. Returns:
  85. a tuple of Tensor or a single Tensor.
  86. Note:
  87. * This method should return a tuple of Tensor or a single Tensor representing the output
  88. of the function.
  89. * positional arguments should all be Tensor
  90. """
  91. raise NotImplementedError
  92. def backward(self, *output_grads):
  93. r"""Compute the gradient of the forward function. It must be overriden by all subclasses.
  94. Args:
  95. output_grads: gradients of outputs that are returned by :meth:`forward`.
  96. Note:
  97. * In case when some tensors of outputs are not related to loss function, the corresponding
  98. values in ``output_grads`` would be ``None``.
  99. * This method should return a tuple which containing the gradients of all inputs, in the same order
  100. as the ``inputs`` argument of :meth:`forward` . A ``Tensor`` could be returned
  101. instead if there is only one input. If users want to stop the propagation of some gradients,
  102. the corresponding returned values should be set ``None`` .
  103. """
  104. raise NotImplementedError
  105. def _default_rule(self, *args):
  106. ret = self.forward(*args)
  107. self.__single_output = isinstance(ret, core2.Tensor)
  108. return ret
  109. def _grad_rule(self, *args):
  110. return self._default_rule(*args), self.backward
  111. def __call__(self, *args):
  112. ret = core2.apply(self, *args)
  113. if self.__single_output:
  114. (ret,) = ret
  115. return ret
  116. def __getstate__(self):
  117. return self.__dict__
  118. def __setstate__(self, state):
  119. self.__dict__.update(state)