You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import abc
  9. import copy
  10. import weakref
  11. from importlib import import_module
  12. from typing import Any, Dict, List, Tuple, Type
  13. import numpy
  14. from .. import get_logger
  15. from ..core._imperative_rt.core2 import Tensor as RawTensor
  16. from ..module import Module
  17. from ..quantization.utils import QParams
  18. from ..tensor import Tensor
  19. from .utils import _check_obj_attr
  20. logger = get_logger(__name__)
  21. class Node:
  22. r"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method.
  23. They are inputs/outputs of Expr (the operations on variables).
  24. """
  25. expr = None # type: Expr
  26. r"""The Expr which produces the Node."""
  27. __total_id = 0 # type: int
  28. _id = None # type: int
  29. _top_graph = None # type: weakref.ReferenceType
  30. _format_spec = "" # type: str
  31. def __init__(self, expr, name: str, qualname: str):
  32. self.expr = expr
  33. self.users = [] # List[Expr]
  34. self._id = Node.__total_id
  35. Node.__total_id += 1
  36. self._name = name
  37. self._qualname = qualname
  38. self.actual_node = [] # type: List[Node]
  39. def __repr__(self):
  40. format_spec = Node._format_spec
  41. return self.__format__(format_spec)
  42. def __format__(self, format_spec: str) -> str:
  43. if not format_spec:
  44. format_spec = Node._format_spec
  45. name = self._name
  46. if name is None:
  47. name = ""
  48. if format_spec in ["i", "p", "ip", "pi"]:
  49. if "p" in format_spec:
  50. prefix_name = self.top_graph._name
  51. name = "{}_{}".format(prefix_name, name)
  52. if "i" in format_spec:
  53. name = "%{}_{}".format(self._id, name)
  54. return name
  55. else:
  56. return name if name else ("%d" % self._id)
  57. @property
  58. def name(self):
  59. r"""Return the name of this Node."""
  60. return self._name
  61. @name.setter
  62. def name(self, new_name: str):
  63. r"""Set a new name to this Node."""
  64. graph = self.top_graph
  65. assert graph is not None, "The parent graph of this Node cannot be None."
  66. assert graph._namespace.used_names.get(new_name, None) is None, (
  67. "The name(%s) is already in use. Please try a different one again."
  68. % (new_name)
  69. )
  70. self._name = graph._namespace.create_unique_name(new_name, self)
  71. @property
  72. def qualname(self):
  73. r"""Get the `qualname` of this Node. The `qualname` can be used to get the
  74. submodule from the traced Module or Module.
  75. Example:
  76. .. code-block::
  77. import megengine.module as M
  78. import megengine.functional as F
  79. import megengine.traced_module as tm
  80. import megengine as mge
  81. class block(M.Module):
  82. def __init__(self):
  83. super().__init__()
  84. self.param = mge.Tensor([1.])
  85. self.relu = M.ReLU()
  86. def forward(self, x):
  87. x = x + self.param
  88. return self.relu(F.relu(x))
  89. class module(M.Module):
  90. def __init__(self):
  91. super().__init__()
  92. self.block = block()
  93. def forward(self, x):
  94. x = self.block(x)
  95. return x
  96. net = module()
  97. traced_net = tm.trace_module(net, mge.Tensor([0.]))
  98. traced_net = traced_net.flatten()
  99. out_node = traced_net.graph.outputs[0]
  100. # qualname : "module.block.relu.[out]"
  101. qualname = out_node.qualname
  102. # qualname : "block.relu"
  103. qualname = qualname.split(".", 1)[-1].rsplit(".", 1)[0]
  104. assert qualname in list(map(lambda x: x[0], net.named_modules()))
  105. assert qualname in list(map(lambda x: x[0], traced_net.named_modules()))
  106. """
  107. return self._qualname
  108. @property
  109. def top_graph(self):
  110. r"""Get the parent graph of this Node."""
  111. if self._top_graph:
  112. return self._top_graph()
  113. return None
  114. @classmethod
  115. def _set_format_spec(cls, str):
  116. old_format_spec = cls._format_spec
  117. cls._format_spec = str
  118. return old_format_spec
  119. @classmethod
  120. def _get_next_id(cls):
  121. return cls.__total_id
  122. @classmethod
  123. def _set_next_id(cls, id: int = 0):
  124. assert isinstance(id, int)
  125. cls.__total_id = id
  126. def __copy__(self):
  127. cls = self.__class__
  128. result = cls.__new__(cls)
  129. result.__dict__.update(self.__dict__)
  130. return result
  131. def __deepcopy__(self, memo):
  132. cls = self.__class__
  133. result = cls.__new__(cls)
  134. state = {}
  135. memo[id(self)] = result
  136. for k, v in self.__dict__.items():
  137. if not isinstance(v, weakref.ReferenceType) and k != "actual_node":
  138. state[k] = copy.deepcopy(v, memo)
  139. result.__dict__.update(state)
  140. return result
  141. class ModuleNode(Node):
  142. r"""``ModuleNode`` represents the Module objects."""
  143. module_type = Module # type: Type[Module]
  144. r"""The type of the Module correspending to the ModuleNode."""
  145. _owner = None # type: weakref.ReferenceType
  146. def __init__(self, expr, name: str = None, qualname: str = None):
  147. super().__init__(expr, name, qualname)
  148. def __getstate__(self):
  149. state = {
  150. "expr": self.expr,
  151. "users": self.users,
  152. "_id": self._id,
  153. "_name": self._name,
  154. "_qualname": self._qualname,
  155. "module_type": (self.module_type.__module__, self.module_type.__qualname__),
  156. }
  157. _check_obj_attr(state)
  158. return state
  159. def __setstate__(self, state):
  160. if "_orig_name" in state:
  161. state["_qualname"] = state.pop("_orig_name")
  162. self.__dict__.update(state)
  163. try:
  164. if isinstance(self.module_type, tuple):
  165. mname, classname = self.module_type
  166. mtype = getattr(import_module(mname), classname)
  167. self.module_type = mtype
  168. except Exception:
  169. pass
  170. @property
  171. def owner(self):
  172. r"""Get the ``Module`` corresponding to this ``ModuleNode``.
  173. """
  174. if self._owner:
  175. return self._owner()
  176. return None
  177. class TensorNode(Node):
  178. r"""``TensorNode`` represents the Tensor objects."""
  179. _shape = None # type: Tuple[int]
  180. _dtype = None # type: numpy.dtype
  181. _qparams = None # type: QParams
  182. _device = None
  183. _value = None # type: Tensor
  184. def __init__(
  185. self,
  186. expr: "Expr",
  187. name: str = None,
  188. qualname: str = None,
  189. shape: Tuple[int] = None,
  190. dtype: numpy.dtype = None,
  191. qparams: QParams = None,
  192. ):
  193. super().__init__(expr, name, qualname)
  194. self._shape = shape
  195. self._dtype = shape
  196. self._qparams = qparams
  197. def __getstate__(self):
  198. state = {
  199. "expr": self.expr,
  200. "users": self.users,
  201. "_id": self._id,
  202. "_qparams": self._qparams,
  203. "_shape": self._shape,
  204. "_dtype": self._dtype,
  205. "_device": self._device,
  206. "_name": self._name,
  207. "_qualname": self._qualname,
  208. }
  209. _check_obj_attr(state)
  210. return state
  211. def __setstate__(self, state):
  212. if "_orig_name" in state:
  213. qualname = state.pop("_orig_name")
  214. modulepath, comma, qualname = qualname.rpartition(".")
  215. expr_name = state["expr"].__class__.__name__
  216. if expr_name not in ["GetAttr"]:
  217. qualname = "[{}]".format(qualname)
  218. if comma:
  219. qualname = "{}.{}".format(modulepath, qualname)
  220. state["_qualname"] = qualname
  221. self.__dict__.update(state)
  222. @property
  223. def shape(self):
  224. r"""Get the shape of this Node."""
  225. return self._shape
  226. @shape.setter
  227. def shape(self, shape):
  228. self._shape = shape
  229. @property
  230. def dtype(self):
  231. r"""Get the dtype of this Node."""
  232. return self._dtype
  233. @dtype.setter
  234. def dtype(self, dtype):
  235. self._dtype = dtype
  236. @property
  237. def device(self):
  238. r"""Get the device of this Node pointed Tensor."""
  239. return self._device
  240. @device.setter
  241. def device(self, device):
  242. self._device = device
  243. @property
  244. def qparams(self):
  245. r"""Get the :class:`QParams` of this Node."""
  246. return self._qparams
  247. @qparams.setter
  248. def qparams(self, qparams):
  249. self._qparams = qparams
  250. @property
  251. def value(self):
  252. r"""Get the bound Tensor of this Node."""
  253. return self._value
  254. @value.setter
  255. def value(self, value):
  256. r"""Bind a :class:`Tensor` to this Node."""
  257. if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
  258. setattr(value, "_NodeMixin__node", None)
  259. self._value = value
  260. class NodeMixin(abc.ABC):
  261. __node = None
  262. @abc.abstractmethod
  263. def _record_wrapped_nodes(self, node):
  264. # record the nodes which had been bound to this NodeMixin
  265. pass
  266. @classmethod
  267. def _record_tensornode_property(cls, node, value):
  268. assert isinstance(node, TensorNode)
  269. assert isinstance(value, RawTensor)
  270. if isinstance(value, RawTensor):
  271. try:
  272. node._dtype = value.dtype
  273. except RuntimeError:
  274. node._dtype = None
  275. node._shape = (
  276. value._tuple_shape if isinstance(value, Tensor) else value.shape
  277. )
  278. node._device = value.device
  279. if hasattr(value, "_qparams") and value._qparams is not None:
  280. node._qparams = value.qparams
  281. @classmethod
  282. def wrap(cls, value, node):
  283. if isinstance(value, (NodeMixin, RawTensor)):
  284. if isinstance(node, Node):
  285. if isinstance(value, RawTensor):
  286. cls._record_tensornode_property(node, value)
  287. if isinstance(value, NodeMixin):
  288. value._record_wrapped_nodes(node)
  289. setattr(value, "_NodeMixin__node", node)
  290. else:
  291. assert callable(node)
  292. n = node()
  293. assert isinstance(n, Node)
  294. if isinstance(value, RawTensor):
  295. cls._record_tensornode_property(n, value)
  296. if isinstance(value, NodeMixin):
  297. value._record_wrapped_nodes(n)
  298. setattr(value, "_NodeMixin__node", n)
  299. @classmethod
  300. def wrap_safe(cls, value, node):
  301. assert isinstance(value, (NodeMixin, RawTensor))
  302. if isinstance(value, RawTensor):
  303. cls._record_tensornode_property(node, value)
  304. setattr(value, "_NodeMixin__node", node)
  305. if isinstance(value, NodeMixin):
  306. value._record_wrapped_nodes(node)
  307. @classmethod
  308. def get(cls, value, *default):
  309. return getattr(value, "_NodeMixin__node", *default)
  310. @classmethod
  311. def get_wrapped_type(cls, value):
  312. if isinstance(value, RawTensor):
  313. return TensorNode
  314. if isinstance(value, (Module, NodeMixin)):
  315. return ModuleNode
  316. return Node

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台