You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import abc
  9. import weakref
  10. from typing import Any, Dict, List, Tuple, Type
  11. import numpy
  12. from .. import get_logger
  13. from ..core._imperative_rt.core2 import Tensor as RawTensor
  14. from ..module import Module
  15. from ..tensor import Tensor
  16. logger = get_logger(__name__)
  17. class Node:
  18. r"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method.
  19. They are inputs/outputs of Expr (the operations on variables).
  20. """
  21. expr = None # type: Expr
  22. r"""The Expr which produces the Node."""
  23. __total_id = 0 # type: int
  24. _id = None # type: int
  25. _top_graph = None # type: weakref.ReferenceType
  26. _format_spec = "" # type: str
  27. def __init__(self, expr, name: str, qualname: str):
  28. self.expr = expr
  29. self.users = [] # List[Expr]
  30. self._id = Node.__total_id
  31. Node.__total_id += 1
  32. self._name = name
  33. self._qualname = qualname
  34. self.actual_node = [] # type: List[Node]
  35. def __repr__(self):
  36. format_spec = Node._format_spec
  37. return self.__format__(format_spec)
  38. def __format__(self, format_spec: str) -> str:
  39. if not format_spec:
  40. format_spec = Node._format_spec
  41. name = self._name
  42. if name is None:
  43. name = ""
  44. if format_spec in ["i", "p", "ip", "pi"]:
  45. if "p" in format_spec:
  46. prefix_name = self.top_graph._name
  47. name = "{}_{}".format(prefix_name, name)
  48. if "i" in format_spec:
  49. name = "%{}_{}".format(self._id, name)
  50. return name
  51. else:
  52. return name if name else ("%d" % self._id)
  53. @property
  54. def name(self):
  55. r"""Return the name of this Node."""
  56. return self._name
  57. @name.setter
  58. def name(self, new_name: str):
  59. r"""Set a new name to this Node."""
  60. graph = self.top_graph
  61. assert graph is not None, "The parent graph of this Node cannot be None."
  62. assert new_name not in graph._namespace.used_names, (
  63. "The name(%s) is already in use. Please try a different one again."
  64. % (new_name)
  65. )
  66. new_name = graph._namespace.create_unique_name(new_name)
  67. self._name = new_name
  68. @property
  69. def qualname(self):
  70. r"""Get the `qualname` of this Node. The `qualname` can be used to get the
  71. submodule from the traced Module or Module.
  72. Example:
  73. .. code-block::
  74. import megengine.module as M
  75. import megengine.functional as F
  76. import megengine.traced_module as tm
  77. import megengine as mge
  78. class block(M.Module):
  79. def __init__(self):
  80. super().__init__()
  81. self.param = mge.Tensor([1.])
  82. self.relu = M.ReLU()
  83. def forward(self, x):
  84. x = x + self.param
  85. return self.relu(F.relu(x))
  86. class module(M.Module):
  87. def __init__(self):
  88. super().__init__()
  89. self.block = block()
  90. def forward(self, x):
  91. x = self.block(x)
  92. return x
  93. net = module()
  94. traced_net = tm.trace_module(net, mge.Tensor([0.]))
  95. traced_net = traced_net.flatten()
  96. out_node = traced_net.graph.outputs[0]
  97. # qualname : "module.block.relu.[out]"
  98. qualname = out_node.qualname
  99. # qualname : "block.relu"
  100. qualname = qualname.split(".", 1)[-1].rsplit(".", 1)[0]
  101. assert qualname in list(map(lambda x: x[0], net.named_modules()))
  102. assert qualname in list(map(lambda x: x[0], traced_net.named_modules()))
  103. """
  104. return self._qualname
  105. @property
  106. def top_graph(self):
  107. r"""Get the parent graph of this Node."""
  108. if self._top_graph:
  109. return self._top_graph()
  110. return None
  111. @classmethod
  112. def _set_format_spec(cls, str):
  113. old_format_spec = cls._format_spec
  114. cls._format_spec = str
  115. return old_format_spec
  116. @classmethod
  117. def _get_next_id(cls):
  118. return cls.__total_id
  119. @classmethod
  120. def _set_next_id(cls, id: int = 0):
  121. assert isinstance(id, int)
  122. cls.__total_id = id
  123. class ModuleNode(Node):
  124. r"""``ModuleNode`` represents the Module objects."""
  125. module_type = Module # type: Type[Module]
  126. r"""The type of the Module correspending to the ModuleNode."""
  127. _owner = None # type: weakref.ReferenceType
  128. def __init__(self, expr, name: str = None, qualname: str = None):
  129. super().__init__(expr, name, qualname)
  130. def __getstate__(self):
  131. return {
  132. "expr": self.expr,
  133. "users": self.users,
  134. "_id": self._id,
  135. "_name": self._name,
  136. "_qualname": self._qualname,
  137. "module_type": self.module_type,
  138. }
  139. def __setstate__(self, state):
  140. if "_orig_name" in state:
  141. state["_qualname"] = state.pop("_orig_name")
  142. self.__dict__.update(state)
  143. @property
  144. def owner(self):
  145. r"""Get the ``Module`` corresponding to this ``ModuleNode``.
  146. """
  147. if self._owner:
  148. return self._owner()
  149. return None
  150. class TensorNode(Node):
  151. r"""``TensorNode`` represents the Tensor objects."""
  152. _shape = None # type: Tuple[int]
  153. _dtype = None # type: numpy.dtype
  154. _qparams = None
  155. _device = None
  156. _value = None # type: Tensor
  157. def __getstate__(self):
  158. return {
  159. "expr": self.expr,
  160. "users": self.users,
  161. "_id": self._id,
  162. "_qparams": self._qparams,
  163. "_shape": self._shape,
  164. "_dtype": self._dtype,
  165. "_device": self._device,
  166. "_name": self._name,
  167. "_qualname": self._qualname,
  168. }
  169. def __setstate__(self, state):
  170. if "_orig_name" in state:
  171. qualname = state.pop("_orig_name")
  172. modulepath, comma, qualname = qualname.rpartition(".")
  173. expr_name = state["expr"].__class__.__name__
  174. if expr_name not in ["GetAttr"]:
  175. qualname = "[{}]".format(qualname)
  176. if comma:
  177. qualname = "{}.{}".format(modulepath, qualname)
  178. state["_qualname"] = qualname
  179. self.__dict__.update(state)
  180. @property
  181. def shape(self):
  182. r"""Get the shape of this Node."""
  183. return self._shape
  184. @shape.setter
  185. def shape(self, shape):
  186. self._shape = shape
  187. @property
  188. def dtype(self):
  189. r"""Get the dtype of this Node."""
  190. return self._dtype
  191. @dtype.setter
  192. def dtype(self, dtype):
  193. self._dtype = dtype
  194. @property
  195. def device(self):
  196. r"""Get the device of this Node pointed Tensor."""
  197. return self._device
  198. @device.setter
  199. def device(self, device):
  200. self._device = device
  201. @property
  202. def qparams(self):
  203. r"""Get the :class:`QParams` of this Node."""
  204. return self._qparams
  205. @qparams.setter
  206. def qparams(self, qparams):
  207. self._qparams = qparams
  208. @property
  209. def value(self):
  210. r"""Get the bound Tensor of this Node."""
  211. return self._value
  212. @value.setter
  213. def value(self, value):
  214. r"""Bind a :class:`Tensor` to this Node."""
  215. if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
  216. setattr(value, "_NodeMixin__node", None)
  217. self._value = value
  218. class NodeMixin(abc.ABC):
  219. __node = None
  220. @abc.abstractmethod
  221. def _record_wrapped_nodes(self, node):
  222. # record the nodes which had been bound to this NodeMixin
  223. pass
  224. @classmethod
  225. def _record_tensornode_property(cls, node, value):
  226. assert isinstance(node, TensorNode)
  227. assert isinstance(value, RawTensor)
  228. if isinstance(value, RawTensor):
  229. node._dtype = value.dtype
  230. node._shape = (
  231. value._tuple_shape if isinstance(value, Tensor) else value.shape
  232. )
  233. node._device = value.device
  234. if hasattr(value, "_qparams") and value._qparams is not None:
  235. node._qparams = value.qparams
  236. @classmethod
  237. def wrap(cls, value, node):
  238. if isinstance(value, (NodeMixin, RawTensor)):
  239. if isinstance(node, Node):
  240. if isinstance(value, RawTensor):
  241. cls._record_tensornode_property(node, value)
  242. if isinstance(value, NodeMixin):
  243. value._record_wrapped_nodes(node)
  244. setattr(value, "_NodeMixin__node", node)
  245. else:
  246. assert callable(node)
  247. n = node()
  248. assert isinstance(n, Node)
  249. if isinstance(value, RawTensor):
  250. cls._record_tensornode_property(n, value)
  251. if isinstance(value, NodeMixin):
  252. value._record_wrapped_nodes(n)
  253. setattr(value, "_NodeMixin__node", n)
  254. @classmethod
  255. def wrap_safe(cls, value, node):
  256. assert isinstance(value, (NodeMixin, RawTensor))
  257. if isinstance(value, RawTensor):
  258. cls._record_tensornode_property(node, value)
  259. setattr(value, "_NodeMixin__node", node)
  260. if isinstance(value, NodeMixin):
  261. value._record_wrapped_nodes(node)
  262. @classmethod
  263. def get(cls, value, *default):
  264. return getattr(value, "_NodeMixin__node", *default)
  265. @classmethod
  266. def get_wrapped_type(cls, value):
  267. if isinstance(value, RawTensor):
  268. return TensorNode
  269. if isinstance(value, (Module, NodeMixin)):
  270. return ModuleNode
  271. return Node

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台