You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import abc
  9. import copy
  10. import weakref
  11. from importlib import import_module
  12. from typing import Any, Dict, List, Tuple, Type
  13. import numpy
  14. from .. import get_logger
  15. from ..core._imperative_rt.core2 import Tensor as RawTensor
  16. from ..module import Module
  17. from ..quantization.utils import QParams
  18. from ..tensor import Tensor
  19. from .utils import _check_obj_attr
  20. logger = get_logger(__name__)
  21. class Node:
  22. r"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method.
  23. They are inputs/outputs of Expr (the operations on variables).
  24. """
  25. expr = None # type: Expr
  26. r"""The Expr which produces the Node."""
  27. __total_id = 0 # type: int
  28. _id = None # type: int
  29. _top_graph = None # type: weakref.ReferenceType
  30. _format_spec = "" # type: str
  31. def __init__(self, expr, name: str, qualname: str):
  32. self.expr = expr
  33. self.users = [] # List[Expr]
  34. self._id = Node.__total_id
  35. Node.__total_id += 1
  36. self._name = name
  37. self._qualname = qualname
  38. self.actual_node = [] # type: List[Node]
  39. def __repr__(self):
  40. format_spec = Node._format_spec
  41. return self.__format__(format_spec)
  42. def __format__(self, format_spec: str) -> str:
  43. if not format_spec:
  44. format_spec = Node._format_spec
  45. name = self._name
  46. if name is None:
  47. name = ""
  48. if format_spec in ["i", "p", "ip", "pi"]:
  49. if "p" in format_spec:
  50. prefix_name = self.top_graph._name
  51. name = "{}_{}".format(prefix_name, name)
  52. if "i" in format_spec:
  53. name = "%{}_{}".format(self._id, name)
  54. return name
  55. else:
  56. return name if name else ("%d" % self._id)
  57. @property
  58. def name(self):
  59. r"""Return the name of this Node."""
  60. return self._name
  61. @name.setter
  62. def name(self, new_name: str):
  63. r"""Set a new name to this Node."""
  64. graph = self.top_graph
  65. assert graph is not None, "The parent graph of this Node cannot be None."
  66. assert graph._namespace.used_names.get(new_name, None) is None, (
  67. "The name(%s) is already in use. Please try a different one again."
  68. % (new_name)
  69. )
  70. graph._namespace.unassociate_name_with_obj(self)
  71. self._name = graph._namespace.create_unique_name(new_name, self)
  72. @property
  73. def qualname(self):
  74. r"""Get the `qualname` of this Node. The `qualname` can be used to get the
  75. submodule from the traced Module or Module.
  76. Example:
  77. .. code-block::
  78. import megengine.module as M
  79. import megengine.functional as F
  80. import megengine.traced_module as tm
  81. import megengine as mge
  82. class block(M.Module):
  83. def __init__(self):
  84. super().__init__()
  85. self.param = mge.Tensor([1.])
  86. self.relu = M.ReLU()
  87. def forward(self, x):
  88. x = x + self.param
  89. return self.relu(F.relu(x))
  90. class module(M.Module):
  91. def __init__(self):
  92. super().__init__()
  93. self.block = block()
  94. def forward(self, x):
  95. x = self.block(x)
  96. return x
  97. net = module()
  98. traced_net = tm.trace_module(net, mge.Tensor([0.]))
  99. traced_net = traced_net.flatten()
  100. out_node = traced_net.graph.outputs[0]
  101. # qualname : "module.block.relu.[out]"
  102. qualname = out_node.qualname
  103. # qualname : "block.relu"
  104. qualname = qualname.split(".", 1)[-1].rsplit(".", 1)[0]
  105. assert qualname in list(map(lambda x: x[0], net.named_modules()))
  106. assert qualname in list(map(lambda x: x[0], traced_net.named_modules()))
  107. """
  108. return self._qualname
  109. @property
  110. def top_graph(self):
  111. r"""Get the parent graph of this Node."""
  112. if self._top_graph:
  113. return self._top_graph()
  114. return None
  115. @classmethod
  116. def _set_format_spec(cls, str):
  117. old_format_spec = cls._format_spec
  118. cls._format_spec = str
  119. return old_format_spec
  120. @classmethod
  121. def _get_next_id(cls):
  122. return cls.__total_id
  123. @classmethod
  124. def _set_next_id(cls, id: int = 0):
  125. assert isinstance(id, int)
  126. cls.__total_id = id
  127. def __copy__(self):
  128. cls = self.__class__
  129. result = cls.__new__(cls)
  130. result.__dict__.update(self.__dict__)
  131. return result
  132. def __deepcopy__(self, memo):
  133. cls = self.__class__
  134. result = cls.__new__(cls)
  135. state = {}
  136. memo[id(self)] = result
  137. for k, v in self.__dict__.items():
  138. if not isinstance(v, weakref.ReferenceType) and k != "actual_node":
  139. state[k] = copy.deepcopy(v, memo)
  140. result.__dict__.update(state)
  141. return result
  142. class ModuleNode(Node):
  143. r"""``ModuleNode`` represents the Module objects."""
  144. module_type = Module # type: Type[Module]
  145. r"""The type of the Module correspending to the ModuleNode."""
  146. _owner = None # type: weakref.ReferenceType
  147. def __init__(self, expr, name: str = None, qualname: str = None):
  148. super().__init__(expr, name, qualname)
  149. def __getstate__(self):
  150. state = {
  151. "expr": self.expr,
  152. "users": self.users,
  153. "_id": self._id,
  154. "_name": self._name,
  155. "_qualname": self._qualname,
  156. "module_type": (self.module_type.__module__, self.module_type.__qualname__),
  157. }
  158. _check_obj_attr(state)
  159. return state
  160. def __setstate__(self, state):
  161. if "_orig_name" in state:
  162. state["_qualname"] = state.pop("_orig_name")
  163. self.__dict__.update(state)
  164. try:
  165. if isinstance(self.module_type, tuple):
  166. mname, classname = self.module_type
  167. mtype = getattr(import_module(mname), classname)
  168. self.module_type = mtype
  169. except Exception:
  170. pass
  171. @property
  172. def owner(self):
  173. r"""Get the ``Module`` corresponding to this ``ModuleNode``.
  174. """
  175. if self._owner:
  176. return self._owner()
  177. return None
  178. class TensorNode(Node):
  179. r"""``TensorNode`` represents the Tensor objects."""
  180. _shape = None # type: Tuple[int]
  181. _dtype = None # type: numpy.dtype
  182. _qparams = None # type: QParams
  183. _device = None
  184. _value = None # type: Tensor
  185. def __init__(
  186. self,
  187. expr: "Expr",
  188. name: str = None,
  189. qualname: str = None,
  190. shape: Tuple[int] = None,
  191. dtype: numpy.dtype = None,
  192. qparams: QParams = None,
  193. ):
  194. super().__init__(expr, name, qualname)
  195. self._shape = shape
  196. self._dtype = shape
  197. self._qparams = qparams
  198. def __getstate__(self):
  199. state = {
  200. "expr": self.expr,
  201. "users": self.users,
  202. "_id": self._id,
  203. "_qparams": self._qparams,
  204. "_shape": self._shape,
  205. "_dtype": self._dtype,
  206. "_device": self._device,
  207. "_name": self._name,
  208. "_qualname": self._qualname,
  209. }
  210. _check_obj_attr(state)
  211. return state
  212. def __setstate__(self, state):
  213. if "_orig_name" in state:
  214. qualname = state.pop("_orig_name")
  215. modulepath, comma, qualname = qualname.rpartition(".")
  216. expr_name = state["expr"].__class__.__name__
  217. if expr_name not in ["GetAttr"]:
  218. qualname = "[{}]".format(qualname)
  219. if comma:
  220. qualname = "{}.{}".format(modulepath, qualname)
  221. state["_qualname"] = qualname
  222. self.__dict__.update(state)
  223. @property
  224. def shape(self):
  225. r"""Get the shape of this Node."""
  226. return self._shape
  227. @shape.setter
  228. def shape(self, shape):
  229. self._shape = shape
  230. @property
  231. def dtype(self):
  232. r"""Get the dtype of this Node."""
  233. return self._dtype
  234. @dtype.setter
  235. def dtype(self, dtype):
  236. self._dtype = dtype
  237. @property
  238. def device(self):
  239. r"""Get the device of this Node pointed Tensor."""
  240. return self._device
  241. @device.setter
  242. def device(self, device):
  243. self._device = device
  244. @property
  245. def qparams(self):
  246. r"""Get the :class:`QParams` of this Node."""
  247. return self._qparams
  248. @qparams.setter
  249. def qparams(self, qparams):
  250. self._qparams = qparams
  251. @property
  252. def value(self):
  253. r"""Get the bound Tensor of this Node."""
  254. return self._value
  255. @value.setter
  256. def value(self, value):
  257. r"""Bind a :class:`Tensor` to this Node."""
  258. if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
  259. setattr(value, "_NodeMixin__node", None)
  260. self._value = value
  261. class NodeMixin(abc.ABC):
  262. __node = None
  263. @abc.abstractmethod
  264. def _record_wrapped_nodes(self, node):
  265. # record the nodes which had been bound to this NodeMixin
  266. pass
  267. @classmethod
  268. def _record_tensornode_property(cls, node, value):
  269. assert isinstance(node, TensorNode)
  270. assert isinstance(value, RawTensor)
  271. if isinstance(value, RawTensor):
  272. try:
  273. node._dtype = value.dtype
  274. except RuntimeError:
  275. node._dtype = None
  276. node._shape = (
  277. value._tuple_shape if isinstance(value, Tensor) else value.shape
  278. )
  279. node._device = value.device
  280. if hasattr(value, "_qparams") and value._qparams is not None:
  281. node._qparams = value.qparams
  282. @classmethod
  283. def wrap(cls, value, node):
  284. if isinstance(value, (NodeMixin, RawTensor)):
  285. if isinstance(node, Node):
  286. if isinstance(value, RawTensor):
  287. cls._record_tensornode_property(node, value)
  288. if isinstance(value, NodeMixin):
  289. value._record_wrapped_nodes(node)
  290. setattr(value, "_NodeMixin__node", node)
  291. else:
  292. assert callable(node)
  293. n = node()
  294. assert isinstance(n, Node)
  295. if isinstance(value, RawTensor):
  296. cls._record_tensornode_property(n, value)
  297. if isinstance(value, NodeMixin):
  298. value._record_wrapped_nodes(n)
  299. setattr(value, "_NodeMixin__node", n)
  300. @classmethod
  301. def wrap_safe(cls, value, node):
  302. assert isinstance(value, (NodeMixin, RawTensor))
  303. if isinstance(value, RawTensor):
  304. cls._record_tensornode_property(node, value)
  305. setattr(value, "_NodeMixin__node", node)
  306. if isinstance(value, NodeMixin):
  307. value._record_wrapped_nodes(node)
  308. @classmethod
  309. def get(cls, value, *default):
  310. return getattr(value, "_NodeMixin__node", *default)
  311. @classmethod
  312. def get_wrapped_type(cls, value):
  313. if isinstance(value, RawTensor):
  314. return TensorNode
  315. if isinstance(value, (Module, NodeMixin)):
  316. return ModuleNode
  317. return Node

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台