diff --git a/imperative/python/megengine/experimental/traced_module/expr.py b/imperative/python/megengine/experimental/traced_module/expr.py index 5b6b1ed6..24bbee79 100644 --- a/imperative/python/megengine/experimental/traced_module/expr.py +++ b/imperative/python/megengine/experimental/traced_module/expr.py @@ -17,12 +17,14 @@ from typing import Callable, Dict, List from ...core._imperative_rt import OpDef from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing +from ...core.ops.builtin import FakeQuant from ...core.ops.special import Const from ...module import Module from ...tensor import Parameter, Tensor from .module_tracer import active_module_tracer, module_tracer from .node import ModuleNode, Node, NodeMixin, TensorNode -from .pytree import ArgsIndex, TreeDef, tree_flatten +from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten +from .serialization import get_opdef_state, load_opdef_from_state def rstrip(s: str, __chars: str): @@ -76,6 +78,7 @@ class Expr: node.users.append(self) else: assert node is None + assert _is_leaf(val) and _is_const_leaf(val) idx = len(self.inputs) + len(self.const_val) self.const_val.append((idx, val)) @@ -154,6 +157,11 @@ class Expr: return self._top_graph() return None + def __getstate__(self): + state = self.__dict__.copy() + state.pop("_top_graph", None) + return state + # expr: None (i.e. fake expression which is used to mark input) class Input(Expr): @@ -321,14 +329,36 @@ class Apply(Expr): ", ".join(str(i) for i in self.inputs), ) + def __getstate__(self): + state = super().__getstate__() + state["opdef"] = get_opdef_state(state["opdef"]) + return state + + def __setstate__(self, state): + state["opdef"] = load_opdef_from_state(state["opdef"]) + for k, v in state.items(): + setattr(self, k, v) + @classmethod def apply_module_trace_hook(cls, opdef, *inputs): for i in inputs: node = NodeMixin.get(i, None) if node is None: # capture as constant NodeMixin.wrap_safe(i, Constant.make(i)) - apply_node = cls.make(opdef) - apply_node.add_inputs(inputs) + + if isinstance(opdef, FakeQuant): + inp_nodes = [NodeMixin.get(inputs[0])] + for i in inputs[1:]: + node = Constant.make(i) + inp_nodes.append(node) + apply_node = cls.make(opdef) + for n in inp_nodes: + n.users.append(apply_node) + apply_node.inputs = inp_nodes + else: + apply_node = cls.make(opdef) + apply_node.add_inputs(inputs) + assert not apply_node.const_val unset_module_tracing() @@ -387,7 +417,7 @@ class Constant(Expr): super().__init__() assert isinstance(c, (RawTensor, Module)) if isinstance(c, Module): - assert module_tracer.is_builtin(c) + assert module_tracer.is_builtin(c) or c.is_qat self.value = c self.name = name self.inputs = [] @@ -395,6 +425,7 @@ class Constant(Expr): self.outputs = [ node_cls(self, name=name), ] + self.outputs[0]._name = name if name else "const_" + str(self._id) @classmethod def make(cls, *args, **kwargs): @@ -422,7 +453,7 @@ class Constant(Expr): ) def __getstate__(self): - state = self.__dict__.copy() + state = super().__getstate__() if isinstance(self.value, RawTensor): state["value"] = Tensor(self.value) return state diff --git a/imperative/python/megengine/experimental/traced_module/fake_quant.py b/imperative/python/megengine/experimental/traced_module/fake_quant.py new file mode 100644 index 00000000..d81d3c94 --- /dev/null +++ b/imperative/python/megengine/experimental/traced_module/fake_quant.py @@ -0,0 +1,48 @@ +from copy import deepcopy +from typing import Union + +from ...core.tensor.dtype import QuantDtypeMeta +from ...quantization.fake_quant import QParamsModuleMixin, _FakeQuantize +from ...quantization.utils import QParams, QuantMode, fake_quant_tensor + + +class FakeQuantize(_FakeQuantize, QParamsModuleMixin): + def __init__( + self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs + ): + super().__init__(dtype, enable, **kwargs) + self.qparams = None + + def fake_quant_forward(self, inp, qparams: QParams = None): + if qparams is None: + qparams = self.get_qparams() + assert ( + qparams.dtype_meta is self.dtype + ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format( + qparams.dtype_meta, self.dtype + ) + return fake_quant_tensor(inp, qparams) + + def get_qparams(self): + return self.qparams + + def set_qparams(self, qparams: QParams): + """ + :param qparams: used to set initial scale. + """ + if qparams.scale is None: + raise AssertionError("Can not get an initialized scale") + scale = qparams.scale + if qparams.dtype_meta is None: + qparams.dtype_meta = self.dtype + else: + assert ( + qparams.dtype_meta is self.dtype + ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format( + qparams.dtype_meta, self.dtype + ) + dtype_meta = qparams.dtype_meta + zero_point = qparams.zero_point + mode = qparams.mode + + self.qparams = QParams(mode, dtype_meta, scale, zero_point) diff --git a/imperative/python/megengine/experimental/traced_module/module_tracer.py b/imperative/python/megengine/experimental/traced_module/module_tracer.py index 221310ae..7ef693b5 100644 --- a/imperative/python/megengine/experimental/traced_module/module_tracer.py +++ b/imperative/python/megengine/experimental/traced_module/module_tracer.py @@ -12,6 +12,7 @@ from ... import Tensor from ... import functional as F from ...core.tensor.array_method import ArrayMethodMixin from ...module import Module +from ...module.qat import QATModule _active_module_tracer = None @@ -68,7 +69,7 @@ BUILTIN_ARRAY_METHOD = [ "__iand__", "__ior__", "__ixor__", - "T", + "transpose", "astype", "reshape", "_broadcast", @@ -180,6 +181,7 @@ class Patcher: self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) self.patch_method(Tensor, "detach", self.wrap_fn) self.patch_method(Tensor, "__new__", self.wrap_fn) + self.patch_method(QATModule, "_apply_fakequant_with_observer", self.wrap_fn) for i, j in self._builtin_functions: if id(i) not in self.visited_frames_ids: self.patch_function(i, j, self.wrap_fn) diff --git a/imperative/python/megengine/experimental/traced_module/node.py b/imperative/python/megengine/experimental/traced_module/node.py index 2c43c895..fb64a8bf 100644 --- a/imperative/python/megengine/experimental/traced_module/node.py +++ b/imperative/python/megengine/experimental/traced_module/node.py @@ -127,7 +127,7 @@ class TensorNode(Node): shape = None # type: Tuple[int] dtype = None # type: numpy.dtype - qparam = None + qparams = None device = None def __getstate__(self): @@ -135,7 +135,7 @@ class TensorNode(Node): "expr": self.expr, "users": self.users, "_id": self._id, - "qparam": self.qparam, + "qparams": self.qparams, "shape": self.shape, "dtype": self.dtype, "device": self.device, diff --git a/imperative/python/megengine/experimental/traced_module/pytree.py b/imperative/python/megengine/experimental/traced_module/pytree.py index 8a526a9a..8382adc8 100644 --- a/imperative/python/megengine/experimental/traced_module/pytree.py +++ b/imperative/python/megengine/experimental/traced_module/pytree.py @@ -155,10 +155,7 @@ def tree_flatten( assert is_leaf(values), values node = LeafDef(leaf_type(values)) if is_const_leaf(values): - if isinstance(values, np.ndarray): - node.const_val = str(values) - else: - node.const_val = values + node.const_val = values return [values,], node rst = [] @@ -232,9 +229,13 @@ class LeafDef(TreeDef): return leaves[0] def __eq__(self, other): + if isinstance(self.const_val, np.ndarray): + return self.type == other.type and (self.const_val == other.const_val).all() return self.type == other.type and self.const_val == other.const_val def __hash__(self): + if isinstance(self.const_val, np.ndarray): + return hash(tuple([self.type, str(self.const_val)])) return hash(tuple([self.type, self.const_val])) def __repr__(self): diff --git a/imperative/python/megengine/experimental/traced_module/traced_module.py b/imperative/python/megengine/experimental/traced_module/traced_module.py index bb55c001..bb5cce04 100644 --- a/imperative/python/megengine/experimental/traced_module/traced_module.py +++ b/imperative/python/megengine/experimental/traced_module/traced_module.py @@ -29,14 +29,20 @@ from ...core._imperative_rt.core2 import ( from ...core._trace_option import set_symbolic_shape from ...core.tensor.array_method import ArrayMethodMixin from ...module import Module -from ...quantization.fake_quant import LSQ, TQT, FakeQuantize +from ...module.qat import QATModule +from ...quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize from ...quantization.observer import ( ExponentialMovingAverageObserver, + HistogramObserver, MinMaxObserver, + Observer, + PassiveObserver, + SyncExponentialMovingAverageObserver, SyncMinMaxObserver, ) from ...tensor import Tensor from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input +from .fake_quant import FakeQuantize as TM_FakeQuant from .module_tracer import ( Patcher, active_module_tracer, @@ -613,7 +619,8 @@ def _wrapped_function(orig_func): if isinstance(i, (RawTensor, NodeMixin)): NodeMixin.wrap_safe(i, Constant.make(i)) meth_name = _get_meth_name(args[0], wrapped_fn) if args else None - if meth_name: + arg_type = args[0] if isinstance(args[0], type) else type(args[0]) + if meth_name and issubclass(arg_type, RawTensor): self = inputs[0] if meth_name == "__new__": if all([not isinstance(i, RawTensor) for i in inputs]): @@ -680,7 +687,15 @@ class TracedModuleBuilder(NodeMixin): self._mod = mod self._body = None self._is_top = is_top_module - self._is_builtin = module_tracer.is_builtin(mod) + self._is_builtin = ( + True + if isinstance(mod, (Observer, _FakeQuantize)) + else module_tracer.is_builtin(mod) + ) + if isinstance(self._mod, QATModule): + unset_module_tracing() + self._check_qat_module(self._mod) + set_module_tracing() self._argdef_graph_map = {} self._argdef_outdef_map = {} @@ -693,15 +708,65 @@ class TracedModuleBuilder(NodeMixin): dict(TracedModuleBuilder.__dict__), ) + def _check_qat_module(self, qat_module): + def isbuiltin(m): + return m is None or module_tracer.is_builtin(m) + + if qat_module.with_act: + act_observer = qat_module.act_observer + act_fakequant = qat_module.act_fake_quant + if not isbuiltin(act_observer) or not isbuiltin(act_fakequant): + qparams = ( + act_observer.get_qparams() + if hasattr(act_observer, "get_qparams") + else act_fakequant.get_qparams() + ) + dtype = ( + act_observer.dtype + if hasattr(act_observer, "dtype") + else act_fakequant.dtype + ) + qat_module.act_observer = None + qat_module.act_fake_quant = TM_FakeQuant(dtype) + qat_module.act_fake_quant.set_qparams(qparams) + + if qat_module.with_weight: + weight_observer = qat_module.weight_observer + weight_fakequant = qat_module.weight_fake_quant + if not isbuiltin(weight_observer) or not isbuiltin(weight_fakequant): + qparams = ( + weight_observer.get_qparams() + if hasattr(weight_observer, "get_qparams") + else weight_fakequant.get_qparams() + ) + dtype = ( + weight_observer.dtype + if hasattr(weight_observer, "dtype") + else weight_fakequant.dtype + ) + qat_module.weight_observer = None + qat_module.weight_fake_quant = TM_FakeQuant(dtype) + qat_module.weight_fake_quant.set_qparams(qparams) + def build(self): if self._is_builtin or isinstance(self._mod, TracedModule): + if module_tracer.is_builtin(self._mod) or isinstance( + self._mod, TracedModule + ): + mod_type = type(self._mod) + else: + assert isinstance(self._mod, (Observer, _FakeQuantize)) + mod_type = ( + Observer if isinstance(self._mod, Observer) else _FakeQuantize + ) for node in self.nodes: - node.module_type = type(self._mod) - # node._owner = weakref.ref(self._mod) + node.module_type = mod_type + return self._mod else: + is_qat = isinstance(self._mod, QATModule) traced_module = TracedModule( - self._is_top, self._argdef_graph_map, self._argdef_outdef_map + self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat ) for _, g in self._argdef_graph_map.items(): g.compile() @@ -712,6 +777,20 @@ class TracedModuleBuilder(NodeMixin): v = v.build() setattr(traced_module, k, v) + if isinstance(self._mod, QATModule): + unset_module_tracing() + traced_module.with_act = self._mod.with_act + traced_module.with_weight = self._mod.with_weight + if not hasattr(traced_module, "act_fake_quant"): + traced_module.act_fakequant = None + if not hasattr(traced_module, "act_observer"): + traced_module.act_observer = None + if not hasattr(traced_module, "weight_fake_quant"): + traced_module.weight_fakequant = None + if not hasattr(traced_module, "weight_observer"): + traced_module.weight_observer = None + set_module_tracing() + return traced_module def _record_wrapped_nodes(self, node): @@ -846,7 +925,8 @@ class TracedModuleBuilder(NodeMixin): attr = getattr(self._mod, name) if isinstance(attr, Module): attr = TracedModuleBuilder(attr) - setattr(self, name, attr) + if isinstance(attr, (Module, RawTensor)): + setattr(self, name, attr) NodeMixin.wrap( attr, lambda: GetAttr.make( @@ -1066,7 +1146,7 @@ class TracedModule(Module): argdef_graph_map = None argdef_outdef_map = None - def __init__(self, is_top, argdef_graph_map, argdef_outdef_map): + def __init__(self, is_top, argdef_graph_map, argdef_outdef_map, is_qat=False): super(TracedModule, self).__init__() self.argdef_graph_map = argdef_graph_map self.argdef_outdef_map = argdef_outdef_map @@ -1074,6 +1154,7 @@ class TracedModule(Module): self.watch_points = [] self.watch_node_value = {} self.end_points = [] + self.is_qat = is_qat def forward(self, *args, **kwargs): inputs, treedef = tree_flatten(((self, *args), kwargs)) @@ -1195,8 +1276,8 @@ class TracedModule(Module): ): if graph is not None and prefix_name and prefix_name[-1] != "_": prefix_name += "_" - if graph is None: - assert not isinstance(module, TracedModule) + if graph is None or module.is_qat: + assert not isinstance(module, TracedModule) or module.is_qat const = Constant(module, "self.%s" % module2name[id(module)]) m_node = call.inputs[0] if m_node.top_graph != active_module_tracer().current_scope(): @@ -1326,9 +1407,23 @@ def _register_all_builtin_module(): isclass(m[1]) and issubclass(m[1], M.Module) and m[1] is not M.Sequential + and m[1] is not M.ModuleList ): module_tracer.register_as_builtin(m[1]) + module_tracer.register_as_builtin(Observer) + module_tracer.register_as_builtin(MinMaxObserver) + module_tracer.register_as_builtin(SyncMinMaxObserver) + module_tracer.register_as_builtin(ExponentialMovingAverageObserver) + module_tracer.register_as_builtin(SyncExponentialMovingAverageObserver) + module_tracer.register_as_builtin(HistogramObserver) + module_tracer.register_as_builtin(PassiveObserver) + + module_tracer.register_as_builtin(LSQ) + module_tracer.register_as_builtin(TQT) + module_tracer.register_as_builtin(FakeQuantize) + module_tracer.register_as_builtin(TM_FakeQuant) + def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: """