You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import abc
  9. import weakref
  10. from typing import Any, Dict, List, Tuple, Type
  11. import numpy
  12. from .. import get_logger
  13. from ..core._imperative_rt.core2 import Tensor as RawTensor
  14. from ..module import Module
  15. from ..tensor import Tensor
  16. logger = get_logger(__name__)
  17. class Node:
  18. r"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method.
  19. They are inputs/outputs of Expr (the operations on variables).
  20. """
  21. expr = None # type: Expr
  22. r"""The Expr which produces the Node."""
  23. __total_id = 0 # type: int
  24. _id = None # type: int
  25. _top_graph = None # type: weakref.ReferenceType
  26. _name = None # type: str
  27. _orig_name = None # type: str
  28. _format_spec = "" # type: str
  29. def __init__(self, expr: "Expr", name: str, orig_name: str):
  30. self.expr = expr
  31. self.users = [] # List[Expr]
  32. self._id = Node.__total_id
  33. Node.__total_id += 1
  34. self._name = name
  35. self._orig_name = orig_name
  36. self.actual_node = [] # type: List[Node]
  37. def __repr__(self):
  38. format_spec = Node._format_spec
  39. return self.__format__(format_spec)
  40. def __format__(self, format_spec: str) -> str:
  41. if not format_spec:
  42. format_spec = Node._format_spec
  43. name = self._name
  44. if name is None:
  45. name = ""
  46. if format_spec in ["i", "p", "ip", "pi"]:
  47. if "p" in format_spec:
  48. graph = self.top_graph
  49. prefix_name = ""
  50. if graph is not None:
  51. prefix_name = graph._name
  52. if graph._prefix_name:
  53. prefix_name = "{}_{}".format(
  54. graph._prefix_name, prefix_name.lstrip("_")
  55. )
  56. if name:
  57. name = "_" + name.lstrip("_")
  58. name = "{}{}".format(prefix_name, name)
  59. if "i" in format_spec:
  60. if name:
  61. name = "_" + name.lstrip("_")
  62. name = "%{}{}".format(self._id, name)
  63. return name
  64. else:
  65. return name if name else ("%d" % self._id)
  66. @property
  67. def name(self):
  68. r"""Return the name of this Node."""
  69. return self._name
  70. @name.setter
  71. def name(self, new_name: str):
  72. graph = self.top_graph
  73. assert graph is not None, "The parent graph of this Node cannot be None."
  74. assert new_name not in graph._used_names, (
  75. "The name(%s) is already in use. Please try a different one again."
  76. % (new_name)
  77. )
  78. new_name = graph._create_unique_name(new_name)
  79. self._name = new_name
  80. self._orig_name = new_name
  81. @property
  82. def top_graph(self):
  83. r"""Get the parent graph of this Node."""
  84. if self._top_graph:
  85. return self._top_graph()
  86. return None
  87. @classmethod
  88. def _set_format_spec(cls, str):
  89. old_format_spec = cls._format_spec
  90. cls._format_spec = str
  91. return old_format_spec
  92. @classmethod
  93. def _get_next_id(cls):
  94. return cls.__total_id
  95. @classmethod
  96. def _set_next_id(cls, id: int = 0):
  97. assert isinstance(id, int)
  98. cls.__total_id = id
  99. class ModuleNode(Node):
  100. r"""``ModuleNode`` represents the Module objects."""
  101. module_type = Module # type: Type[Module]
  102. r"""The type of the Module correspending to the ModuleNode."""
  103. _owner = None # type: weakref.ReferenceType
  104. def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
  105. super().__init__(expr, name, orig_name)
  106. def __getstate__(self):
  107. return {
  108. "expr": self.expr,
  109. "users": self.users,
  110. "_id": self._id,
  111. "_name": self._name,
  112. "_orig_name": self._orig_name,
  113. "module_type": self.module_type,
  114. }
  115. @property
  116. def owner(self):
  117. r"""Get the ``Module`` corresponding to this ``ModuleNode``.
  118. Returns:
  119. An :calss:`~.Module`.
  120. """
  121. if self._owner:
  122. return self._owner()
  123. return None
  124. class TensorNode(Node):
  125. r"""``TensorNode`` represents the Tensor objects."""
  126. _shape = None # type: Tuple[int]
  127. _dtype = None # type: numpy.dtype
  128. _qparams = None
  129. _device = None
  130. _value = None # type: Tensor
  131. def __getstate__(self):
  132. return {
  133. "expr": self.expr,
  134. "users": self.users,
  135. "_id": self._id,
  136. "_qparams": self._qparams,
  137. "_shape": self._shape,
  138. "_dtype": self._dtype,
  139. "_device": self._device,
  140. "_name": self._name,
  141. "_orig_name": self._orig_name,
  142. }
  143. @property
  144. def shape(self):
  145. r"""Get the shape of this Node."""
  146. return self._shape
  147. @shape.setter
  148. def shape(self, shape):
  149. self._shape = shape
  150. @property
  151. def dtype(self):
  152. r"""Get the dtype of this Node."""
  153. return self._dtype
  154. @dtype.setter
  155. def dtype(self, dtype):
  156. self._dtype = dtype
  157. @property
  158. def device(self):
  159. r"""Get the device of this Node pointed Tensor."""
  160. return self._device
  161. @device.setter
  162. def device(self, device):
  163. self._device = device
  164. @property
  165. def qparams(self):
  166. r"""Get the :calss:`QParams` of this Node."""
  167. return self._qparams
  168. @qparams.setter
  169. def qparams(self, qparams):
  170. self._qparams = qparams
  171. @property
  172. def value(self):
  173. r"""Get the bound Tensor of this Node."""
  174. return self._value
  175. @value.setter
  176. def value(self, value):
  177. r"""Bind a Tensor to this Node.
  178. Args:
  179. value: A :class:`Tensor`.
  180. """
  181. if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
  182. setattr(value, "_NodeMixin__node", None)
  183. self._value = value
  184. class NodeMixin(abc.ABC):
  185. __node = None
  186. @abc.abstractmethod
  187. def _record_wrapped_nodes(self, node):
  188. # record the nodes which had been bound to this NodeMixin
  189. pass
  190. @classmethod
  191. def _record_tensornode_property(cls, node, value):
  192. assert isinstance(node, TensorNode)
  193. assert isinstance(value, RawTensor)
  194. if isinstance(value, RawTensor):
  195. node._dtype = value.dtype
  196. node._shape = (
  197. value._tuple_shape if isinstance(value, Tensor) else value.shape
  198. )
  199. node._device = value.device
  200. if hasattr(value, "_qparams") and value._qparams is not None:
  201. node._qparams = value.qparams
  202. @classmethod
  203. def wrap(cls, value, node):
  204. if isinstance(value, (NodeMixin, RawTensor)):
  205. if isinstance(node, Node):
  206. if isinstance(value, RawTensor):
  207. cls._record_tensornode_property(node, value)
  208. if isinstance(value, NodeMixin):
  209. value._record_wrapped_nodes(node)
  210. setattr(value, "_NodeMixin__node", node)
  211. else:
  212. assert callable(node)
  213. n = node()
  214. assert isinstance(n, Node)
  215. if isinstance(value, RawTensor):
  216. cls._record_tensornode_property(n, value)
  217. if isinstance(value, NodeMixin):
  218. value._record_wrapped_nodes(n)
  219. setattr(value, "_NodeMixin__node", n)
  220. @classmethod
  221. def wrap_safe(cls, value, node):
  222. assert isinstance(value, (NodeMixin, RawTensor))
  223. if isinstance(value, RawTensor):
  224. cls._record_tensornode_property(node, value)
  225. setattr(value, "_NodeMixin__node", node)
  226. if isinstance(value, NodeMixin):
  227. value._record_wrapped_nodes(node)
  228. @classmethod
  229. def get(cls, value, *default):
  230. return getattr(value, "_NodeMixin__node", *default)
  231. @classmethod
  232. def get_wrapped_type(cls, value):
  233. if isinstance(value, RawTensor):
  234. return TensorNode
  235. if isinstance(value, (Module, NodeMixin)):
  236. return ModuleNode
  237. return Node

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台