You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import abc
  9. import weakref
  10. from typing import Any, Dict, List, Tuple, Type
  11. import numpy
  12. from ..core._imperative_rt.core2 import Tensor as RawTensor
  13. from ..module import Module
  14. from ..tensor import Tensor
  15. class Node:
  16. """
  17. ``Node`` represents the variables (Tensor/Module/other python object) used in Module's forward method. They are inputs/outputs of Expr(the operations on variables).
  18. param expr: the Expr which produces the node
  19. param name: the name of the node
  20. """
  21. expr = None
  22. __total_id = 0
  23. _id = None
  24. _top_graph = None # type: weakref.ReferenceType
  25. _name = None
  26. _orig_name = None
  27. _format_spec = ""
  28. def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
  29. self.expr = expr
  30. self.users = [] # List[Expr]
  31. self._id = Node.__total_id
  32. Node.__total_id += 1
  33. self._name = name
  34. self._orig_name = orig_name
  35. self.actual_node = [] # type: List[Node]
  36. def __setstate__(self, d):
  37. self.__dict__ = d
  38. Node.__total_id = max(Node.__total_id, self._id) + 1
  39. def __repr__(self):
  40. format_spec = Node._format_spec
  41. return self.__format__(format_spec)
  42. def __format__(self, format_spec: str) -> str:
  43. if not format_spec:
  44. format_spec = Node._format_spec
  45. name = self._name
  46. if name is None:
  47. name = ""
  48. if format_spec in ["i", "p", "ip", "pi"]:
  49. if "p" in format_spec:
  50. graph = self.top_graph
  51. prefix_name = ""
  52. if graph is not None:
  53. prefix_name = graph._name
  54. if graph._prefix_name:
  55. prefix_name = "{}_{}".format(
  56. graph._prefix_name, prefix_name.lstrip("_")
  57. )
  58. if name:
  59. name = "_" + name.lstrip("_")
  60. name = "{}{}".format(prefix_name, name)
  61. if "i" in format_spec:
  62. if name:
  63. name = "_" + name.lstrip("_")
  64. name = "%{}{}".format(self._id, name)
  65. return name
  66. else:
  67. return name if name else ("%d" % self._id)
  68. @property
  69. def top_graph(self):
  70. if self._top_graph:
  71. return self._top_graph()
  72. return None
  73. @classmethod
  74. def set_format_spec(cls, str):
  75. old_format_spec = cls._format_spec
  76. cls._format_spec = str
  77. return old_format_spec
  78. class ModuleNode(Node):
  79. """
  80. ``ModuleNode`` represents the Module objects.
  81. Attributes:
  82. module_type: type of the Module correspending to the ModuleNode
  83. graph: the InternalGraph which will be interpreted when call Module's forward method
  84. attr_type_map: record the type of Module's attributes
  85. """
  86. module_type = Module # type: Type[Module]
  87. _owner = None # type: weakref.ReferenceType
  88. def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
  89. super().__init__(expr, name, orig_name)
  90. def __getstate__(self):
  91. return {
  92. "expr": self.expr,
  93. "users": self.users,
  94. "_id": self._id,
  95. "_name": self._name,
  96. "_orig_name": self._orig_name,
  97. "module_type": self.module_type,
  98. }
  99. @property
  100. def owner(self):
  101. if self._owner:
  102. return self._owner()
  103. return None
  104. class TensorNode(Node):
  105. """
  106. ``TensorNode`` represents the Tensor objects.
  107. """
  108. _shape = None # type: Tuple[int]
  109. _dtype = None # type: numpy.dtype
  110. _qparams = None
  111. _device = None
  112. _value = None # type: Tensor
  113. def __getstate__(self):
  114. return {
  115. "expr": self.expr,
  116. "users": self.users,
  117. "_id": self._id,
  118. "_qparams": self._qparams,
  119. "_shape": self._shape,
  120. "_dtype": self._dtype,
  121. "_device": self._device,
  122. "_name": self._name,
  123. "_orig_name": self._orig_name,
  124. }
  125. @property
  126. def shape(self):
  127. return self._shape
  128. @shape.setter
  129. def shape(self, shape):
  130. self._shape = shape
  131. @property
  132. def dtype(self):
  133. return self._dtype
  134. @dtype.setter
  135. def dtype(self, dtype):
  136. self._dtype = dtype
  137. @property
  138. def device(self):
  139. return self._device
  140. @device.setter
  141. def device(self, device):
  142. self._device = device
  143. @property
  144. def qparams(self):
  145. return self._qparams
  146. @qparams.setter
  147. def qparams(self, qparams):
  148. self._qparams = qparams
  149. @property
  150. def value(self):
  151. return self._value
  152. @value.setter
  153. def value(self, value):
  154. if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
  155. setattr(value, "_NodeMixin__node", None)
  156. self._value = value
  157. class NodeMixin(abc.ABC):
  158. __node = None
  159. @abc.abstractmethod
  160. def _record_wrapped_nodes(self, node):
  161. # record the nodes which had been bound to this NodeMixin
  162. pass
  163. @classmethod
  164. def _record_tensornode_property(cls, node, value):
  165. assert isinstance(node, TensorNode)
  166. assert isinstance(value, RawTensor)
  167. if isinstance(value, RawTensor):
  168. node._dtype = value.dtype
  169. node._shape = (
  170. value._tuple_shape if isinstance(value, Tensor) else value.shape
  171. )
  172. node._device = value.device
  173. if hasattr(value, "_qparams") and value._qparams is not None:
  174. node._qparams = value.qparams
  175. @classmethod
  176. def wrap(cls, value, node):
  177. if isinstance(value, (NodeMixin, RawTensor)):
  178. if isinstance(node, Node):
  179. if isinstance(value, RawTensor):
  180. cls._record_tensornode_property(node, value)
  181. if isinstance(value, NodeMixin):
  182. value._record_wrapped_nodes(node)
  183. setattr(value, "_NodeMixin__node", node)
  184. else:
  185. assert callable(node)
  186. n = node()
  187. assert isinstance(n, Node)
  188. if isinstance(value, RawTensor):
  189. cls._record_tensornode_property(n, value)
  190. if isinstance(value, NodeMixin):
  191. value._record_wrapped_nodes(n)
  192. setattr(value, "_NodeMixin__node", n)
  193. @classmethod
  194. def wrap_safe(cls, value, node):
  195. assert isinstance(value, (NodeMixin, RawTensor))
  196. if isinstance(value, RawTensor):
  197. cls._record_tensornode_property(node, value)
  198. setattr(value, "_NodeMixin__node", node)
  199. if isinstance(value, NodeMixin):
  200. value._record_wrapped_nodes(node)
  201. @classmethod
  202. def get(cls, value, *default):
  203. return getattr(value, "_NodeMixin__node", *default)
  204. @classmethod
  205. def get_wrapped_type(cls, value):
  206. if isinstance(value, RawTensor):
  207. return TensorNode
  208. if isinstance(value, (Module, NodeMixin)):
  209. return ModuleNode
  210. return Node

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台