You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad.py 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import functools
  10. import heapq
  11. import itertools
  12. import typing
  13. import weakref
  14. import numpy as np
  15. import megengine as mge
  16. from .._imperative_rt import core2, ops
  17. from ..ops.builtin import Elemwise, OpDef, RemoteSend
  18. from ..ops.special import Const
  19. """ Some notes:
  20. 1. Initialize the optimizer:
  21. for each trainable parameter:
  22. call wrt(param, callback)
  23. Each parameter tensor will be assciated with a Tracer object saved in Tensor._extra_data
  24. 2. Tracer has one member: node, which is a VariableNode
  25. 3. VariableNode has a OpNode member: opnode
  26. 4. OpNode has four members:
  27. a. id
  28. b. inputs, which is made of VariableNode
  29. c. outputs, which are weakref's to VariableNode
  30. d. backward: call back function
  31. e. has_grad_fn: call has_grad_fn(opnode, reached) to check grad exist
  32. f. backward_allow_noinput: whether backward allow noinput
  33. """
  34. _grad_count = 0
  35. _grad_manager_dict = weakref.WeakValueDictionary()
  36. def get_grad_managers():
  37. return [_grad_manager_dict[key] for key in _grad_manager_dict]
  38. class GradKey(core2.GradKey):
  39. def __init__(self, name=None):
  40. if name:
  41. self.name = name
  42. def backward(self, ys, dys):
  43. return core2.backward(self, ys, dys)
  44. class Grad:
  45. def __init__(self, name=None):
  46. global _grad_count
  47. if name is None:
  48. name = "grad_%d" % _grad_count
  49. _grad_count += 1
  50. self._refkeeper = []
  51. self._impl = GradKey(name)
  52. _grad_manager_dict[self._name] = self
  53. @property
  54. def _name(self):
  55. return self._impl.name
  56. def _is_attached_to(self, tensor):
  57. return self._impl.is_attached_to(tensor)
  58. def wrt(self, *tensors, callback=None):
  59. for x in tensors:
  60. self._impl.attach(x, callback)
  61. return self
  62. def __call__(self, ys, dys):
  63. from collections.abc import Sequence
  64. if not isinstance(ys, Sequence):
  65. ys = [ys]
  66. if not isinstance(dys, Sequence):
  67. dys = [dys]
  68. self._impl.backward(ys, dys)
  69. self._refkeeper = None
  70. def __enter__(self):
  71. return self
  72. def __exit__(self, _1, _2, _3):
  73. self._refkeeper = None
  74. del self._impl
  75. class Function(ops.PyOpBase):
  76. def _default_rule(self, *args):
  77. ret = self.forward(*args)
  78. self.__single_output = isinstance(ret, core2.Tensor)
  79. return ret
  80. def _grad_rule(self, *args):
  81. return self._default_rule(*args), self.backward
  82. def __call__(self, *args):
  83. ret = core2.apply(self, *args)
  84. if self.__single_output:
  85. (ret,) = ret
  86. return ret
  87. def __getstate__(self):
  88. return self.__dict__
  89. def __setstate__(self, state):
  90. self.__dict__.update(state)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台