You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import abc
  9. import weakref
  10. from typing import Any, Dict, List, Tuple, Type
  11. import numpy
  12. from .. import get_logger
  13. from ..core._imperative_rt.core2 import Tensor as RawTensor
  14. from ..module import Module
  15. from ..tensor import Tensor
  16. logger = get_logger(__name__)
  17. class Node:
  18. r"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method.
  19. They are inputs/outputs of Expr (the operations on variables).
  20. """
  21. expr = None # type: Expr
  22. r"""The Expr which produces the Node."""
  23. __total_id = 0 # type: int
  24. _id = None # type: int
  25. _top_graph = None # type: weakref.ReferenceType
  26. _format_spec = "" # type: str
  27. def __init__(self, expr, name: str, qualname: str):
  28. self.expr = expr
  29. self.users = [] # List[Expr]
  30. self._id = Node.__total_id
  31. Node.__total_id += 1
  32. self._name = name
  33. self._qualname = qualname
  34. self.actual_node = [] # type: List[Node]
  35. def __repr__(self):
  36. format_spec = Node._format_spec
  37. return self.__format__(format_spec)
  38. def __format__(self, format_spec: str) -> str:
  39. if not format_spec:
  40. format_spec = Node._format_spec
  41. name = self._name
  42. if name is None:
  43. name = ""
  44. if format_spec in ["i", "p", "ip", "pi"]:
  45. if "p" in format_spec:
  46. prefix_name = self.top_graph._name
  47. name = "{}_{}".format(prefix_name, name)
  48. if "i" in format_spec:
  49. name = "%{}_{}".format(self._id, name)
  50. return name
  51. else:
  52. return name if name else ("%d" % self._id)
  53. @property
  54. def name(self):
  55. r"""Return the name of this Node."""
  56. return self._name
  57. @name.setter
  58. def name(self, new_name: str):
  59. r"""Set a new name to this Node."""
  60. graph = self.top_graph
  61. assert graph is not None, "The parent graph of this Node cannot be None."
  62. assert new_name not in graph._namespace.used_names, (
  63. "The name(%s) is already in use. Please try a different one again."
  64. % (new_name)
  65. )
  66. self._name = graph._namespace.create_unique_name(new_name, self)
  67. @property
  68. def qualname(self):
  69. r"""Get the `qualname` of this Node. The `qualname` can be used to get the
  70. submodule from the traced Module or Module.
  71. Example:
  72. .. code-block::
  73. import megengine.module as M
  74. import megengine.functional as F
  75. import megengine.traced_module as tm
  76. import megengine as mge
  77. class block(M.Module):
  78. def __init__(self):
  79. super().__init__()
  80. self.param = mge.Tensor([1.])
  81. self.relu = M.ReLU()
  82. def forward(self, x):
  83. x = x + self.param
  84. return self.relu(F.relu(x))
  85. class module(M.Module):
  86. def __init__(self):
  87. super().__init__()
  88. self.block = block()
  89. def forward(self, x):
  90. x = self.block(x)
  91. return x
  92. net = module()
  93. traced_net = tm.trace_module(net, mge.Tensor([0.]))
  94. traced_net = traced_net.flatten()
  95. out_node = traced_net.graph.outputs[0]
  96. # qualname : "module.block.relu.[out]"
  97. qualname = out_node.qualname
  98. # qualname : "block.relu"
  99. qualname = qualname.split(".", 1)[-1].rsplit(".", 1)[0]
  100. assert qualname in list(map(lambda x: x[0], net.named_modules()))
  101. assert qualname in list(map(lambda x: x[0], traced_net.named_modules()))
  102. """
  103. return self._qualname
  104. @property
  105. def top_graph(self):
  106. r"""Get the parent graph of this Node."""
  107. if self._top_graph:
  108. return self._top_graph()
  109. return None
  110. @classmethod
  111. def _set_format_spec(cls, str):
  112. old_format_spec = cls._format_spec
  113. cls._format_spec = str
  114. return old_format_spec
  115. @classmethod
  116. def _get_next_id(cls):
  117. return cls.__total_id
  118. @classmethod
  119. def _set_next_id(cls, id: int = 0):
  120. assert isinstance(id, int)
  121. cls.__total_id = id
  122. class ModuleNode(Node):
  123. r"""``ModuleNode`` represents the Module objects."""
  124. module_type = Module # type: Type[Module]
  125. r"""The type of the Module correspending to the ModuleNode."""
  126. _owner = None # type: weakref.ReferenceType
  127. def __init__(self, expr, name: str = None, qualname: str = None):
  128. super().__init__(expr, name, qualname)
  129. def __getstate__(self):
  130. return {
  131. "expr": self.expr,
  132. "users": self.users,
  133. "_id": self._id,
  134. "_name": self._name,
  135. "_qualname": self._qualname,
  136. "module_type": self.module_type,
  137. }
  138. def __setstate__(self, state):
  139. if "_orig_name" in state:
  140. state["_qualname"] = state.pop("_orig_name")
  141. self.__dict__.update(state)
  142. @property
  143. def owner(self):
  144. r"""Get the ``Module`` corresponding to this ``ModuleNode``.
  145. """
  146. if self._owner:
  147. return self._owner()
  148. return None
  149. class TensorNode(Node):
  150. r"""``TensorNode`` represents the Tensor objects."""
  151. _shape = None # type: Tuple[int]
  152. _dtype = None # type: numpy.dtype
  153. _qparams = None
  154. _device = None
  155. _value = None # type: Tensor
  156. def __getstate__(self):
  157. return {
  158. "expr": self.expr,
  159. "users": self.users,
  160. "_id": self._id,
  161. "_qparams": self._qparams,
  162. "_shape": self._shape,
  163. "_dtype": self._dtype,
  164. "_device": self._device,
  165. "_name": self._name,
  166. "_qualname": self._qualname,
  167. }
  168. def __setstate__(self, state):
  169. if "_orig_name" in state:
  170. qualname = state.pop("_orig_name")
  171. modulepath, comma, qualname = qualname.rpartition(".")
  172. expr_name = state["expr"].__class__.__name__
  173. if expr_name not in ["GetAttr"]:
  174. qualname = "[{}]".format(qualname)
  175. if comma:
  176. qualname = "{}.{}".format(modulepath, qualname)
  177. state["_qualname"] = qualname
  178. self.__dict__.update(state)
  179. @property
  180. def shape(self):
  181. r"""Get the shape of this Node."""
  182. return self._shape
  183. @shape.setter
  184. def shape(self, shape):
  185. self._shape = shape
  186. @property
  187. def dtype(self):
  188. r"""Get the dtype of this Node."""
  189. return self._dtype
  190. @dtype.setter
  191. def dtype(self, dtype):
  192. self._dtype = dtype
  193. @property
  194. def device(self):
  195. r"""Get the device of this Node pointed Tensor."""
  196. return self._device
  197. @device.setter
  198. def device(self, device):
  199. self._device = device
  200. @property
  201. def qparams(self):
  202. r"""Get the :class:`QParams` of this Node."""
  203. return self._qparams
  204. @qparams.setter
  205. def qparams(self, qparams):
  206. self._qparams = qparams
  207. @property
  208. def value(self):
  209. r"""Get the bound Tensor of this Node."""
  210. return self._value
  211. @value.setter
  212. def value(self, value):
  213. r"""Bind a :class:`Tensor` to this Node."""
  214. if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
  215. setattr(value, "_NodeMixin__node", None)
  216. self._value = value
  217. class NodeMixin(abc.ABC):
  218. __node = None
  219. @abc.abstractmethod
  220. def _record_wrapped_nodes(self, node):
  221. # record the nodes which had been bound to this NodeMixin
  222. pass
  223. @classmethod
  224. def _record_tensornode_property(cls, node, value):
  225. assert isinstance(node, TensorNode)
  226. assert isinstance(value, RawTensor)
  227. if isinstance(value, RawTensor):
  228. node._dtype = value.dtype
  229. node._shape = (
  230. value._tuple_shape if isinstance(value, Tensor) else value.shape
  231. )
  232. node._device = value.device
  233. if hasattr(value, "_qparams") and value._qparams is not None:
  234. node._qparams = value.qparams
  235. @classmethod
  236. def wrap(cls, value, node):
  237. if isinstance(value, (NodeMixin, RawTensor)):
  238. if isinstance(node, Node):
  239. if isinstance(value, RawTensor):
  240. cls._record_tensornode_property(node, value)
  241. if isinstance(value, NodeMixin):
  242. value._record_wrapped_nodes(node)
  243. setattr(value, "_NodeMixin__node", node)
  244. else:
  245. assert callable(node)
  246. n = node()
  247. assert isinstance(n, Node)
  248. if isinstance(value, RawTensor):
  249. cls._record_tensornode_property(n, value)
  250. if isinstance(value, NodeMixin):
  251. value._record_wrapped_nodes(n)
  252. setattr(value, "_NodeMixin__node", n)
  253. @classmethod
  254. def wrap_safe(cls, value, node):
  255. assert isinstance(value, (NodeMixin, RawTensor))
  256. if isinstance(value, RawTensor):
  257. cls._record_tensornode_property(node, value)
  258. setattr(value, "_NodeMixin__node", node)
  259. if isinstance(value, NodeMixin):
  260. value._record_wrapped_nodes(node)
  261. @classmethod
  262. def get(cls, value, *default):
  263. return getattr(value, "_NodeMixin__node", *default)
  264. @classmethod
  265. def get_wrapped_type(cls, value):
  266. if isinstance(value, RawTensor):
  267. return TensorNode
  268. if isinstance(value, (Module, NodeMixin)):
  269. return ModuleNode
  270. return Node

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台