You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import heapq
  11. import itertools
  12. import typing
  13. import weakref
  14. import numpy as np
  15. import megengine as mge
  16. from ..ops.builtin import Elemwise, OpDef, RemoteSend
  17. from ..ops.special import Const
  18. from ..tensor.core import TensorBase, TensorWrapperBase, apply
  19. from ..tensor.function import Function
  20. from ..tensor.tensor import Tensor, get_context
  21. from . import builtin_op_utils
  22. """ Some notes:
  23. 1. Initialize the optimizer:
  24. for each trainable parameter:
  25. call wrt(param, callback)
  26. Each parameter tensor will be assciated with a Tracer object saved in Tensor._extra_data
  27. 2. Tracer has one member: node, which is a VariableNode
  28. 3. VariableNode has a OpNode member: opnode
  29. 4. OpNode has four members:
  30. a. id
  31. b. inputs, which is made of VariableNode
  32. c. outputs, which are weakref's to VariableNode
  33. d. backward: call back function
  34. e. has_grad_fn: call has_grad_fn(opnode, reached) to check grad exist
  35. f. backward_allow_noinput: whether backward allow noinput
  36. """
  37. _grad_count = 0
  38. _grad_manager_dict = weakref.WeakValueDictionary()
  39. def get_grad_managers():
  40. return [_grad_manager_dict[key] for key in _grad_manager_dict]
  41. def add(a, b):
  42. (c,) = apply(Elemwise(Elemwise.Mode.ADD), a, b)
  43. return c
  44. def get_tensor(x):
  45. # use recursion to avoid infinite loop
  46. if isinstance(x, Tensor):
  47. return x
  48. try:
  49. x = x.__wrapped__
  50. except AttributeError:
  51. raise TypeError(type(x))
  52. return get_tensor(x)
  53. class Grad:
  54. def __init__(self, name=None):
  55. if name is None:
  56. global _grad_count
  57. self._name = "grad_" + str(_grad_count)
  58. _grad_count += 1
  59. else:
  60. self._name = name
  61. assert self._name not in _grad_manager_dict, "grad manager name duplicated"
  62. _grad_manager_dict[self._name] = self
  63. # list of all x in partial(y) / partial(x)
  64. self.xs = []
  65. # constains weak reference of all OpNode during forward
  66. # OpNode contains inputs, outputs and its backward
  67. # ops forms the computational graph
  68. self.ops = []
  69. # save remote_send output for backward
  70. self.remote_send_cache = []
  71. self._attached_tensors = weakref.WeakSet()
  72. self._enabled = True
  73. @property
  74. def name(self):
  75. return self._name
  76. def wrt(self, *args: Tensor, callback=None):
  77. """ Indicates the loss is a function of the input tensors (usually the net trainable parameters),
  78. i.e., d (loss) / d (Tensor) != 0
  79. callback is used to perform additional operations after gradient is obtained in backward.
  80. e.g., copy the grad to a particular place
  81. A VariableNode will be created and saved in the tensor/s _extra_data slot.
  82. """
  83. for x in map(get_tensor, args):
  84. v = self._new_variable(x, callback=callback)
  85. assert self not in x._extra_data
  86. x._extra_data[self] = Tracer(v)
  87. self.xs.append(v)
  88. return self
  89. def _new_variable(self, owner, opnode=None, callback=None):
  90. self._attached_tensors.add(owner)
  91. return VariableNode(self, owner, opnode=opnode, callback=callback)
  92. def _new_opnode(self, inputs, outputs):
  93. inputs = tuple(inputs)
  94. for i in inputs:
  95. assert i is None or isinstance(i, VariableNode)
  96. o = OpNode()
  97. o.inputs = inputs
  98. o.outputs = []
  99. tracers = []
  100. for i in outputs:
  101. assert isinstance(i, Tensor)
  102. v = self._new_variable(i, o)
  103. o.outputs.append(weakref.ref(v))
  104. tracers.append(Tracer(v))
  105. self.ops.append(weakref.ref(o))
  106. return o, tracers
  107. def copy(self):
  108. raise NotImplementedError
  109. def __enter__(self):
  110. return self
  111. def _exit(self):
  112. """clear all resources"""
  113. self._enabled = False
  114. for o in self.ops:
  115. o = o()
  116. if o:
  117. o.clear()
  118. for i in self._attached_tensors:
  119. i._extra_data.pop(self, None)
  120. self.remote_send_cache = []
  121. def __exit__(self, *_):
  122. self._exit()
  123. def __call__(self, ys, dys):
  124. """ Defines Grad().
  125. :param ys: outputs of forward operators, e.g., the loss tensor
  126. :type ys: list of Tensor or TensorWrapperBase
  127. :param dys: delta of outputs, physically equivalent to sensitivity of outputs to the loss,
  128. e.g., one for the loss itself
  129. :type dys: list of Tensor or TensorWrapperBase
  130. """
  131. assert self._enabled
  132. self._enabled = False
  133. def check_wrapper():
  134. if isinstance(dys, TensorWrapperBase):
  135. return type(dys)
  136. if isinstance(dys, TensorBase):
  137. return
  138. assert isinstance(dys, (tuple, list))
  139. for i in dys:
  140. if isinstance(i, TensorWrapperBase):
  141. return type(i)
  142. # use Tensor as defualt wrapper
  143. return mge.Tensor
  144. Wrapper = check_wrapper()
  145. def aslist(x):
  146. if isinstance(x, (Tensor, TensorWrapperBase)):
  147. x = [x]
  148. else:
  149. x = list(x)
  150. x = [i.__wrapped__ if isinstance(i, TensorWrapperBase) else i for i in x]
  151. for i in x:
  152. assert isinstance(i, Tensor)
  153. return x
  154. ys = aslist(ys)
  155. dys = aslist(dys)
  156. assert len(ys) == len(dys)
  157. ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()]
  158. ys = [y for i, y in enumerate(ys) if i in ids]
  159. dys = [dy for i, dy in enumerate(dys) if i in ids]
  160. # ys is changed to a list of VariableNode which contains more information
  161. # such as OpNode, callback, etc.
  162. ys = [i._extra_data[self].node for i in ys]
  163. # NOTE: callback is called only if grad is not None
  164. # the OpNode sequence in backward
  165. op_seq = []
  166. # VariableNode -> (i, j), where i is time stamp in backward, j means jth input
  167. last_written_to = {}
  168. def schedule():
  169. reached = set(ys)
  170. # i is the time stamp in backward
  171. i = 0
  172. for o in self.ops[::-1]:
  173. o = o()
  174. if o is None:
  175. continue
  176. if not o.has_grad_fn(o, reached):
  177. continue
  178. op_seq.append(o)
  179. for j, v in enumerate(o.inputs):
  180. reached.add(v)
  181. last_written_to[v] = i, j
  182. i += 1
  183. schedule()
  184. # VariableNode -> Tensor
  185. cache = {}
  186. def initialize():
  187. for y, dy in zip(ys, dys):
  188. cache[y] = dy
  189. if y not in last_written_to and y.callback:
  190. y.callback(y.owner(), dy)
  191. initialize()
  192. # NOTE: None is used to mark a node has been consumed
  193. for seqno, opnode in enumerate(op_seq):
  194. input_nodes = opnode.inputs
  195. output_nodes = [i() for i in opnode.outputs]
  196. backward = opnode.backward
  197. backward_allow_noinput = opnode.backward_allow_noinput
  198. opnode.clear()
  199. output_grads = []
  200. for i in output_nodes:
  201. if i is not None:
  202. if i in cache:
  203. assert cache[i] is not None
  204. output_grads.append(cache[i])
  205. else:
  206. output_grads.append(None)
  207. # read by backward, mark consumed
  208. cache[i] = None
  209. else:
  210. output_grads.append(None)
  211. if (
  212. any([grad is not None for grad in output_grads])
  213. or backward_allow_noinput
  214. ):
  215. input_grads = backward(*output_grads)
  216. else:
  217. input_grads = [None] * len(input_nodes)
  218. assert len(input_nodes) == len(input_grads)
  219. for i, (v, g) in enumerate(zip(input_nodes, input_grads)):
  220. if v is None:
  221. continue
  222. if v in cache:
  223. assert cache[v]
  224. if g is not None:
  225. cache[v] = add(cache[v], g)
  226. elif g is not None:
  227. cache[v] = g
  228. if last_written_to[v] == (seqno, i):
  229. if v.callback:
  230. v.callback(
  231. v.owner(), Wrapper(cache[v]) if Wrapper else cache[v]
  232. )
  233. if v.opnode is None:
  234. # won't read by backward, mark consumed
  235. cache[v] = None
  236. for v in cache.values():
  237. assert v is None
  238. self._exit()
  239. def __del__(self):
  240. self._exit()
  241. class clearable:
  242. __cleared = False
  243. def __bool__(self):
  244. return not self.__cleared
  245. def clear(self):
  246. self.__dict__.clear()
  247. self.__cleared = True
  248. class OpNode(clearable):
  249. """ OpNode saves all the information to form the computational graph.
  250. """
  251. def __init__(self):
  252. self.id = None
  253. self.inputs = None # Could be VariableNode
  254. self.outputs = None # Could be VariableNode
  255. self.backward = None
  256. self.has_grad_fn = None
  257. self.backward_allow_noinput = False
  258. class VariableNode(clearable):
  259. """ VariableNode saves OpNode and callback.
  260. FIXME!!! Explain manager and owner
  261. """
  262. def __init__(self, manager, owner, opnode=None, callback=None):
  263. # manager is Grad type
  264. self.manager = weakref.ref(manager)
  265. # owner is Tensor type
  266. self.owner = weakref.ref(owner)
  267. self.opnode = opnode
  268. self.callback = callback
  269. class Tracer(clearable, TensorBase):
  270. def __init__(self, node=None):
  271. """ type(node) is VariableNode
  272. """
  273. self.node = node
  274. @functools.singledispatch
  275. def check_backward_allow_noinput(op: OpDef):
  276. return False
  277. @functools.singledispatch
  278. def get_op_has_grad_fn(op: OpDef):
  279. assert 0
  280. @get_op_has_grad_fn.register(OpDef)
  281. def _(op: OpDef):
  282. return default_has_grad_fn
  283. @get_op_has_grad_fn.register(Function)
  284. def _(op: Function):
  285. return default_has_grad_fn
  286. def default_has_grad_fn(opnode, reached):
  287. for v in opnode.outputs:
  288. if v() in reached:
  289. return True
  290. return False
  291. @apply.register()
  292. def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
  293. args = tuple(i if isinstance(i, Tracer) else None for i in args)
  294. input_requires_grad = list(map(bool, args))
  295. if not any(input_requires_grad):
  296. return
  297. ctx = get_context()
  298. manager = None
  299. assert len(ctx.inputs) == len(args)
  300. for i, j in zip(ctx.inputs, args):
  301. if j:
  302. j = j.node
  303. assert i is j.owner()
  304. if manager is None:
  305. manager = j.manager()
  306. assert manager
  307. else:
  308. assert manager is j.manager()
  309. if not manager._enabled:
  310. return
  311. # register backward method
  312. # tuple of backward functions corresponding to dy / dx_i
  313. # None means y is not a function of x_i
  314. backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn(
  315. op, ctx.inputs, ctx.outputs, input_requires_grad
  316. )
  317. assert len(ctx.outputs) == len(output_need_grad)
  318. if not any(output_need_grad):
  319. return
  320. opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs)
  321. if isinstance(op, RemoteSend):
  322. manager.remote_send_cache.append(opnode)
  323. opnode.backward = backward
  324. outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)]
  325. opnode.backward_allow_noinput = check_backward_allow_noinput(op)
  326. opnode.has_grad_fn = get_op_has_grad_fn(op)
  327. return tuple(outputs)
  328. @apply.register()
  329. def _(op: Const, *_: typing.Optional[Tracer]):
  330. return None

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台