You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 13 kB


  1. import abc
  2. import copy
  3. import weakref
  4. from importlib import import_module
  5. from typing import Any, Dict, List, Tuple, Type
  6. import numpy
  7. from .. import get_logger
  8. from ..core._imperative_rt.core2 import Tensor as RawTensor
  9. from ..module import Module
  10. from ..quantization.utils import QParams
  11. from ..tensor import Tensor
  12. from .module_tracer import active_module_tracer
  13. from .tm_config import _get_expr_checker
  14. from .utils import _check_obj_attr
  15. logger = get_logger(__name__)
  16. class Node:
  17. r"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method.
  18. They are inputs/outputs of Expr (the operations on variables).
  19. """
  20. expr = None # type: Expr
  21. r"""The Expr which produces the Node."""
  22. __total_id = 0 # type: int
  23. _id = None # type: int
  24. _top_graph = None # type: weakref.ReferenceType
  25. _format_spec = "" # type: str
  26. def __init__(self, expr, name: str, qualname: str):
  27. self.expr = expr
  28. self.users = [] # List[Expr]
  29. self._id = Node.__total_id
  30. Node.__total_id += 1
  31. self._name = name
  32. self._qualname = qualname
  33. self.actual_node = [] # type: List[Node]
  34. def __repr__(self):
  35. format_spec = Node._format_spec
  36. return self.__format__(format_spec)
  37. def __format__(self, format_spec: str) -> str:
  38. if not format_spec:
  39. format_spec = Node._format_spec
  40. name = self._name
  41. if name is None:
  42. name = ""
  43. if format_spec in ["i", "p", "ip", "pi"]:
  44. if "p" in format_spec:
  45. prefix_name = self.top_graph._name
  46. name = "{}_{}".format(prefix_name, name)
  47. if "i" in format_spec:
  48. name = "%{}_{}".format(self._id, name)
  49. return name
  50. else:
  51. return name if name else ("%d" % self._id)
  52. @property
  53. def name(self):
  54. r"""Return the name of this Node."""
  55. return self._name
  56. @name.setter
  57. def name(self, new_name: str):
  58. r"""Set a new name to this Node."""
  59. graph = self.top_graph
  60. assert graph is not None, "The parent graph of this Node cannot be None."
  61. assert graph._namespace.used_names.get(new_name, None) is None, (
  62. "The name(%s) is already in use. Please try a different one again."
  63. % (new_name)
  64. )
  65. graph._namespace.unassociate_name_with_obj(self)
  66. self._name = graph._namespace.create_unique_name(new_name, self)
  67. @property
  68. def qualname(self):
  69. r"""Get the `qualname` of this Node. The `qualname` can be used to get the
  70. submodule from the traced Module or Module.
  71. Example:
  72. .. code-block::
  73. import megengine.module as M
  74. import megengine.functional as F
  75. import megengine.traced_module as tm
  76. import megengine as mge
  77. class block(M.Module):
  78. def __init__(self):
  79. super().__init__()
  80. self.param = mge.Tensor([1.])
  81. self.relu = M.ReLU()
  82. def forward(self, x):
  83. x = x + self.param
  84. return self.relu(F.relu(x))
  85. class module(M.Module):
  86. def __init__(self):
  87. super().__init__()
  88. self.block = block()
  89. def forward(self, x):
  90. x = self.block(x)
  91. return x
  92. net = module()
  93. traced_net = tm.trace_module(net, mge.Tensor([0.]))
  94. traced_net = traced_net.flatten()
  95. out_node = traced_net.graph.outputs[0]
  96. # qualname : "module.block.relu.[out]"
  97. qualname = out_node.qualname
  98. # qualname : "block.relu"
  99. qualname = qualname.split(".", 1)[-1].rsplit(".", 1)[0]
  100. assert qualname in list(map(lambda x: x[0], net.named_modules()))
  101. assert qualname in list(map(lambda x: x[0], traced_net.named_modules()))
  102. """
  103. return self._qualname
  104. @property
  105. def top_graph(self):
  106. r"""Get the parent graph of this Node."""
  107. if self._top_graph:
  108. return self._top_graph()
  109. return None
  110. @classmethod
  111. def _set_format_spec(cls, str):
  112. old_format_spec = cls._format_spec
  113. cls._format_spec = str
  114. return old_format_spec
  115. @classmethod
  116. def _get_next_id(cls):
  117. return cls.__total_id
  118. @classmethod
  119. def _set_next_id(cls, id: int = 0):
  120. assert isinstance(id, int)
  121. cls.__total_id = id
  122. def __copy__(self):
  123. cls = self.__class__
  124. result = cls.__new__(cls)
  125. result.__dict__.update(self.__dict__)
  126. return result
  127. def __deepcopy__(self, memo):
  128. cls = self.__class__
  129. result = cls.__new__(cls)
  130. state = {}
  131. memo[id(self)] = result
  132. for k, v in self.__dict__.items():
  133. if not isinstance(v, weakref.ReferenceType) and k != "actual_node":
  134. state[k] = copy.deepcopy(v, memo)
  135. result.__dict__.update(state)
  136. return result
  137. class ModuleNode(Node):
  138. r"""``ModuleNode`` represents the Module objects."""
  139. module_type = Module # type: Type[Module]
  140. r"""The type of the Module correspending to the ModuleNode."""
  141. _owner = None # type: weakref.ReferenceType
  142. def __init__(self, expr, name: str = None, qualname: str = None):
  143. super().__init__(expr, name, qualname)
  144. def __getstate__(self):
  145. state = {
  146. "expr": self.expr,
  147. "users": self.users,
  148. "_id": self._id,
  149. "_name": self._name,
  150. "_qualname": self._qualname,
  151. "module_type": (self.module_type.__module__, self.module_type.__qualname__),
  152. }
  153. _check_obj_attr(state)
  154. return state
  155. def __setstate__(self, state):
  156. if "_orig_name" in state:
  157. state["_qualname"] = state.pop("_orig_name")
  158. self.__dict__.update(state)
  159. try:
  160. if isinstance(self.module_type, tuple):
  161. mname, classname = self.module_type
  162. mtype = getattr(import_module(mname), classname)
  163. self.module_type = mtype
  164. except Exception:
  165. pass
  166. @property
  167. def owner(self):
  168. r"""Get the ``Module`` corresponding to this ``ModuleNode``.
  169. """
  170. if self._owner:
  171. return self._owner()
  172. return None
  173. class TensorNode(Node):
  174. r"""``TensorNode`` represents the Tensor objects."""
  175. _shape = None # type: Tuple[int]
  176. _dtype = None # type: numpy.dtype
  177. _qparams = None # type: QParams
  178. _device = None
  179. _value = None # type: Tensor
  180. def __init__(
  181. self,
  182. expr,
  183. name: str = None,
  184. qualname: str = None,
  185. shape: Tuple[int] = None,
  186. dtype: numpy.dtype = None,
  187. qparams: QParams = None,
  188. ):
  189. super().__init__(expr, name, qualname)
  190. self._shape = shape
  191. self._dtype = dtype
  192. self._qparams = qparams
  193. def __getstate__(self):
  194. state = {
  195. "expr": self.expr,
  196. "users": self.users,
  197. "_id": self._id,
  198. "_qparams": self._qparams,
  199. "_shape": self._shape,
  200. "_dtype": self._dtype,
  201. "_device": self._device,
  202. "_name": self._name,
  203. "_qualname": self._qualname,
  204. }
  205. _check_obj_attr(state)
  206. return state
  207. def __setstate__(self, state):
  208. if "_orig_name" in state:
  209. qualname = state.pop("_orig_name")
  210. modulepath, comma, qualname = qualname.rpartition(".")
  211. expr_name = state["expr"].__class__.__name__
  212. if expr_name not in ["GetAttr"]:
  213. qualname = "[{}]".format(qualname)
  214. if comma:
  215. qualname = "{}.{}".format(modulepath, qualname)
  216. state["_qualname"] = qualname
  217. self.__dict__.update(state)
  218. @property
  219. def shape(self):
  220. r"""Get the shape of this Node."""
  221. return self._shape
  222. @shape.setter
  223. def shape(self, shape):
  224. self._shape = shape
  225. @property
  226. def dtype(self):
  227. r"""Get the dtype of this Node."""
  228. return self._dtype
  229. @dtype.setter
  230. def dtype(self, dtype):
  231. self._dtype = dtype
  232. @property
  233. def device(self):
  234. r"""Get the device of this Node pointed Tensor."""
  235. return self._device
  236. @device.setter
  237. def device(self, device):
  238. self._device = device
  239. @property
  240. def qparams(self):
  241. r"""Get the :class:`QParams` of this Node."""
  242. return self._qparams
  243. @qparams.setter
  244. def qparams(self, qparams):
  245. self._qparams = qparams
  246. @property
  247. def value(self):
  248. r"""Get the bound Tensor of this Node."""
  249. return self._value
  250. @value.setter
  251. def value(self, value):
  252. r"""Bind a :class:`Tensor` to this Node."""
  253. if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
  254. setattr(value, "_NodeMixin__node", None)
  255. self._value = value
  256. class NodeMixin(abc.ABC):
  257. __node = None
  258. @abc.abstractmethod
  259. def _record_wrapped_nodes(self, node):
  260. # record the nodes which had been bound to this NodeMixin
  261. pass
  262. @classmethod
  263. def _record_tensornode_property(cls, node, value):
  264. assert isinstance(node, TensorNode)
  265. assert isinstance(value, RawTensor)
  266. if isinstance(value, RawTensor):
  267. try:
  268. node._dtype = value.dtype
  269. except RuntimeError:
  270. node._dtype = None
  271. node._shape = (
  272. value._tuple_shape if isinstance(value, Tensor) else value.shape
  273. )
  274. node._device = value.device
  275. if hasattr(value, "_qparams") and value._qparams is not None:
  276. node._qparams = value.qparams
  277. @classmethod
  278. def wrap(cls, value, node):
  279. if isinstance(value, (NodeMixin, RawTensor)):
  280. if isinstance(node, Node):
  281. if isinstance(value, RawTensor):
  282. cls._record_tensornode_property(node, value)
  283. if isinstance(value, NodeMixin):
  284. value._record_wrapped_nodes(node)
  285. setattr(value, "_NodeMixin__node", node)
  286. if _get_expr_checker():
  287. if isinstance(value, RawTensor):
  288. active_module_tracer().checker.record_node2value(node, value)
  289. if isinstance(value, NodeMixin):
  290. active_module_tracer().checker.record_nodemixin(node, value)
  291. else:
  292. assert callable(node)
  293. n = node()
  294. assert isinstance(n, Node)
  295. if isinstance(value, RawTensor):
  296. cls._record_tensornode_property(n, value)
  297. if isinstance(value, NodeMixin):
  298. value._record_wrapped_nodes(n)
  299. setattr(value, "_NodeMixin__node", n)
  300. if _get_expr_checker():
  301. if isinstance(value, RawTensor):
  302. active_module_tracer().checker.record_node2value(n, value)
  303. if isinstance(value, NodeMixin):
  304. active_module_tracer().checker.record_nodemixin(n, value)
  305. @classmethod
  306. def wrap_safe(cls, value, node):
  307. assert isinstance(value, (NodeMixin, RawTensor))
  308. if isinstance(value, RawTensor):
  309. cls._record_tensornode_property(node, value)
  310. setattr(value, "_NodeMixin__node", node)
  311. if _get_expr_checker():
  312. if isinstance(value, RawTensor):
  313. active_module_tracer().checker.record_node2value(node, value)
  314. if isinstance(value, NodeMixin):
  315. active_module_tracer().checker.record_nodemixin(node, value)
  316. if isinstance(value, NodeMixin):
  317. value._record_wrapped_nodes(node)
  318. @classmethod
  319. def clear_node(cls, value):
  320. if hasattr(value, "_NodeMixin__node"):
  321. delattr(value, "_NodeMixin__node")
  322. @classmethod
  323. def get(cls, value, *default):
  324. return getattr(value, "_NodeMixin__node", *default)
  325. @classmethod
  326. def get_wrapped_type(cls, value):
  327. if isinstance(value, RawTensor):
  328. return TensorNode
  329. if isinstance(value, (Module, NodeMixin)):
  330. return ModuleNode
  331. return Node