You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import heapq
  11. import itertools
  12. import typing
  13. import weakref
  14. import numpy as np
  15. import megengine as mge
  16. from .._imperative_rt import core2, ops
  17. from ..ops.builtin import Elemwise, OpDef, RemoteSend
  18. from ..ops.special import Const
  19. from ..tensor.core import TensorBase, TensorWrapperBase, apply
  20. from ..tensor.function import Function
  21. from . import builtin_op_utils
  22. """ Some notes:
  23. 1. Initialize the optimizer:
  24. for each trainable parameter:
  25. call wrt(param, callback)
  26. Each parameter tensor will be assciated with a Tracer object saved in Tensor._extra_data
  27. 2. Tracer has one member: node, which is a VariableNode
  28. 3. VariableNode has a OpNode member: opnode
  29. 4. OpNode has four members:
  30. a. id
  31. b. inputs, which is made of VariableNode
  32. c. outputs, which are weakref's to VariableNode
  33. d. backward: call back function
  34. e. has_grad_fn: call has_grad_fn(opnode, reached) to check grad exist
  35. f. backward_allow_noinput: whether backward allow noinput
  36. """
  37. _grad_count = 0
  38. _grad_manager_dict = weakref.WeakValueDictionary()
  39. def get_grad_managers():
  40. return [_grad_manager_dict[key] for key in _grad_manager_dict]
  41. def add(a, b):
  42. (c,) = apply(Elemwise(Elemwise.Mode.ADD), a, b)
  43. return c
  44. def get_tensor(x):
  45. # use recursion to avoid infinite loop
  46. if isinstance(x, Tensor):
  47. return x
  48. try:
  49. x = x.__wrapped__
  50. except AttributeError:
  51. raise TypeError(type(x))
  52. return get_tensor(x)
  53. class clearable:
  54. __cleared = False
  55. def __bool__(self):
  56. return not self.__cleared
  57. def clear(self):
  58. self.__dict__.clear()
  59. self.__cleared = True
  60. class OpNode(clearable):
  61. """ OpNode saves all the information to form the computational graph.
  62. """
  63. def __init__(self):
  64. self.id = None
  65. self.inputs = None # Could be VariableNode
  66. self.outputs = None # Could be VariableNode
  67. self.backward = None
  68. self.has_grad_fn = None
  69. self.backward_allow_noinput = False
  70. class VariableNode(clearable):
  71. """ VariableNode saves OpNode and callback.
  72. FIXME!!! Explain manager and owner
  73. """
  74. def __init__(self, manager, owner, opnode=None, callback=None):
  75. # manager is Grad type
  76. self.manager = weakref.ref(manager)
  77. # owner is Tensor type
  78. self.owner = weakref.ref(owner)
  79. self.opnode = opnode
  80. self.callback = callback
  81. class Tracer(clearable, TensorBase):
  82. def __init__(self, node=None):
  83. """ type(node) is VariableNode
  84. """
  85. self.node = node
  86. @functools.singledispatch
  87. def check_backward_allow_noinput(op: OpDef):
  88. return False
  89. @functools.singledispatch
  90. def get_op_has_grad_fn(op: OpDef):
  91. assert 0
  92. @get_op_has_grad_fn.register(OpDef)
  93. def _(op: OpDef):
  94. return default_has_grad_fn
  95. @get_op_has_grad_fn.register(Function)
  96. def _(op: Function):
  97. return default_has_grad_fn
  98. def default_has_grad_fn(opnode, reached):
  99. for v in opnode.outputs:
  100. if v() in reached:
  101. return True
  102. return False
  103. @apply.register()
  104. def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
  105. args = tuple(i if isinstance(i, Tracer) else None for i in args)
  106. input_requires_grad = list(map(bool, args))
  107. if not any(input_requires_grad):
  108. return
  109. ctx = get_context()
  110. manager = None
  111. assert len(ctx.inputs) == len(args)
  112. for i, j in zip(ctx.inputs, args):
  113. if j:
  114. j = j.node
  115. assert i is j.owner()
  116. if manager is None:
  117. manager = j.manager()
  118. assert manager
  119. else:
  120. assert manager is j.manager()
  121. if not manager._enabled:
  122. return
  123. # register backward method
  124. # tuple of backward functions corresponding to dy / dx_i
  125. # None means y is not a function of x_i
  126. backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn(
  127. op, ctx.inputs, ctx.outputs, input_requires_grad
  128. )
  129. assert len(ctx.outputs) == len(output_need_grad)
  130. if not any(output_need_grad):
  131. return
  132. opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs)
  133. if isinstance(op, RemoteSend):
  134. manager.remote_send_cache.append(opnode)
  135. opnode.backward = backward
  136. outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)]
  137. opnode.backward_allow_noinput = check_backward_allow_noinput(op)
  138. opnode.has_grad_fn = get_op_has_grad_fn(op)
  139. return tuple(outputs)
  140. @apply.register()
  141. def _(op: Const, *_: typing.Optional[Tracer]):
  142. return None
  143. class Grad:
  144. def __init__(self):
  145. self._impl = core2.GradKey()
  146. def wrt(self, *tensors, callback=None):
  147. for x in tensors:
  148. self._impl.attach(x, callback)
  149. return self
  150. def __call__(self, ys, dys):
  151. from collections.abc import Sequence
  152. if not isinstance(ys, Sequence):
  153. ys = [ys]
  154. if not isinstance(dys, Sequence):
  155. dys = [dys]
  156. core2.backward(self._impl, ys, dys)
  157. def __enter__(self):
  158. return self
  159. def __exit__(self, _1, _2, _3):
  160. del self._impl
  161. class Function(ops.PyOpBase):
  162. def _default_rule(self, *args):
  163. ret = self.forward(*args)
  164. self.__single_output = isinstance(ret, core2.Tensor)
  165. return ret
  166. def _grad_rule(self, *args):
  167. return self._default_rule(*args), self.backward
  168. def __call__(self, *args):
  169. ret = core2.apply(self, *args)
  170. if self.__single_output:
  171. (ret,) = ret
  172. return ret
  173. def __getstate__(self):
  174. return self.__dict__
  175. def __setstate__(self, state):
  176. self.__dict__.update(state)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台