You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 7.5 kB


  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import abc
  10. import weakref
  11. from typing import Any, Dict, List, Tuple, Type
  12. import numpy
  13. from ...core._imperative_rt.core2 import Tensor as RawTensor
  14. from ...module import Module
  15. from ...tensor import Tensor
  16. class Node:
  17. """
  18. ``Node`` represents the variables (Tensor/Module/other python object) used in Module's forward method. They are inputs/outputs of Expr(the operations on variables).
  19. param expr: the Expr which produces the node
  20. param name: the name of the node
  21. """
  22. expr = None
  23. __total_id = 0
  24. _id = None
  25. _top_graph = None # type: weakref.ReferenceType
  26. _name = None
  27. _orig_name = None
  28. _format_spec = ""
  29. def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
  30. self.expr = expr
  31. self.users = [] # List[Expr]
  32. self._id = Node.__total_id
  33. Node.__total_id += 1
  34. self._name = name
  35. self._orig_name = orig_name
  36. self.actual_node = [] # type: List[Node]
  37. def __setstate__(self, d):
  38. self.__dict__ = d
  39. Node.__total_id = max(Node.__total_id, self._id) + 1
  40. def __repr__(self):
  41. format_spec = Node._format_spec
  42. return self.__format__(format_spec)
  43. def __format__(self, format_spec: str) -> str:
  44. if not format_spec:
  45. format_spec = Node._format_spec
  46. name = self._name
  47. if name is None:
  48. name = ""
  49. if format_spec in ["i", "p", "ip", "pi"]:
  50. if "p" in format_spec:
  51. graph = self.top_graph
  52. prefix_name = ""
  53. if graph is not None:
  54. prefix_name = graph._name
  55. if graph._prefix_name:
  56. prefix_name = "{}_{}".format(
  57. graph._prefix_name, prefix_name.lstrip("_")
  58. )
  59. if name:
  60. name = "_" + name.lstrip("_")
  61. name = "{}{}".format(prefix_name, name)
  62. if "i" in format_spec:
  63. if name:
  64. name = "_" + name.lstrip("_")
  65. name = "%{}{}".format(self._id, name)
  66. return name
  67. else:
  68. return name if name else ("%d" % self._id)
  69. @property
  70. def top_graph(self):
  71. if self._top_graph:
  72. return self._top_graph()
  73. return None
  74. @classmethod
  75. def set_format_spec(cls, str):
  76. old_format_spec = cls._format_spec
  77. cls._format_spec = str
  78. return old_format_spec
  79. class ModuleNode(Node):
  80. """
  81. ``ModuleNode`` represents the Module objects.
  82. Attributes:
  83. module_type: type of the Module correspending to the ModuleNode
  84. graph: the InternalGraph which will be interpreted when call Module's forward method
  85. attr_type_map: record the type of Module's attributes
  86. """
  87. module_type = Module # type: Type[Module]
  88. _owner = None # type: weakref.ReferenceType
  89. def __init__(self, expr: "Expr", name: str = None, orig_name: str = None):
  90. super().__init__(expr, name, orig_name)
  91. def __getstate__(self):
  92. return {
  93. "expr": self.expr,
  94. "users": self.users,
  95. "_id": self._id,
  96. "_name": self._name,
  97. "_orig_name": self._orig_name,
  98. "module_type": self.module_type,
  99. }
  100. @property
  101. def owner(self):
  102. if self._owner:
  103. return self._owner()
  104. return None
  105. class TensorNode(Node):
  106. """
  107. ``TensorNode`` represents the Tensor objects.
  108. """
  109. _shape = None # type: Tuple[int]
  110. _dtype = None # type: numpy.dtype
  111. _qparams = None
  112. _device = None
  113. _value = None # type: Tensor
  114. def __getstate__(self):
  115. return {
  116. "expr": self.expr,
  117. "users": self.users,
  118. "_id": self._id,
  119. "_qparams": self._qparams,
  120. "_shape": self._shape,
  121. "_dtype": self._dtype,
  122. "_device": self._device,
  123. "_name": self._name,
  124. "_orig_name": self._orig_name,
  125. }
  126. @property
  127. def shape(self):
  128. return self._shape
  129. @shape.setter
  130. def shape(self, shape):
  131. self._shape = shape
  132. @property
  133. def dtype(self):
  134. return self._dtype
  135. @dtype.setter
  136. def dtype(self, dtype):
  137. self._dtype = dtype
  138. @property
  139. def device(self):
  140. return self._device
  141. @device.setter
  142. def device(self, device):
  143. self._device = device
  144. @property
  145. def qparams(self):
  146. return self._qparams
  147. @qparams.setter
  148. def qparams(self, qparams):
  149. self._qparams = qparams
  150. @property
  151. def value(self):
  152. return self._value
  153. @value.setter
  154. def value(self, value):
  155. if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
  156. setattr(value, "_NodeMixin__node", None)
  157. self._value = value
  158. class NodeMixin(abc.ABC):
  159. __node = None
  160. @abc.abstractmethod
  161. def _record_wrapped_nodes(self, node):
  162. # record the nodes which had been bound to this NodeMixin
  163. pass
  164. @classmethod
  165. def _record_tensornode_property(cls, node, value):
  166. assert isinstance(node, TensorNode)
  167. assert isinstance(value, RawTensor)
  168. if isinstance(value, RawTensor):
  169. node._dtype = value.dtype
  170. node._shape = (
  171. value._tuple_shape if isinstance(value, Tensor) else value.shape
  172. )
  173. node._device = value.device
  174. if hasattr(value, "_qparams") and value._qparams is not None:
  175. node._qparams = value.qparams
  176. @classmethod
  177. def wrap(cls, value, node):
  178. if isinstance(value, (NodeMixin, RawTensor)):
  179. if isinstance(node, Node):
  180. if isinstance(value, RawTensor):
  181. cls._record_tensornode_property(node, value)
  182. if isinstance(value, NodeMixin):
  183. value._record_wrapped_nodes(node)
  184. setattr(value, "_NodeMixin__node", node)
  185. else:
  186. assert callable(node)
  187. n = node()
  188. assert isinstance(n, Node)
  189. if isinstance(value, RawTensor):
  190. cls._record_tensornode_property(n, value)
  191. if isinstance(value, NodeMixin):
  192. value._record_wrapped_nodes(n)
  193. setattr(value, "_NodeMixin__node", n)
  194. @classmethod
  195. def wrap_safe(cls, value, node):
  196. assert isinstance(value, (NodeMixin, RawTensor))
  197. if isinstance(value, RawTensor):
  198. cls._record_tensornode_property(node, value)
  199. setattr(value, "_NodeMixin__node", node)
  200. if isinstance(value, NodeMixin):
  201. value._record_wrapped_nodes(node)
  202. @classmethod
  203. def get(cls, value, *default):
  204. return getattr(value, "_NodeMixin__node", *default)
  205. @classmethod
  206. def get_wrapped_type(cls, value):
  207. if isinstance(value, RawTensor):
  208. return TensorNode
  209. if isinstance(value, (Module, NodeMixin)):
  210. return ModuleNode
  211. return Node

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台