You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import heapq
  11. import itertools
  12. import typing
  13. import weakref
  14. import numpy as np
  15. from ..ops.builtin import Elemwise, OpDef
  16. from ..ops.special import Const
  17. from ..tensor.core import TensorBase, TensorWrapperBase, apply
  18. from ..tensor.function import Function
  19. from ..tensor.tensor import Tensor, get_context
  20. from . import builtin_op_utils
  21. """ Some notes:
  22. 1. Initialize the optimizer:
  23. for each trainable parameter:
  24. call wrt(param, callback)
  25. Each parameter tensor will be assciated with a Tracer object saved in Tensor._extra_data
  26. 2. Tracer has one member: node, which is a VariableNode
  27. 3. VariableNode has a OpNode member: opnode
  28. 4. OpNode has four members:
  29. a. id
  30. b. inputs, which is made of VariableNode
  31. c. outputs, which are weakref's to VariableNode
  32. d. backward: call back function
  33. e. has_grad_fn: call has_grad_fn(opnode, reached) to check grad exist
  34. f. backward_allow_noinput: whether backward allow noinput
  35. """
  36. _grad_count = 0
  37. _grad_manager_dict = weakref.WeakValueDictionary()
  38. def get_grad_managers():
  39. return [_grad_manager_dict[key] for key in _grad_manager_dict]
  40. def add(a, b):
  41. (c,) = apply(Elemwise(mode="add"), a, b)
  42. return c
  43. def get_tensor(x):
  44. # use recursion to avoid infinite loop
  45. if isinstance(x, Tensor):
  46. return x
  47. try:
  48. x = x.__wrapped__
  49. except AttributeError:
  50. raise TypeError(type(x))
  51. return get_tensor(x)
  52. class Grad:
  53. def __init__(self, name=None):
  54. if name is None:
  55. global _grad_count
  56. self._name = "grad_" + str(_grad_count)
  57. _grad_count += 1
  58. else:
  59. self._name = name
  60. assert self._name not in _grad_manager_dict, "grad manager name duplicated"
  61. _grad_manager_dict[self._name] = self
  62. # list of all x in partial(y) / partial(x)
  63. self.xs = []
  64. # constains weak reference of all OpNode during forward
  65. # OpNode contains inputs, outputs and its backward
  66. # ops forms the computational graph
  67. self.ops = []
  68. self._enabled = True
  69. @property
  70. def name(self):
  71. return self._name
  72. def wrt(self, *args: Tensor, callback=None):
  73. """ Indicates the loss is a function of the input tensors (usually the net trainable parameters),
  74. i.e., d (loss) / d (Tensor) != 0
  75. callback is used to perform additional operations after gradient is obtained in backward.
  76. e.g., copy the grad to a particular place
  77. A VariableNode will be created and saved in the tensor/s _extra_data slot.
  78. """
  79. for x in map(get_tensor, args):
  80. v = self._new_variable(x, callback=callback)
  81. assert self not in x._extra_data
  82. x._extra_data[self] = Tracer(v)
  83. self.xs.append(v)
  84. return self
  85. def _new_variable(self, owner, opnode=None, callback=None):
  86. return VariableNode(self, owner, opnode=opnode, callback=callback)
  87. def _new_opnode(self, inputs, outputs):
  88. inputs = tuple(inputs)
  89. for i in inputs:
  90. assert i is None or isinstance(i, VariableNode)
  91. o = OpNode()
  92. o.inputs = inputs
  93. o.outputs = []
  94. tracers = []
  95. for i in outputs:
  96. assert isinstance(i, Tensor)
  97. v = self._new_variable(i, o)
  98. o.outputs.append(weakref.ref(v))
  99. tracers.append(Tracer(v))
  100. self.ops.append(weakref.ref(o))
  101. return o, tracers
  102. def copy(self):
  103. raise NotImplementedError
  104. def __enter__(self):
  105. return self
  106. def __exit__(self, *_):
  107. """clear all resources"""
  108. self._enabled = False
  109. for o in self.ops:
  110. o = o()
  111. if o:
  112. o.clear()
  113. def __call__(self, ys, dys):
  114. """ Defines Grad().
  115. :param ys: outputs of forward operators, e.g., the loss tensor
  116. :type ys: list of Tensor or TensorWrapperBase
  117. :param dys: delta of outputs, physically equivalent to sensitivity of outputs to the loss,
  118. e.g., one for the loss itself
  119. :type dys: list of Tensor or TensorWrapperBase
  120. """
  121. assert self._enabled
  122. self._enabled = False
  123. def check_wrapper():
  124. if isinstance(dys, TensorWrapperBase):
  125. return type(dys)
  126. if isinstance(dys, TensorBase):
  127. return
  128. assert isinstance(dys, (tuple, list))
  129. for i in dys:
  130. if isinstance(i, TensorWrapperBase):
  131. return type(i)
  132. Wrapper = check_wrapper()
  133. def aslist(x):
  134. if isinstance(x, (Tensor, TensorWrapperBase)):
  135. x = [x]
  136. else:
  137. x = list(x)
  138. x = [i.__wrapped__ if isinstance(i, TensorWrapperBase) else i for i in x]
  139. for i in x:
  140. assert isinstance(i, Tensor)
  141. return x
  142. ys = aslist(ys)
  143. dys = aslist(dys)
  144. assert len(ys) == len(dys)
  145. ids = [i for i, y in enumerate(ys) if self in y._extra_data.keys()]
  146. ys = [y for i, y in enumerate(ys) if i in ids]
  147. dys = [dy for i, dy in enumerate(dys) if i in ids]
  148. # ys is changed to a list of VariableNode which contains more information
  149. # such as OpNode, callback, etc.
  150. ys = [i._extra_data[self].node for i in ys]
  151. # NOTE: callback is called only if grad is not None
  152. # the OpNode sequence in backward
  153. op_seq = []
  154. # VariableNode -> (i, j), where i is time stamp in backward, j means jth input
  155. last_written_to = {}
  156. def schedule():
  157. reached = set(ys)
  158. # i is the time stamp in backward
  159. i = 0
  160. for o in self.ops[::-1]:
  161. o = o()
  162. if o is None:
  163. continue
  164. if not o.has_grad_fn(o, reached):
  165. continue
  166. op_seq.append(o)
  167. for j, v in enumerate(o.inputs):
  168. reached.add(v)
  169. last_written_to[v] = i, j
  170. i += 1
  171. schedule()
  172. # VariableNode -> Tensor
  173. cache = {}
  174. def initialize():
  175. for y, dy in zip(ys, dys):
  176. cache[y] = dy
  177. if y not in last_written_to and y.callback:
  178. y.callback(y.owner(), dy)
  179. initialize()
  180. # NOTE: None is used to mark a node has been consumed
  181. for seqno, opnode in enumerate(op_seq):
  182. input_nodes = opnode.inputs
  183. output_nodes = [i() for i in opnode.outputs]
  184. backward = opnode.backward
  185. backward_allow_noinput = opnode.backward_allow_noinput
  186. opnode.clear()
  187. output_grads = []
  188. for i in output_nodes:
  189. if i is not None:
  190. if i in cache:
  191. assert cache[i] is not None
  192. output_grads.append(cache[i])
  193. else:
  194. output_grads.append(None)
  195. # read by backward, mark consumed
  196. cache[i] = None
  197. else:
  198. output_grads.append(None)
  199. if (
  200. any([grad is not None for grad in output_grads])
  201. or backward_allow_noinput
  202. ):
  203. input_grads = backward(*output_grads)
  204. else:
  205. input_grads = [None] * len(input_nodes)
  206. assert len(input_nodes) == len(input_grads)
  207. for i, (v, g) in enumerate(zip(input_nodes, input_grads)):
  208. if v is None:
  209. continue
  210. if v in cache:
  211. assert cache[v]
  212. if g is not None:
  213. cache[v] = add(cache[v], g)
  214. elif g is not None:
  215. cache[v] = g
  216. if last_written_to[v] == (seqno, i):
  217. if v.callback:
  218. v.callback(
  219. v.owner(), Wrapper(cache[v]) if Wrapper else cache[v]
  220. )
  221. if v.opnode is None:
  222. # won't read by backward, mark consumed
  223. cache[v] = None
  224. for v in cache.values():
  225. assert v is None
  226. class clearable:
  227. __cleared = False
  228. def __bool__(self):
  229. return not self.__cleared
  230. def clear(self):
  231. self.__dict__.clear()
  232. self.__cleared = True
  233. class OpNode(clearable):
  234. """ OpNode saves all the information to form the computational graph.
  235. """
  236. def __init__(self):
  237. self.id = None
  238. self.inputs = None # Could be VariableNode
  239. self.outputs = None # Could be VariableNode
  240. self.backward = None
  241. self.has_grad_fn = None
  242. self.backward_allow_noinput = False
  243. class VariableNode(clearable):
  244. """ VariableNode saves OpNode and callback.
  245. FIXME!!! Explain manager and owner
  246. """
  247. def __init__(self, manager, owner, opnode=None, callback=None):
  248. # manager is Grad type
  249. self.manager = weakref.ref(manager)
  250. # owner is Tensor type
  251. self.owner = weakref.ref(owner)
  252. self.opnode = opnode
  253. self.callback = callback
  254. class Tracer(clearable, TensorBase):
  255. def __init__(self, node=None):
  256. """ type(node) is VariableNode
  257. """
  258. self.node = node
  259. @functools.singledispatch
  260. def check_backward_allow_noinput(op: OpDef):
  261. return False
  262. @functools.singledispatch
  263. def get_op_has_grad_fn(op: OpDef):
  264. assert 0
  265. @get_op_has_grad_fn.register(OpDef)
  266. def _(op: OpDef):
  267. return default_has_grad_fn
  268. @get_op_has_grad_fn.register(Function)
  269. def _(op: Function):
  270. return default_has_grad_fn
  271. def default_has_grad_fn(opnode, reached):
  272. for v in opnode.outputs:
  273. if v() in reached:
  274. return True
  275. return False
  276. @apply.register()
  277. def tracer_apply(op: (OpDef, Function), *args: typing.Optional[Tracer]):
  278. args = tuple(i if isinstance(i, Tracer) else None for i in args)
  279. input_requires_grad = list(map(bool, args))
  280. if not any(input_requires_grad):
  281. return
  282. ctx = get_context()
  283. manager = None
  284. assert len(ctx.inputs) == len(args)
  285. for i, j in zip(ctx.inputs, args):
  286. if j:
  287. j = j.node
  288. assert i is j.owner()
  289. if manager is None:
  290. manager = j.manager()
  291. assert manager
  292. else:
  293. assert manager is j.manager()
  294. if not manager._enabled:
  295. return
  296. opnode, outputs = manager._new_opnode([i and i.node for i in args], ctx.outputs)
  297. # register backward method
  298. # tuple of backward functions corresponding to dy / dx_i
  299. # None means y is not a function of x_i
  300. opnode.backward, output_need_grad = builtin_op_utils.builtin_op_get_backward_fn(
  301. op, ctx.inputs, ctx.outputs, input_requires_grad
  302. )
  303. assert len(outputs) == len(output_need_grad)
  304. outputs = [x if y else None for (x, y) in zip(outputs, output_need_grad)]
  305. opnode.backward_allow_noinput = check_backward_allow_noinput(op)
  306. opnode.has_grad_fn = get_op_has_grad_fn(op)
  307. return tuple(outputs)
  308. @apply.register()
  309. def _(op: Const, *_: typing.Optional[Tracer]):
  310. return None

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台