You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import heapq
  11. import itertools
  12. import typing
  13. import weakref
  14. import numpy as np
  15. import megengine as mge
  16. from ..ops.builtin import Elemwise, OpDef
  17. from ..ops.special import Const
  18. from ..tensor.core import TensorBase, TensorWrapperBase, apply
  19. from ..tensor.function import Function
  20. from ..tensor.tensor import Tensor, get_context
  21. from . import builtin_op_utils
  22. """ Some notes:
  23. 1. Initialize the optimizer:
  24. for each trainable parameter:
  25. call wrt(param, callback)
  26. Each parameter tensor will be assciated with a Tracer object saved in Tensor._extra_data
  27. 2. Tracer has one member: node, which is a VariableNode
  28. 3. VariableNode has a OpNode member: opnode
  29. 4. OpNode has four members:
  30. a. id
  31. b. inputs, which is made of VariableNode
  32. c. outputs, which are weakref's to VariableNode
  33. d. backward: call back function
  34. e. has_grad_fn: call has_grad_fn(opnode, reached) to check grad exist
  35. f. backward_allow_noinput: whether backward allow noinput
  36. """
  37. _grad_count = 0
  38. _grad_manager_dict = weakref.WeakValueDictionary()
  39. def get_grad_managers():
  40. return [_grad_manager_dict[key] for key in _grad_manager_dict]
  41. def add(a, b):
  42. (c,) = apply(Elemwise(Elemwise.Mode.ADD), a, b)
  43. return c
  44. def get_tensor(x):
  45. # use recursion to avoid infinite loop
  46. if isinstance(x, Tensor):
  47. return x
  48. try:
  49. x = x.__wrapped__
  50. except AttributeError:
  51. raise TypeError(type(x))
  52. return get_tensor(x)
  53. class Grad:
  54. def __init__(self, name=None):
  55. if name is None:
  56. global _grad_count
  57. self._name = "grad_" + str(_grad_count)
  58. _grad_count += 1
  59. else:
  60. self._name = name
  61. assert self._name not in _grad_manager_dict, "grad manager name duplicated"
  62. _grad_manager_dict[self._name] = self
  63. # list of all x in partial(y) / partial(x)
  64. self.xs = []
  65. # constains weak reference of all OpNode during forward
  66. # OpNode contains inputs, outputs and its backward
  67. # ops forms the computational graph
  68. self.ops = []
  69. self._attached_tensors = weakref.WeakSet()
  70. self._enabled = True
  71. @property
  72. def name(self):
  73. return self._name
  74. def wrt(self, *args: Tensor, callback=None):
  75. """ Indicates the loss is a function of the input tensors (usually the net trainable parameters),
  76. i.e., d (loss) / d (Tensor) != 0
  77. callback is used to perform additional operations after gradient is obtained in backward.
  78. e.g., copy the grad to a particular place
  79. A VariableNode will be created and saved in the tensor/s _extra_data slot.
  80. """
  81. for x in map(get_tensor, args):
  82. v = self._new_variable(x, callback=callback)
  83. assert self not in x._extra_data
  84. x._extra_data[self] = Tracer(v)
  85. self.xs.append(v)
  86. return self
  87. def _new_variable(self, owner, opnode=None, callback=None):
  88. self._attached_tensors.add(owner)
  89. return VariableNode(self, owner, opnode=opnode, callback=callback)
  90. def _new_opnode(self, inputs, outputs):
  91. inputs = tuple(inputs)
  92. for i in inputs:
  93. assert i is None or isinstance(i, VariableNode)
  94. o = OpNode()
  95. o.inputs = inputs
  96. o.outputs = []
  97. tracers = []
  98. for i in outputs:
  99. assert isinstance(i, Tensor)
  100. v = self._new_variable(i, o)
  101. o.outputs.append(weakref.ref(v))
  102. tracers.append(Tracer(v))
  103. self.ops.append(weakref.ref(o))
  104. return o, tracers
  105. def copy(self):
  106. raise NotImplementedError
  107. def __enter__(self):
  108. return self
  109. def _exit(self):
  110. """clear all resources"""
  111. self._enabled = False
  112. for o in self.ops:
  113. o = o()
  114. if o:
  115. o.clear()
  116. for i in self._attached_tensors:
  117. i._extra_data.pop(self, None)
  118. def __exit__(self, *_):
  119. self._exit()
  120. def __call__(self, ys, dys):
  121. """ Defines Grad().
  122. :param ys: outputs of forward operators, e.g., the loss tensor
  123. :type ys: list of Tensor or TensorWrapperBase
  124. :param dys: delta of outputs, physically equivalent to sensitivity of outputs to the loss,
  125. e.g., one for the loss itself
  126. :type dys: list of Tensor or TensorWrapperBase
  127. """
  128. assert self._enabled
  129. self._enabled = False
  130. def check_wrapper():
  131. if isinstance(dys, TensorWrapperBase):
  132. return type(dys)
  133. if isinstance(dys, TensorBase):
  134. return
  135. assert isinstance(dys, (tuple, list))
  136. for i in dys:
  137. if isinstance(i, TensorWrapperBase):
  138. return type(i)
  139. # use Tensor as defualt wrapper
  140. return mge.Tensor
  141. Wrapper = check_wrapper()
  142. def aslist(x):
  143. if isinstance(x, (Tensor, TensorWrapperBase)):
  144. x = [x]
  145. else:
  146. x = list(x)
  147. x = [i.__wrapped__ if isinstance(i, TensorWrapperBase) else i for i in x]
  148. for i in x:
  149. assert isinstance(i, Tensor)
  150. return x
  151. ys = aslist(ys)
  152. dys = aslist(dys)
  153. assert len(ys) == len(dys)
  154. ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()]
  155. ys = [y for i, y in enumerate(ys) if i in ids]
  156. dys = [dy for i, dy in enumerate(dys) if i in ids]
  157. # ys is changed to a list of VariableNode which contains more information
  158. # such as OpNode, callback, etc.
  159. ys = [i._extra_data[self].node for i in ys]
  160. # NOTE: callback is called only if grad is not None
  161. # the OpNode sequence in backward
  162. op_seq = []
  163. # VariableNode -> (i, j), where i is time stamp in backward, j means jth input
  164. last_written_to = {}
  165. def schedule():
  166. reached = set(ys)
  167. # i is the time stamp in backward
  168. i = 0
  169. for o in self.ops[::-1]:
  170. o = o()
  171. if o is None:
  172. continue
  173. if not o.has_grad_fn(o, reached):
  174. continue
  175. op_seq.append(o)
  176. for j, v in enumerate(o.inputs):
  177. reached.add(v)
  178. last_written_to[v] = i, j
  179. i += 1
  180. schedule()
  181. # VariableNode -> Tensor
  182. cache = {}
  183. def initialize():
  184. for y, dy in zip(ys, dys):
  185. cache[y] = dy
  186. if y not in last_written_to and y.callback:
  187. y.callback(y.owner(), dy)
  188. initialize()
  189. # NOTE: None is used to mark a node has been consumed
  190. for seqno, opnode in enumerate(op_seq):
  191. input_nodes = opnode.inputs
  192. output_nodes = [i() for i in opnode.outputs]
  193. backward = opnode.backward
  194. backward_allow_noinput = opnode.backward_allow_noinput
  195. opnode.clear()
  196. output_grads = []
  197. for i in output_nodes:
  198. if i is not None:
  199. if i in cache:
  200. assert cache[i] is not None
  201. output_grads.append(cache[i])
  202. else:
  203. output_grads.append(None)
  204. # read by backward, mark consumed
  205. cache[i] = None
  206. else:
  207. output_grads.append(None)
  208. if (
  209. any([grad is not None for grad in output_grads])
  210. or backward_allow_noinput
  211. ):
  212. input_grads = backward(*output_grads)
  213. else:
  214. input_grads = [None] * len(input_nodes)
  215. assert len(input_nodes) == len(input_grads)
  216. for i, (v, g) in enumerate(zip(input_nodes, input_grads)):
  217. if v is None:
  218. continue
  219. if v in cache:
  220. assert cache[v]
  221. if g is not None:
  222. cache[v] = add(cache[v], g)
  223. elif g is not None:
  224. cache[v] = g
  225. if last_written_to[v] == (seqno, i):
  226. if v.callback:
  227. v.callback(
  228. v.owner(), Wrapper(cache[v]) if Wrapper else cache[v]
  229. )
  230. if v.opnode is None:
  231. # won't read by backward, mark consumed
  232. cache[v] = None
  233. for v in cache.values():
  234. assert v is None
  235. self._exit()
  236. def __del__(self):
  237. self._exit()
  238. class clearable:
  239. __cleared = False
  240. def __bool__(self):
  241. return not self.__cleared
  242. def clear(self):
  243. self.__dict__.clear()
  244. self.__cleared = True
  245. class OpNode(clearable):
  246. """ OpNode saves all the information to form the computational graph.
  247. """
  248. def __init__(self):
  249. self.id = None
  250. self.inputs = None # Could be VariableNode
  251. self.outputs = None # Could be VariableNode
  252. self.backward = None
  253. self.has_grad_fn = None
  254. self.backward_allow_noinput = False
  255. class VariableNode(clearable):
  256. """ VariableNode saves OpNode and callback.
  257. FIXME!!! Explain manager and owner
  258. """
  259. def __init__(self, manager, owner, opnode=None, callback=None):
  260. # manager is Grad type
  261. self.manager = weakref.ref(manager)
  262. # owner is Tensor type
  263. self.owner = weakref.ref(owner)
  264. self.opnode = opnode
  265. self.callback = callback
  266. class Tracer(clearable, TensorBase):
  267. def __init__(self, node=None):
  268. """ type(node) is VariableNode
  269. """
  270. self.node = node
  271. @functools.singledispatch
  272. def check_backward_allow_noinput(op: OpDef):
  273. return False
  274. @functools.singledispatch
  275. def get_op_has_grad_fn(op: OpDef):
  276. assert 0
  277. @get_op_has_grad_fn.register(OpDef)
  278. def _(op: OpDef):
  279. return default_has_grad_fn
  280. @get_op_has_grad_fn.register(Function)
  281. def _(op: Function):
  282. return default_has_grad_fn
  283. def default_has_grad_fn(opnode, reached):
  284. for v in opnode.outputs:
  285. if v() in reached:
  286. return True
  287. return False
  288. @apply.register()
  289. def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
  290. args = tuple(i if isinstance(i, Tracer) else None for i in args)
  291. input_requires_grad = list(map(bool, args))
  292. if not any(input_requires_grad):
  293. return
  294. ctx = get_context()
  295. manager = None
  296. assert len(ctx.inputs) == len(args)
  297. for i, j in zip(ctx.inputs, args):
  298. if j:
  299. j = j.node
  300. assert i is j.owner()
  301. if manager is None:
  302. manager = j.manager()
  303. assert manager
  304. else:
  305. assert manager is j.manager()
  306. if not manager._enabled:
  307. return
  308. # register backward method
  309. # tuple of backward functions corresponding to dy / dx_i
  310. # None means y is not a function of x_i
  311. backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn(
  312. op, ctx.inputs, ctx.outputs, input_requires_grad
  313. )
  314. assert len(ctx.outputs) == len(output_need_grad)
  315. if not any(output_need_grad):
  316. return
  317. opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs)
  318. opnode.backward = backward
  319. outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)]
  320. opnode.backward_allow_noinput = check_backward_allow_noinput(op)
  321. opnode.has_grad_fn = get_op_has_grad_fn(op)
  322. return tuple(outputs)
  323. @apply.register()
  324. def _(op: Const, *_: typing.Optional[Tracer]):
  325. return None

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台