You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import abc
  9. import weakref
  10. from typing import Any, Dict, List, Tuple, Type
  11. import numpy
  12. from .. import get_logger
  13. from ..core._imperative_rt.core2 import Tensor as RawTensor
  14. from ..module import Module
  15. from ..tensor import Tensor
  16. logger = get_logger(__name__)
  17. class Node:
  18. r"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method.
  19. They are inputs/outputs of Expr (the operations on variables).
  20. """
  21. expr = None # type: Expr
  22. r"""The Expr which produces the Node."""
  23. __total_id = 0 # type: int
  24. _id = None # type: int
  25. _top_graph = None # type: weakref.ReferenceType
  26. _name = None # type: str
  27. _orig_name = None # type: str
  28. _format_spec = "" # type: str
  29. def __init__(self, expr, name: str, orig_name: str):
  30. self.expr = expr
  31. self.users = [] # List[Expr]
  32. self._id = Node.__total_id
  33. Node.__total_id += 1
  34. self._name = name
  35. self._orig_name = orig_name
  36. self.actual_node = [] # type: List[Node]
  37. def __repr__(self):
  38. format_spec = Node._format_spec
  39. return self.__format__(format_spec)
  40. def __format__(self, format_spec: str) -> str:
  41. if not format_spec:
  42. format_spec = Node._format_spec
  43. name = self._name
  44. if name is None:
  45. name = ""
  46. if format_spec in ["i", "p", "ip", "pi"]:
  47. if "p" in format_spec:
  48. graph = self.top_graph
  49. prefix_name = ""
  50. if graph is not None:
  51. prefix_name = graph._name
  52. if graph._prefix_name:
  53. prefix_name = "{}_{}".format(
  54. graph._prefix_name, prefix_name.lstrip("_")
  55. )
  56. if name:
  57. name = "_" + name.lstrip("_")
  58. name = "{}{}".format(prefix_name, name)
  59. if "i" in format_spec:
  60. if name:
  61. name = "_" + name.lstrip("_")
  62. name = "%{}{}".format(self._id, name)
  63. return name
  64. else:
  65. return name if name else ("%d" % self._id)
  66. @property
  67. def name(self):
  68. r"""Return the name of this Node."""
  69. return self._name
  70. @name.setter
  71. def name(self, new_name: str):
  72. graph = self.top_graph
  73. assert graph is not None, "The parent graph of this Node cannot be None."
  74. assert new_name not in graph._used_names, (
  75. "The name(%s) is already in use. Please try a different one again."
  76. % (new_name)
  77. )
  78. new_name = graph._create_unique_name(new_name)
  79. self._name = new_name
  80. self._orig_name = new_name
  81. @property
  82. def top_graph(self):
  83. r"""Get the parent graph of this Node."""
  84. if self._top_graph:
  85. return self._top_graph()
  86. return None
  87. @classmethod
  88. def _set_format_spec(cls, str):
  89. old_format_spec = cls._format_spec
  90. cls._format_spec = str
  91. return old_format_spec
  92. @classmethod
  93. def _get_next_id(cls):
  94. return cls.__total_id
  95. @classmethod
  96. def _set_next_id(cls, id: int = 0):
  97. assert isinstance(id, int)
  98. cls.__total_id = id
  99. class ModuleNode(Node):
  100. r"""``ModuleNode`` represents the Module objects."""
  101. module_type = Module # type: Type[Module]
  102. r"""The type of the Module correspending to the ModuleNode."""
  103. _owner = None # type: weakref.ReferenceType
  104. def __init__(self, expr, name: str = None, orig_name: str = None):
  105. super().__init__(expr, name, orig_name)
  106. def __getstate__(self):
  107. return {
  108. "expr": self.expr,
  109. "users": self.users,
  110. "_id": self._id,
  111. "_name": self._name,
  112. "_orig_name": self._orig_name,
  113. "module_type": self.module_type,
  114. }
  115. @property
  116. def owner(self):
  117. r"""Get the ``Module`` corresponding to this ``ModuleNode``.
  118. """
  119. if self._owner:
  120. return self._owner()
  121. return None
  122. class TensorNode(Node):
  123. r"""``TensorNode`` represents the Tensor objects."""
  124. _shape = None # type: Tuple[int]
  125. _dtype = None # type: numpy.dtype
  126. _qparams = None
  127. _device = None
  128. _value = None # type: Tensor
  129. def __getstate__(self):
  130. return {
  131. "expr": self.expr,
  132. "users": self.users,
  133. "_id": self._id,
  134. "_qparams": self._qparams,
  135. "_shape": self._shape,
  136. "_dtype": self._dtype,
  137. "_device": self._device,
  138. "_name": self._name,
  139. "_orig_name": self._orig_name,
  140. }
  141. @property
  142. def shape(self):
  143. r"""Get the shape of this Node."""
  144. return self._shape
  145. @shape.setter
  146. def shape(self, shape):
  147. self._shape = shape
  148. @property
  149. def dtype(self):
  150. r"""Get the dtype of this Node."""
  151. return self._dtype
  152. @dtype.setter
  153. def dtype(self, dtype):
  154. self._dtype = dtype
  155. @property
  156. def device(self):
  157. r"""Get the device of this Node pointed Tensor."""
  158. return self._device
  159. @device.setter
  160. def device(self, device):
  161. self._device = device
  162. @property
  163. def qparams(self):
  164. r"""Get the :class:`QParams` of this Node."""
  165. return self._qparams
  166. @qparams.setter
  167. def qparams(self, qparams):
  168. self._qparams = qparams
  169. @property
  170. def value(self):
  171. r"""Get the bound Tensor of this Node."""
  172. return self._value
  173. @value.setter
  174. def value(self, value):
  175. r"""Bind a :class:`Tensor` to this Node."""
  176. if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
  177. setattr(value, "_NodeMixin__node", None)
  178. self._value = value
  179. class NodeMixin(abc.ABC):
  180. __node = None
  181. @abc.abstractmethod
  182. def _record_wrapped_nodes(self, node):
  183. # record the nodes which had been bound to this NodeMixin
  184. pass
  185. @classmethod
  186. def _record_tensornode_property(cls, node, value):
  187. assert isinstance(node, TensorNode)
  188. assert isinstance(value, RawTensor)
  189. if isinstance(value, RawTensor):
  190. node._dtype = value.dtype
  191. node._shape = (
  192. value._tuple_shape if isinstance(value, Tensor) else value.shape
  193. )
  194. node._device = value.device
  195. if hasattr(value, "_qparams") and value._qparams is not None:
  196. node._qparams = value.qparams
  197. @classmethod
  198. def wrap(cls, value, node):
  199. if isinstance(value, (NodeMixin, RawTensor)):
  200. if isinstance(node, Node):
  201. if isinstance(value, RawTensor):
  202. cls._record_tensornode_property(node, value)
  203. if isinstance(value, NodeMixin):
  204. value._record_wrapped_nodes(node)
  205. setattr(value, "_NodeMixin__node", node)
  206. else:
  207. assert callable(node)
  208. n = node()
  209. assert isinstance(n, Node)
  210. if isinstance(value, RawTensor):
  211. cls._record_tensornode_property(n, value)
  212. if isinstance(value, NodeMixin):
  213. value._record_wrapped_nodes(n)
  214. setattr(value, "_NodeMixin__node", n)
  215. @classmethod
  216. def wrap_safe(cls, value, node):
  217. assert isinstance(value, (NodeMixin, RawTensor))
  218. if isinstance(value, RawTensor):
  219. cls._record_tensornode_property(node, value)
  220. setattr(value, "_NodeMixin__node", node)
  221. if isinstance(value, NodeMixin):
  222. value._record_wrapped_nodes(node)
  223. @classmethod
  224. def get(cls, value, *default):
  225. return getattr(value, "_NodeMixin__node", *default)
  226. @classmethod
  227. def get_wrapped_type(cls, value):
  228. if isinstance(value, RawTensor):
  229. return TensorNode
  230. if isinstance(value, (Module, NodeMixin)):
  231. return ModuleNode
  232. return Node

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台