You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import heapq
  11. import itertools
  12. import typing
  13. import weakref
  14. import numpy as np
  15. from ..ops.builtin import Elemwise, OpDef
  16. from ..ops.special import Const
  17. from ..tensor.core import TensorBase, TensorWrapperBase, apply
  18. from ..tensor.function import Function
  19. from ..tensor.tensor import Tensor, get_context
  20. from . import builtin_op_utils
  21. """ Some notes:
  22. 1. Initialize the optimizer:
  23. for each trainable parameter:
  24. call wrt(param, callback)
  25. Each parameter tensor will be assciated with a Tracer object saved in Tensor._extra_data
  26. 2. Tracer has one member: node, which is a VariableNode
  27. 3. VariableNode has a OpNode member: opnode
  28. 4. OpNode has four members:
  29. a. id
  30. b. inputs, which is made of VariableNode
  31. c. outputs, which are weakref's to VariableNode
  32. d. backward: call back function
  33. e. has_grad_fn: call has_grad_fn(opnode, reached) to check grad exist
  34. f. backward_allow_noinput: whether backward allow noinput
  35. """
  36. _grad_count = 0
  37. _grad_manager_dict = weakref.WeakValueDictionary()
  38. def get_grad_managers():
  39. return [_grad_manager_dict[key] for key in _grad_manager_dict]
  40. def add(a, b):
  41. (c,) = apply(Elemwise(mode="add"), a, b)
  42. return c
  43. def get_tensor(x):
  44. # use recursion to avoid infinite loop
  45. if isinstance(x, Tensor):
  46. return x
  47. try:
  48. x = x.__wrapped__
  49. except AttributeError:
  50. raise TypeError(type(x))
  51. return get_tensor(x)
  52. class Grad:
  53. def __init__(self, name=None):
  54. if name is None:
  55. global _grad_count
  56. self._name = "grad_" + str(_grad_count)
  57. _grad_count += 1
  58. else:
  59. self._name = name
  60. assert self._name not in _grad_manager_dict, "grad manager name duplicated"
  61. _grad_manager_dict[self._name] = self
  62. # list of all x in partial(y) / partial(x)
  63. self.xs = []
  64. # constains weak reference of all OpNode during forward
  65. # OpNode contains inputs, outputs and its backward
  66. # ops forms the computational graph
  67. self.ops = []
  68. self._enabled = True
  69. @property
  70. def name(self):
  71. return self._name
  72. def wrt(self, *args: Tensor, callback=None):
  73. """ Indicates the loss is a function of the input tensors (usually the net trainable parameters),
  74. i.e., d (loss) / d (Tensor) != 0
  75. callback is used to perform additional operations after gradient is obtained in backward.
  76. e.g., copy the grad to a particular place
  77. A VariableNode will be created and saved in the tensor/s _extra_data slot.
  78. """
  79. for x in map(get_tensor, args):
  80. v = self._new_variable(x, callback=callback)
  81. assert self not in x._extra_data
  82. x._extra_data[self] = Tracer(v)
  83. self.xs.append(v)
  84. return self
  85. def _new_variable(self, owner, opnode=None, callback=None):
  86. return VariableNode(self, owner, opnode=opnode, callback=callback)
  87. def _new_opnode(self, inputs, outputs):
  88. inputs = tuple(inputs)
  89. for i in inputs:
  90. assert i is None or isinstance(i, VariableNode)
  91. o = OpNode()
  92. o.inputs = inputs
  93. o.outputs = []
  94. tracers = []
  95. for i in outputs:
  96. assert isinstance(i, Tensor)
  97. v = self._new_variable(i, o)
  98. o.outputs.append(weakref.ref(v))
  99. tracers.append(Tracer(v))
  100. self.ops.append(weakref.ref(o))
  101. return o, tracers
  102. def copy(self):
  103. raise NotImplementedError
  104. def __enter__(self):
  105. return self
  106. def __exit__(self, *_):
  107. """clear all resources"""
  108. self._enabled = False
  109. for o in self.ops:
  110. o = o()
  111. if o:
  112. o.clear()
  113. def __call__(self, ys, dys):
  114. """ Defines Grad().
  115. :param ys: outputs of forward operators, e.g., the loss tensor
  116. :type ys: list of Tensor or TensorWrapperBase
  117. :param dys: delta of outputs, physically equivalent to sensitivity of outputs to the loss,
  118. e.g., one for the loss itself
  119. :type dys: list of Tensor or TensorWrapperBase
  120. """
  121. assert self._enabled
  122. self._enabled = False
  123. def check_wrapper():
  124. if isinstance(dys, TensorWrapperBase):
  125. return type(dys)
  126. if isinstance(dys, TensorBase):
  127. return
  128. assert isinstance(dys, (tuple, list))
  129. for i in dys:
  130. if isinstance(i, TensorWrapperBase):
  131. return type(i)
  132. Wrapper = check_wrapper()
  133. def aslist(x):
  134. if isinstance(x, (Tensor, TensorWrapperBase)):
  135. x = [x]
  136. else:
  137. x = list(x)
  138. x = [i.__wrapped__ if isinstance(i, TensorWrapperBase) else i for i in x]
  139. for i in x:
  140. assert isinstance(i, Tensor)
  141. return x
  142. ys = aslist(ys)
  143. dys = aslist(dys)
  144. assert len(ys) == len(dys)
  145. # ys is changed to a list of VariableNode which contains more information
  146. # such as OpNode, callback, etc.
  147. ys = [i._extra_data[self].node for i in ys]
  148. # NOTE: callback is called only if grad is not None
  149. # the OpNode sequence in backward
  150. op_seq = []
  151. # VariableNode -> (i, j), where i is time stamp in backward, j means jth input
  152. last_written_to = {}
  153. def schedule():
  154. reached = set(ys)
  155. # i is the time stamp in backward
  156. i = 0
  157. for o in self.ops[::-1]:
  158. o = o()
  159. if o is None:
  160. continue
  161. if not o.has_grad_fn(o, reached):
  162. continue
  163. op_seq.append(o)
  164. for j, v in enumerate(o.inputs):
  165. reached.add(v)
  166. last_written_to[v] = i, j
  167. i += 1
  168. schedule()
  169. # VariableNode -> Tensor
  170. cache = {}
  171. def initialize():
  172. for y, dy in zip(ys, dys):
  173. cache[y] = dy
  174. if y not in last_written_to and y.callback:
  175. y.callback(y.owner(), dy)
  176. initialize()
  177. # NOTE: None is used to mark a node has been consumed
  178. for seqno, opnode in enumerate(op_seq):
  179. input_nodes = opnode.inputs
  180. output_nodes = [i() for i in opnode.outputs]
  181. backward = opnode.backward
  182. backward_allow_noinput = opnode.backward_allow_noinput
  183. opnode.clear()
  184. output_grads = []
  185. for i in output_nodes:
  186. if i is not None:
  187. if i in cache:
  188. assert cache[i] is not None
  189. output_grads.append(cache[i])
  190. else:
  191. output_grads.append(None)
  192. # read by backward, mark consumed
  193. cache[i] = None
  194. else:
  195. output_grads.append(None)
  196. if (
  197. any([grad is not None for grad in output_grads])
  198. or backward_allow_noinput
  199. ):
  200. input_grads = backward(*output_grads)
  201. else:
  202. input_grads = [None] * len(input_nodes)
  203. assert len(input_nodes) == len(input_grads)
  204. for i, (v, g) in enumerate(zip(input_nodes, input_grads)):
  205. if v is None:
  206. continue
  207. if v in cache:
  208. assert cache[v]
  209. if g is not None:
  210. cache[v] = add(cache[v], g)
  211. elif g is not None:
  212. cache[v] = g
  213. if last_written_to[v] == (seqno, i):
  214. if v.callback:
  215. v.callback(
  216. v.owner(), Wrapper(cache[v]) if Wrapper else cache[v]
  217. )
  218. if v.opnode is None:
  219. # won't read by backward, mark consumed
  220. cache[v] = None
  221. for v in cache.values():
  222. assert v is None
  223. class clearable:
  224. __cleared = False
  225. def __bool__(self):
  226. return not self.__cleared
  227. def clear(self):
  228. self.__dict__.clear()
  229. self.__cleared = True
  230. class OpNode(clearable):
  231. """ OpNode saves all the information to form the computational graph.
  232. """
  233. def __init__(self):
  234. self.id = None
  235. self.inputs = None # Could be VariableNode
  236. self.outputs = None # Could be VariableNode
  237. self.backward = None
  238. self.has_grad_fn = None
  239. self.backward_allow_noinput = False
  240. class VariableNode(clearable):
  241. """ VariableNode saves OpNode and callback.
  242. FIXME!!! Explain manager and owner
  243. """
  244. def __init__(self, manager, owner, opnode=None, callback=None):
  245. # manager is Grad type
  246. self.manager = weakref.ref(manager)
  247. # owner is Tensor type
  248. self.owner = weakref.ref(owner)
  249. self.opnode = opnode
  250. self.callback = callback
  251. class Tracer(clearable, TensorBase):
  252. def __init__(self, node=None):
  253. """ type(node) is VariableNode
  254. """
  255. self.node = node
  256. @functools.singledispatch
  257. def check_backward_allow_noinput(op: OpDef):
  258. return False
  259. @functools.singledispatch
  260. def get_op_has_grad_fn(op: OpDef):
  261. assert 0
  262. @get_op_has_grad_fn.register(OpDef)
  263. def _(op: OpDef):
  264. return default_has_grad_fn
  265. @get_op_has_grad_fn.register(Function)
  266. def _(op: Function):
  267. return default_has_grad_fn
  268. def default_has_grad_fn(opnode, reached):
  269. for v in opnode.outputs:
  270. if v() in reached:
  271. return True
  272. return False
  273. @apply.register()
  274. def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
  275. args = tuple(i if isinstance(i, Tracer) else None for i in args)
  276. input_requires_grad = list(map(bool, args))
  277. if not any(input_requires_grad):
  278. return
  279. ctx = get_context()
  280. manager = None
  281. assert len(ctx.inputs) == len(args)
  282. for i, j in zip(ctx.inputs, args):
  283. if j:
  284. j = j.node
  285. assert i is j.owner()
  286. if manager is None:
  287. manager = j.manager()
  288. assert manager
  289. else:
  290. assert manager is j.manager()
  291. if not manager._enabled:
  292. return
  293. opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs)
  294. # register backward method
  295. # tuple of backward functions corresponding to dy / dx_i
  296. # None means y is not a function of x_i
  297. opnode.backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn(
  298. op, ctx.inputs, ctx.outputs, input_requires_grad
  299. )
  300. assert len(outputs) == len(output_need_grad)
  301. outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)]
  302. opnode.backward_allow_noinput = check_backward_allow_noinput(op)
  303. opnode.has_grad_fn = get_op_has_grad_fn(op)
  304. return tuple(outputs)
  305. @apply.register()
  306. def _(op: Const, *_: typing.Optional[Tracer]):
  307. return None

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台