You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import abc
  9. import copy
  10. import weakref
  11. from importlib import import_module
  12. from typing import Any, Dict, List, Tuple, Type
  13. import numpy
  14. from .. import get_logger
  15. from ..core._imperative_rt.core2 import Tensor as RawTensor
  16. from ..module import Module
  17. from ..quantization.utils import QParams
  18. from ..tensor import Tensor
  19. from .module_tracer import active_module_tracer
  20. from .tm_config import _get_expr_checker
  21. from .utils import _check_obj_attr
  22. logger = get_logger(__name__)
  23. class Node:
  24. r"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method.
  25. They are inputs/outputs of Expr (the operations on variables).
  26. """
  27. expr = None # type: Expr
  28. r"""The Expr which produces the Node."""
  29. __total_id = 0 # type: int
  30. _id = None # type: int
  31. _top_graph = None # type: weakref.ReferenceType
  32. _format_spec = "" # type: str
  33. def __init__(self, expr, name: str, qualname: str):
  34. self.expr = expr
  35. self.users = [] # List[Expr]
  36. self._id = Node.__total_id
  37. Node.__total_id += 1
  38. self._name = name
  39. self._qualname = qualname
  40. self.actual_node = [] # type: List[Node]
  41. def __repr__(self):
  42. format_spec = Node._format_spec
  43. return self.__format__(format_spec)
  44. def __format__(self, format_spec: str) -> str:
  45. if not format_spec:
  46. format_spec = Node._format_spec
  47. name = self._name
  48. if name is None:
  49. name = ""
  50. if format_spec in ["i", "p", "ip", "pi"]:
  51. if "p" in format_spec:
  52. prefix_name = self.top_graph._name
  53. name = "{}_{}".format(prefix_name, name)
  54. if "i" in format_spec:
  55. name = "%{}_{}".format(self._id, name)
  56. return name
  57. else:
  58. return name if name else ("%d" % self._id)
  59. @property
  60. def name(self):
  61. r"""Return the name of this Node."""
  62. return self._name
  63. @name.setter
  64. def name(self, new_name: str):
  65. r"""Set a new name to this Node."""
  66. graph = self.top_graph
  67. assert graph is not None, "The parent graph of this Node cannot be None."
  68. assert graph._namespace.used_names.get(new_name, None) is None, (
  69. "The name(%s) is already in use. Please try a different one again."
  70. % (new_name)
  71. )
  72. graph._namespace.unassociate_name_with_obj(self)
  73. self._name = graph._namespace.create_unique_name(new_name, self)
  74. @property
  75. def qualname(self):
  76. r"""Get the `qualname` of this Node. The `qualname` can be used to get the
  77. submodule from the traced Module or Module.
  78. Example:
  79. .. code-block::
  80. import megengine.module as M
  81. import megengine.functional as F
  82. import megengine.traced_module as tm
  83. import megengine as mge
  84. class block(M.Module):
  85. def __init__(self):
  86. super().__init__()
  87. self.param = mge.Tensor([1.])
  88. self.relu = M.ReLU()
  89. def forward(self, x):
  90. x = x + self.param
  91. return self.relu(F.relu(x))
  92. class module(M.Module):
  93. def __init__(self):
  94. super().__init__()
  95. self.block = block()
  96. def forward(self, x):
  97. x = self.block(x)
  98. return x
  99. net = module()
  100. traced_net = tm.trace_module(net, mge.Tensor([0.]))
  101. traced_net = traced_net.flatten()
  102. out_node = traced_net.graph.outputs[0]
  103. # qualname : "module.block.relu.[out]"
  104. qualname = out_node.qualname
  105. # qualname : "block.relu"
  106. qualname = qualname.split(".", 1)[-1].rsplit(".", 1)[0]
  107. assert qualname in list(map(lambda x: x[0], net.named_modules()))
  108. assert qualname in list(map(lambda x: x[0], traced_net.named_modules()))
  109. """
  110. return self._qualname
  111. @property
  112. def top_graph(self):
  113. r"""Get the parent graph of this Node."""
  114. if self._top_graph:
  115. return self._top_graph()
  116. return None
  117. @classmethod
  118. def _set_format_spec(cls, str):
  119. old_format_spec = cls._format_spec
  120. cls._format_spec = str
  121. return old_format_spec
  122. @classmethod
  123. def _get_next_id(cls):
  124. return cls.__total_id
  125. @classmethod
  126. def _set_next_id(cls, id: int = 0):
  127. assert isinstance(id, int)
  128. cls.__total_id = id
  129. def __copy__(self):
  130. cls = self.__class__
  131. result = cls.__new__(cls)
  132. result.__dict__.update(self.__dict__)
  133. return result
  134. def __deepcopy__(self, memo):
  135. cls = self.__class__
  136. result = cls.__new__(cls)
  137. state = {}
  138. memo[id(self)] = result
  139. for k, v in self.__dict__.items():
  140. if not isinstance(v, weakref.ReferenceType) and k != "actual_node":
  141. state[k] = copy.deepcopy(v, memo)
  142. result.__dict__.update(state)
  143. return result
  144. class ModuleNode(Node):
  145. r"""``ModuleNode`` represents the Module objects."""
  146. module_type = Module # type: Type[Module]
  147. r"""The type of the Module correspending to the ModuleNode."""
  148. _owner = None # type: weakref.ReferenceType
  149. def __init__(self, expr, name: str = None, qualname: str = None):
  150. super().__init__(expr, name, qualname)
  151. def __getstate__(self):
  152. state = {
  153. "expr": self.expr,
  154. "users": self.users,
  155. "_id": self._id,
  156. "_name": self._name,
  157. "_qualname": self._qualname,
  158. "module_type": (self.module_type.__module__, self.module_type.__qualname__),
  159. }
  160. _check_obj_attr(state)
  161. return state
  162. def __setstate__(self, state):
  163. if "_orig_name" in state:
  164. state["_qualname"] = state.pop("_orig_name")
  165. self.__dict__.update(state)
  166. try:
  167. if isinstance(self.module_type, tuple):
  168. mname, classname = self.module_type
  169. mtype = getattr(import_module(mname), classname)
  170. self.module_type = mtype
  171. except Exception:
  172. pass
  173. @property
  174. def owner(self):
  175. r"""Get the ``Module`` corresponding to this ``ModuleNode``.
  176. """
  177. if self._owner:
  178. return self._owner()
  179. return None
  180. class TensorNode(Node):
  181. r"""``TensorNode`` represents the Tensor objects."""
  182. _shape = None # type: Tuple[int]
  183. _dtype = None # type: numpy.dtype
  184. _qparams = None # type: QParams
  185. _device = None
  186. _value = None # type: Tensor
  187. def __init__(
  188. self,
  189. expr,
  190. name: str = None,
  191. qualname: str = None,
  192. shape: Tuple[int] = None,
  193. dtype: numpy.dtype = None,
  194. qparams: QParams = None,
  195. ):
  196. super().__init__(expr, name, qualname)
  197. self._shape = shape
  198. self._dtype = dtype
  199. self._qparams = qparams
  200. def __getstate__(self):
  201. state = {
  202. "expr": self.expr,
  203. "users": self.users,
  204. "_id": self._id,
  205. "_qparams": self._qparams,
  206. "_shape": self._shape,
  207. "_dtype": self._dtype,
  208. "_device": self._device,
  209. "_name": self._name,
  210. "_qualname": self._qualname,
  211. }
  212. _check_obj_attr(state)
  213. return state
  214. def __setstate__(self, state):
  215. if "_orig_name" in state:
  216. qualname = state.pop("_orig_name")
  217. modulepath, comma, qualname = qualname.rpartition(".")
  218. expr_name = state["expr"].__class__.__name__
  219. if expr_name not in ["GetAttr"]:
  220. qualname = "[{}]".format(qualname)
  221. if comma:
  222. qualname = "{}.{}".format(modulepath, qualname)
  223. state["_qualname"] = qualname
  224. self.__dict__.update(state)
  225. @property
  226. def shape(self):
  227. r"""Get the shape of this Node."""
  228. return self._shape
  229. @shape.setter
  230. def shape(self, shape):
  231. self._shape = shape
  232. @property
  233. def dtype(self):
  234. r"""Get the dtype of this Node."""
  235. return self._dtype
  236. @dtype.setter
  237. def dtype(self, dtype):
  238. self._dtype = dtype
  239. @property
  240. def device(self):
  241. r"""Get the device of this Node pointed Tensor."""
  242. return self._device
  243. @device.setter
  244. def device(self, device):
  245. self._device = device
  246. @property
  247. def qparams(self):
  248. r"""Get the :class:`QParams` of this Node."""
  249. return self._qparams
  250. @qparams.setter
  251. def qparams(self, qparams):
  252. self._qparams = qparams
  253. @property
  254. def value(self):
  255. r"""Get the bound Tensor of this Node."""
  256. return self._value
  257. @value.setter
  258. def value(self, value):
  259. r"""Bind a :class:`Tensor` to this Node."""
  260. if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
  261. setattr(value, "_NodeMixin__node", None)
  262. self._value = value
  263. class NodeMixin(abc.ABC):
  264. __node = None
  265. @abc.abstractmethod
  266. def _record_wrapped_nodes(self, node):
  267. # record the nodes which had been bound to this NodeMixin
  268. pass
  269. @classmethod
  270. def _record_tensornode_property(cls, node, value):
  271. assert isinstance(node, TensorNode)
  272. assert isinstance(value, RawTensor)
  273. if isinstance(value, RawTensor):
  274. try:
  275. node._dtype = value.dtype
  276. except RuntimeError:
  277. node._dtype = None
  278. node._shape = (
  279. value._tuple_shape if isinstance(value, Tensor) else value.shape
  280. )
  281. node._device = value.device
  282. if hasattr(value, "_qparams") and value._qparams is not None:
  283. node._qparams = value.qparams
  284. @classmethod
  285. def wrap(cls, value, node):
  286. if isinstance(value, (NodeMixin, RawTensor)):
  287. if isinstance(node, Node):
  288. if isinstance(value, RawTensor):
  289. cls._record_tensornode_property(node, value)
  290. if isinstance(value, NodeMixin):
  291. value._record_wrapped_nodes(node)
  292. setattr(value, "_NodeMixin__node", node)
  293. if _get_expr_checker():
  294. if isinstance(value, RawTensor):
  295. active_module_tracer().checker.record_node2value(node, value)
  296. if isinstance(value, NodeMixin):
  297. active_module_tracer().checker.record_nodemixin(node, value)
  298. else:
  299. assert callable(node)
  300. n = node()
  301. assert isinstance(n, Node)
  302. if isinstance(value, RawTensor):
  303. cls._record_tensornode_property(n, value)
  304. if isinstance(value, NodeMixin):
  305. value._record_wrapped_nodes(n)
  306. setattr(value, "_NodeMixin__node", n)
  307. if _get_expr_checker():
  308. if isinstance(value, RawTensor):
  309. active_module_tracer().checker.record_node2value(n, value)
  310. if isinstance(value, NodeMixin):
  311. active_module_tracer().checker.record_nodemixin(n, value)
  312. @classmethod
  313. def wrap_safe(cls, value, node):
  314. assert isinstance(value, (NodeMixin, RawTensor))
  315. if isinstance(value, RawTensor):
  316. cls._record_tensornode_property(node, value)
  317. setattr(value, "_NodeMixin__node", node)
  318. if _get_expr_checker():
  319. if isinstance(value, RawTensor):
  320. active_module_tracer().checker.record_node2value(node, value)
  321. if isinstance(value, NodeMixin):
  322. active_module_tracer().checker.record_nodemixin(node, value)
  323. if isinstance(value, NodeMixin):
  324. value._record_wrapped_nodes(node)
  325. @classmethod
  326. def clear_node(cls, value):
  327. if hasattr(value, "_NodeMixin__node"):
  328. delattr(value, "_NodeMixin__node")
  329. @classmethod
  330. def get(cls, value, *default):
  331. return getattr(value, "_NodeMixin__node", *default)
  332. @classmethod
  333. def get_wrapped_type(cls, value):
  334. if isinstance(value, RawTensor):
  335. return TensorNode
  336. if isinstance(value, (Module, NodeMixin)):
  337. return ModuleNode
  338. return Node