|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import abc
- import copy
- import weakref
- from importlib import import_module
- from typing import Any, Dict, List, Tuple, Type
-
- import numpy
-
- from .. import get_logger
- from ..core._imperative_rt.core2 import Tensor as RawTensor
- from ..module import Module
- from ..quantization.utils import QParams
- from ..tensor import Tensor
- from .utils import _check_obj_attr
-
- logger = get_logger(__name__)
-
-
- class Node:
- r"""``Node`` represents the variables (``Tensor``, ``Module``) used in Module's forward method.
- They are inputs/outputs of Expr (the operations on variables).
- """
-
- expr = None # type: Expr
- r"""The Expr which produces the Node."""
- __total_id = 0 # type: int
- _id = None # type: int
- _top_graph = None # type: weakref.ReferenceType
- _format_spec = "" # type: str
-
- def __init__(self, expr, name: str, qualname: str):
- self.expr = expr
- self.users = [] # List[Expr]
- self._id = Node.__total_id
- Node.__total_id += 1
- self._name = name
- self._qualname = qualname
- self.actual_node = [] # type: List[Node]
-
- def __repr__(self):
- format_spec = Node._format_spec
- return self.__format__(format_spec)
-
- def __format__(self, format_spec: str) -> str:
- if not format_spec:
- format_spec = Node._format_spec
- name = self._name
- if name is None:
- name = ""
- if format_spec in ["i", "p", "ip", "pi"]:
- if "p" in format_spec:
- prefix_name = self.top_graph._name
- name = "{}_{}".format(prefix_name, name)
- if "i" in format_spec:
- name = "%{}_{}".format(self._id, name)
- return name
- else:
- return name if name else ("%d" % self._id)
-
- @property
- def name(self):
- r"""Return the name of this Node."""
- return self._name
-
- @name.setter
- def name(self, new_name: str):
- r"""Set a new name to this Node."""
- graph = self.top_graph
- assert graph is not None, "The parent graph of this Node cannot be None."
- assert graph._namespace.used_names.get(new_name, None) is None, (
- "The name(%s) is already in use. Please try a different one again."
- % (new_name)
- )
- graph._namespace.unassociate_name_with_obj(self)
- self._name = graph._namespace.create_unique_name(new_name, self)
-
- @property
- def qualname(self):
- r"""Get the `qualname` of this Node. The `qualname` can be used to get the
- submodule from the traced Module or Module.
-
- Example:
- .. code-block::
-
- import megengine.module as M
- import megengine.functional as F
- import megengine.traced_module as tm
- import megengine as mge
-
- class block(M.Module):
- def __init__(self):
- super().__init__()
- self.param = mge.Tensor([1.])
- self.relu = M.ReLU()
-
- def forward(self, x):
- x = x + self.param
- return self.relu(F.relu(x))
-
- class module(M.Module):
- def __init__(self):
- super().__init__()
- self.block = block()
-
- def forward(self, x):
- x = self.block(x)
- return x
-
- net = module()
- traced_net = tm.trace_module(net, mge.Tensor([0.]))
- traced_net = traced_net.flatten()
- out_node = traced_net.graph.outputs[0]
-
- # qualname : "module.block.relu.[out]"
- qualname = out_node.qualname
- # qualname : "block.relu"
- qualname = qualname.split(".", 1)[-1].rsplit(".", 1)[0]
-
- assert qualname in list(map(lambda x: x[0], net.named_modules()))
- assert qualname in list(map(lambda x: x[0], traced_net.named_modules()))
- """
- return self._qualname
-
- @property
- def top_graph(self):
- r"""Get the parent graph of this Node."""
- if self._top_graph:
- return self._top_graph()
- return None
-
- @classmethod
- def _set_format_spec(cls, str):
- old_format_spec = cls._format_spec
- cls._format_spec = str
- return old_format_spec
-
- @classmethod
- def _get_next_id(cls):
- return cls.__total_id
-
- @classmethod
- def _set_next_id(cls, id: int = 0):
- assert isinstance(id, int)
- cls.__total_id = id
-
- def __copy__(self):
- cls = self.__class__
- result = cls.__new__(cls)
- result.__dict__.update(self.__dict__)
- return result
-
- def __deepcopy__(self, memo):
- cls = self.__class__
- result = cls.__new__(cls)
- state = {}
- memo[id(self)] = result
- for k, v in self.__dict__.items():
- if not isinstance(v, weakref.ReferenceType) and k != "actual_node":
- state[k] = copy.deepcopy(v, memo)
- result.__dict__.update(state)
- return result
-
-
- class ModuleNode(Node):
- r"""``ModuleNode`` represents the Module objects."""
-
- module_type = Module # type: Type[Module]
- r"""The type of the Module correspending to the ModuleNode."""
- _owner = None # type: weakref.ReferenceType
-
- def __init__(self, expr, name: str = None, qualname: str = None):
- super().__init__(expr, name, qualname)
-
- def __getstate__(self):
- state = {
- "expr": self.expr,
- "users": self.users,
- "_id": self._id,
- "_name": self._name,
- "_qualname": self._qualname,
- "module_type": (self.module_type.__module__, self.module_type.__qualname__),
- }
- _check_obj_attr(state)
- return state
-
- def __setstate__(self, state):
- if "_orig_name" in state:
- state["_qualname"] = state.pop("_orig_name")
- self.__dict__.update(state)
- try:
- if isinstance(self.module_type, tuple):
- mname, classname = self.module_type
- mtype = getattr(import_module(mname), classname)
- self.module_type = mtype
- except Exception:
- pass
-
- @property
- def owner(self):
- r"""Get the ``Module`` corresponding to this ``ModuleNode``.
- """
- if self._owner:
- return self._owner()
- return None
-
-
- class TensorNode(Node):
- r"""``TensorNode`` represents the Tensor objects."""
-
- _shape = None # type: Tuple[int]
- _dtype = None # type: numpy.dtype
- _qparams = None # type: QParams
- _device = None
- _value = None # type: Tensor
-
- def __init__(
- self,
- expr,
- name: str = None,
- qualname: str = None,
- shape: Tuple[int] = None,
- dtype: numpy.dtype = None,
- qparams: QParams = None,
- ):
- super().__init__(expr, name, qualname)
- self._shape = shape
- self._dtype = dtype
- self._qparams = qparams
-
- def __getstate__(self):
- state = {
- "expr": self.expr,
- "users": self.users,
- "_id": self._id,
- "_qparams": self._qparams,
- "_shape": self._shape,
- "_dtype": self._dtype,
- "_device": self._device,
- "_name": self._name,
- "_qualname": self._qualname,
- }
- _check_obj_attr(state)
- return state
-
- def __setstate__(self, state):
- if "_orig_name" in state:
- qualname = state.pop("_orig_name")
- modulepath, comma, qualname = qualname.rpartition(".")
- expr_name = state["expr"].__class__.__name__
- if expr_name not in ["GetAttr"]:
- qualname = "[{}]".format(qualname)
- if comma:
- qualname = "{}.{}".format(modulepath, qualname)
- state["_qualname"] = qualname
- self.__dict__.update(state)
-
- @property
- def shape(self):
- r"""Get the shape of this Node."""
- return self._shape
-
- @shape.setter
- def shape(self, shape):
- self._shape = shape
-
- @property
- def dtype(self):
- r"""Get the dtype of this Node."""
- return self._dtype
-
- @dtype.setter
- def dtype(self, dtype):
- self._dtype = dtype
-
- @property
- def device(self):
- r"""Get the device of this Node pointed Tensor."""
- return self._device
-
- @device.setter
- def device(self, device):
- self._device = device
-
- @property
- def qparams(self):
- r"""Get the :class:`QParams` of this Node."""
- return self._qparams
-
- @qparams.setter
- def qparams(self, qparams):
- self._qparams = qparams
-
- @property
- def value(self):
- r"""Get the bound Tensor of this Node."""
- return self._value
-
- @value.setter
- def value(self, value):
- r"""Bind a :class:`Tensor` to this Node."""
- if isinstance(value, RawTensor) and NodeMixin.get(value, None) is not None:
- setattr(value, "_NodeMixin__node", None)
- self._value = value
-
-
- class NodeMixin(abc.ABC):
- __node = None
-
- @abc.abstractmethod
- def _record_wrapped_nodes(self, node):
- # record the nodes which had been bound to this NodeMixin
- pass
-
- @classmethod
- def _record_tensornode_property(cls, node, value):
- assert isinstance(node, TensorNode)
- assert isinstance(value, RawTensor)
- if isinstance(value, RawTensor):
- try:
- node._dtype = value.dtype
- except RuntimeError:
- node._dtype = None
- node._shape = (
- value._tuple_shape if isinstance(value, Tensor) else value.shape
- )
- node._device = value.device
- if hasattr(value, "_qparams") and value._qparams is not None:
- node._qparams = value.qparams
-
- @classmethod
- def wrap(cls, value, node):
- if isinstance(value, (NodeMixin, RawTensor)):
- if isinstance(node, Node):
- if isinstance(value, RawTensor):
- cls._record_tensornode_property(node, value)
- if isinstance(value, NodeMixin):
- value._record_wrapped_nodes(node)
- setattr(value, "_NodeMixin__node", node)
- else:
- assert callable(node)
- n = node()
- assert isinstance(n, Node)
- if isinstance(value, RawTensor):
- cls._record_tensornode_property(n, value)
- if isinstance(value, NodeMixin):
- value._record_wrapped_nodes(n)
- setattr(value, "_NodeMixin__node", n)
-
- @classmethod
- def wrap_safe(cls, value, node):
- assert isinstance(value, (NodeMixin, RawTensor))
- if isinstance(value, RawTensor):
- cls._record_tensornode_property(node, value)
- setattr(value, "_NodeMixin__node", node)
- if isinstance(value, NodeMixin):
- value._record_wrapped_nodes(node)
-
- @classmethod
- def get(cls, value, *default):
- return getattr(value, "_NodeMixin__node", *default)
-
- @classmethod
- def get_wrapped_type(cls, value):
- if isinstance(value, RawTensor):
- return TensorNode
- if isinstance(value, (Module, NodeMixin)):
- return ModuleNode
- return Node
|