Browse Source

feat(traced_module): support trace custom qat module

GitOrigin-RevId: 49f70a5f46
release-1.6
Megvii Engine Team 3 years ago
parent
commit
b3d0affab1
6 changed files with 199 additions and 22 deletions
  1. +36
    -5
      imperative/python/megengine/experimental/traced_module/expr.py
  2. +48
    -0
      imperative/python/megengine/experimental/traced_module/fake_quant.py
  3. +3
    -1
      imperative/python/megengine/experimental/traced_module/module_tracer.py
  4. +2
    -2
      imperative/python/megengine/experimental/traced_module/node.py
  5. +5
    -4
      imperative/python/megengine/experimental/traced_module/pytree.py
  6. +105
    -10
      imperative/python/megengine/experimental/traced_module/traced_module.py

+ 36
- 5
imperative/python/megengine/experimental/traced_module/expr.py View File

@@ -17,12 +17,14 @@ from typing import Callable, Dict, List
from ...core._imperative_rt import OpDef from ...core._imperative_rt import OpDef
from ...core._imperative_rt.core2 import Tensor as RawTensor from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ...core.ops.builtin import FakeQuant
from ...core.ops.special import Const from ...core.ops.special import Const
from ...module import Module from ...module import Module
from ...tensor import Parameter, Tensor from ...tensor import Parameter, Tensor
from .module_tracer import active_module_tracer, module_tracer from .module_tracer import active_module_tracer, module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode from .node import ModuleNode, Node, NodeMixin, TensorNode
from .pytree import ArgsIndex, TreeDef, tree_flatten
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten
from .serialization import get_opdef_state, load_opdef_from_state




def rstrip(s: str, __chars: str): def rstrip(s: str, __chars: str):
@@ -76,6 +78,7 @@ class Expr:
node.users.append(self) node.users.append(self)
else: else:
assert node is None assert node is None
assert _is_leaf(val) and _is_const_leaf(val)
idx = len(self.inputs) + len(self.const_val) idx = len(self.inputs) + len(self.const_val)
self.const_val.append((idx, val)) self.const_val.append((idx, val))


@@ -154,6 +157,11 @@ class Expr:
return self._top_graph() return self._top_graph()
return None return None


def __getstate__(self):
state = self.__dict__.copy()
state.pop("_top_graph", None)
return state



# expr: None (i.e. fake expression which is used to mark input) # expr: None (i.e. fake expression which is used to mark input)
class Input(Expr): class Input(Expr):
@@ -321,14 +329,36 @@ class Apply(Expr):
", ".join(str(i) for i in self.inputs), ", ".join(str(i) for i in self.inputs),
) )


def __getstate__(self):
state = super().__getstate__()
state["opdef"] = get_opdef_state(state["opdef"])
return state

def __setstate__(self, state):
state["opdef"] = load_opdef_from_state(state["opdef"])
for k, v in state.items():
setattr(self, k, v)

@classmethod @classmethod
def apply_module_trace_hook(cls, opdef, *inputs): def apply_module_trace_hook(cls, opdef, *inputs):
for i in inputs: for i in inputs:
node = NodeMixin.get(i, None) node = NodeMixin.get(i, None)
if node is None: # capture as constant if node is None: # capture as constant
NodeMixin.wrap_safe(i, Constant.make(i)) NodeMixin.wrap_safe(i, Constant.make(i))
apply_node = cls.make(opdef)
apply_node.add_inputs(inputs)

if isinstance(opdef, FakeQuant):
inp_nodes = [NodeMixin.get(inputs[0])]
for i in inputs[1:]:
node = Constant.make(i)
inp_nodes.append(node)
apply_node = cls.make(opdef)
for n in inp_nodes:
n.users.append(apply_node)
apply_node.inputs = inp_nodes
else:
apply_node = cls.make(opdef)
apply_node.add_inputs(inputs)

assert not apply_node.const_val assert not apply_node.const_val


unset_module_tracing() unset_module_tracing()
@@ -387,7 +417,7 @@ class Constant(Expr):
super().__init__() super().__init__()
assert isinstance(c, (RawTensor, Module)) assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module): if isinstance(c, Module):
assert module_tracer.is_builtin(c)
assert module_tracer.is_builtin(c) or c.is_qat
self.value = c self.value = c
self.name = name self.name = name
self.inputs = [] self.inputs = []
@@ -395,6 +425,7 @@ class Constant(Expr):
self.outputs = [ self.outputs = [
node_cls(self, name=name), node_cls(self, name=name),
] ]
self.outputs[0]._name = name if name else "const_" + str(self._id)


@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
@@ -422,7 +453,7 @@ class Constant(Expr):
) )


def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy()
state = super().__getstate__()
if isinstance(self.value, RawTensor): if isinstance(self.value, RawTensor):
state["value"] = Tensor(self.value) state["value"] = Tensor(self.value)
return state return state

+ 48
- 0
imperative/python/megengine/experimental/traced_module/fake_quant.py View File

@@ -0,0 +1,48 @@
from copy import deepcopy
from typing import Union

from ...core.tensor.dtype import QuantDtypeMeta
from ...quantization.fake_quant import QParamsModuleMixin, _FakeQuantize
from ...quantization.utils import QParams, QuantMode, fake_quant_tensor


class FakeQuantize(_FakeQuantize, QParamsModuleMixin):
def __init__(
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
):
super().__init__(dtype, enable, **kwargs)
self.qparams = None

def fake_quant_forward(self, inp, qparams: QParams = None):
if qparams is None:
qparams = self.get_qparams()
assert (
qparams.dtype_meta is self.dtype
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
qparams.dtype_meta, self.dtype
)
return fake_quant_tensor(inp, qparams)

def get_qparams(self):
return self.qparams

def set_qparams(self, qparams: QParams):
"""
:param qparams: used to set initial scale.
"""
if qparams.scale is None:
raise AssertionError("Can not get an initialized scale")
scale = qparams.scale
if qparams.dtype_meta is None:
qparams.dtype_meta = self.dtype
else:
assert (
qparams.dtype_meta is self.dtype
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
qparams.dtype_meta, self.dtype
)
dtype_meta = qparams.dtype_meta
zero_point = qparams.zero_point
mode = qparams.mode

self.qparams = QParams(mode, dtype_meta, scale, zero_point)

+ 3
- 1
imperative/python/megengine/experimental/traced_module/module_tracer.py View File

@@ -12,6 +12,7 @@ from ... import Tensor
from ... import functional as F from ... import functional as F
from ...core.tensor.array_method import ArrayMethodMixin from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module from ...module import Module
from ...module.qat import QATModule


_active_module_tracer = None _active_module_tracer = None


@@ -68,7 +69,7 @@ BUILTIN_ARRAY_METHOD = [
"__iand__", "__iand__",
"__ior__", "__ior__",
"__ixor__", "__ixor__",
"T",
"transpose",
"astype", "astype",
"reshape", "reshape",
"_broadcast", "_broadcast",
@@ -180,6 +181,7 @@ class Patcher:
self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) self.patch_method(ArrayMethodMixin, meth, self.wrap_fn)
self.patch_method(Tensor, "detach", self.wrap_fn) self.patch_method(Tensor, "detach", self.wrap_fn)
self.patch_method(Tensor, "__new__", self.wrap_fn) self.patch_method(Tensor, "__new__", self.wrap_fn)
self.patch_method(QATModule, "_apply_fakequant_with_observer", self.wrap_fn)
for i, j in self._builtin_functions: for i, j in self._builtin_functions:
if id(i) not in self.visited_frames_ids: if id(i) not in self.visited_frames_ids:
self.patch_function(i, j, self.wrap_fn) self.patch_function(i, j, self.wrap_fn)


+ 2
- 2
imperative/python/megengine/experimental/traced_module/node.py View File

@@ -127,7 +127,7 @@ class TensorNode(Node):


shape = None # type: Tuple[int] shape = None # type: Tuple[int]
dtype = None # type: numpy.dtype dtype = None # type: numpy.dtype
qparam = None
qparams = None
device = None device = None


def __getstate__(self): def __getstate__(self):
@@ -135,7 +135,7 @@ class TensorNode(Node):
"expr": self.expr, "expr": self.expr,
"users": self.users, "users": self.users,
"_id": self._id, "_id": self._id,
"qparam": self.qparam,
"qparams": self.qparams,
"shape": self.shape, "shape": self.shape,
"dtype": self.dtype, "dtype": self.dtype,
"device": self.device, "device": self.device,


+ 5
- 4
imperative/python/megengine/experimental/traced_module/pytree.py View File

@@ -155,10 +155,7 @@ def tree_flatten(
assert is_leaf(values), values assert is_leaf(values), values
node = LeafDef(leaf_type(values)) node = LeafDef(leaf_type(values))
if is_const_leaf(values): if is_const_leaf(values):
if isinstance(values, np.ndarray):
node.const_val = str(values)
else:
node.const_val = values
node.const_val = values
return [values,], node return [values,], node


rst = [] rst = []
@@ -232,9 +229,13 @@ class LeafDef(TreeDef):
return leaves[0] return leaves[0]


def __eq__(self, other): def __eq__(self, other):
if isinstance(self.const_val, np.ndarray):
return self.type == other.type and (self.const_val == other.const_val).all()
return self.type == other.type and self.const_val == other.const_val return self.type == other.type and self.const_val == other.const_val


def __hash__(self): def __hash__(self):
if isinstance(self.const_val, np.ndarray):
return hash(tuple([self.type, str(self.const_val)]))
return hash(tuple([self.type, self.const_val])) return hash(tuple([self.type, self.const_val]))


def __repr__(self): def __repr__(self):


+ 105
- 10
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -29,14 +29,20 @@ from ...core._imperative_rt.core2 import (
from ...core._trace_option import set_symbolic_shape from ...core._trace_option import set_symbolic_shape
from ...core.tensor.array_method import ArrayMethodMixin from ...core.tensor.array_method import ArrayMethodMixin
from ...module import Module from ...module import Module
from ...quantization.fake_quant import LSQ, TQT, FakeQuantize
from ...module.qat import QATModule
from ...quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize
from ...quantization.observer import ( from ...quantization.observer import (
ExponentialMovingAverageObserver, ExponentialMovingAverageObserver,
HistogramObserver,
MinMaxObserver, MinMaxObserver,
Observer,
PassiveObserver,
SyncExponentialMovingAverageObserver,
SyncMinMaxObserver, SyncMinMaxObserver,
) )
from ...tensor import Tensor from ...tensor import Tensor
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input
from .fake_quant import FakeQuantize as TM_FakeQuant
from .module_tracer import ( from .module_tracer import (
Patcher, Patcher,
active_module_tracer, active_module_tracer,
@@ -613,7 +619,8 @@ def _wrapped_function(orig_func):
if isinstance(i, (RawTensor, NodeMixin)): if isinstance(i, (RawTensor, NodeMixin)):
NodeMixin.wrap_safe(i, Constant.make(i)) NodeMixin.wrap_safe(i, Constant.make(i))
meth_name = _get_meth_name(args[0], wrapped_fn) if args else None meth_name = _get_meth_name(args[0], wrapped_fn) if args else None
if meth_name:
arg_type = args[0] if isinstance(args[0], type) else type(args[0])
if meth_name and issubclass(arg_type, RawTensor):
self = inputs[0] self = inputs[0]
if meth_name == "__new__": if meth_name == "__new__":
if all([not isinstance(i, RawTensor) for i in inputs]): if all([not isinstance(i, RawTensor) for i in inputs]):
@@ -680,7 +687,15 @@ class TracedModuleBuilder(NodeMixin):
self._mod = mod self._mod = mod
self._body = None self._body = None
self._is_top = is_top_module self._is_top = is_top_module
self._is_builtin = module_tracer.is_builtin(mod)
self._is_builtin = (
True
if isinstance(mod, (Observer, _FakeQuantize))
else module_tracer.is_builtin(mod)
)
if isinstance(self._mod, QATModule):
unset_module_tracing()
self._check_qat_module(self._mod)
set_module_tracing()
self._argdef_graph_map = {} self._argdef_graph_map = {}
self._argdef_outdef_map = {} self._argdef_outdef_map = {}


@@ -693,15 +708,65 @@ class TracedModuleBuilder(NodeMixin):
dict(TracedModuleBuilder.__dict__), dict(TracedModuleBuilder.__dict__),
) )


def _check_qat_module(self, qat_module):
def isbuiltin(m):
return m is None or module_tracer.is_builtin(m)

if qat_module.with_act:
act_observer = qat_module.act_observer
act_fakequant = qat_module.act_fake_quant
if not isbuiltin(act_observer) or not isbuiltin(act_fakequant):
qparams = (
act_observer.get_qparams()
if hasattr(act_observer, "get_qparams")
else act_fakequant.get_qparams()
)
dtype = (
act_observer.dtype
if hasattr(act_observer, "dtype")
else act_fakequant.dtype
)
qat_module.act_observer = None
qat_module.act_fake_quant = TM_FakeQuant(dtype)
qat_module.act_fake_quant.set_qparams(qparams)

if qat_module.with_weight:
weight_observer = qat_module.weight_observer
weight_fakequant = qat_module.weight_fake_quant
if not isbuiltin(weight_observer) or not isbuiltin(weight_fakequant):
qparams = (
weight_observer.get_qparams()
if hasattr(weight_observer, "get_qparams")
else weight_fakequant.get_qparams()
)
dtype = (
weight_observer.dtype
if hasattr(weight_observer, "dtype")
else weight_fakequant.dtype
)
qat_module.weight_observer = None
qat_module.weight_fake_quant = TM_FakeQuant(dtype)
qat_module.weight_fake_quant.set_qparams(qparams)

def build(self): def build(self):
if self._is_builtin or isinstance(self._mod, TracedModule): if self._is_builtin or isinstance(self._mod, TracedModule):
if module_tracer.is_builtin(self._mod) or isinstance(
self._mod, TracedModule
):
mod_type = type(self._mod)
else:
assert isinstance(self._mod, (Observer, _FakeQuantize))
mod_type = (
Observer if isinstance(self._mod, Observer) else _FakeQuantize
)
for node in self.nodes: for node in self.nodes:
node.module_type = type(self._mod)
# node._owner = weakref.ref(self._mod)
node.module_type = mod_type
return self._mod return self._mod
else: else:
is_qat = isinstance(self._mod, QATModule)
traced_module = TracedModule( traced_module = TracedModule(
self._is_top, self._argdef_graph_map, self._argdef_outdef_map
self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat
) )
for _, g in self._argdef_graph_map.items(): for _, g in self._argdef_graph_map.items():
g.compile() g.compile()
@@ -712,6 +777,20 @@ class TracedModuleBuilder(NodeMixin):
v = v.build() v = v.build()
setattr(traced_module, k, v) setattr(traced_module, k, v)


if isinstance(self._mod, QATModule):
unset_module_tracing()
traced_module.with_act = self._mod.with_act
traced_module.with_weight = self._mod.with_weight
if not hasattr(traced_module, "act_fake_quant"):
traced_module.act_fakequant = None
if not hasattr(traced_module, "act_observer"):
traced_module.act_observer = None
if not hasattr(traced_module, "weight_fake_quant"):
traced_module.weight_fakequant = None
if not hasattr(traced_module, "weight_observer"):
traced_module.weight_observer = None
set_module_tracing()

return traced_module return traced_module


def _record_wrapped_nodes(self, node): def _record_wrapped_nodes(self, node):
@@ -846,7 +925,8 @@ class TracedModuleBuilder(NodeMixin):
attr = getattr(self._mod, name) attr = getattr(self._mod, name)
if isinstance(attr, Module): if isinstance(attr, Module):
attr = TracedModuleBuilder(attr) attr = TracedModuleBuilder(attr)
setattr(self, name, attr)
if isinstance(attr, (Module, RawTensor)):
setattr(self, name, attr)
NodeMixin.wrap( NodeMixin.wrap(
attr, attr,
lambda: GetAttr.make( lambda: GetAttr.make(
@@ -1066,7 +1146,7 @@ class TracedModule(Module):
argdef_graph_map = None argdef_graph_map = None
argdef_outdef_map = None argdef_outdef_map = None


def __init__(self, is_top, argdef_graph_map, argdef_outdef_map):
def __init__(self, is_top, argdef_graph_map, argdef_outdef_map, is_qat=False):
super(TracedModule, self).__init__() super(TracedModule, self).__init__()
self.argdef_graph_map = argdef_graph_map self.argdef_graph_map = argdef_graph_map
self.argdef_outdef_map = argdef_outdef_map self.argdef_outdef_map = argdef_outdef_map
@@ -1074,6 +1154,7 @@ class TracedModule(Module):
self.watch_points = [] self.watch_points = []
self.watch_node_value = {} self.watch_node_value = {}
self.end_points = [] self.end_points = []
self.is_qat = is_qat


def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
inputs, treedef = tree_flatten(((self, *args), kwargs)) inputs, treedef = tree_flatten(((self, *args), kwargs))
@@ -1195,8 +1276,8 @@ class TracedModule(Module):
): ):
if graph is not None and prefix_name and prefix_name[-1] != "_": if graph is not None and prefix_name and prefix_name[-1] != "_":
prefix_name += "_" prefix_name += "_"
if graph is None:
assert not isinstance(module, TracedModule)
if graph is None or module.is_qat:
assert not isinstance(module, TracedModule) or module.is_qat
const = Constant(module, "self.%s" % module2name[id(module)]) const = Constant(module, "self.%s" % module2name[id(module)])
m_node = call.inputs[0] m_node = call.inputs[0]
if m_node.top_graph != active_module_tracer().current_scope(): if m_node.top_graph != active_module_tracer().current_scope():
@@ -1326,9 +1407,23 @@ def _register_all_builtin_module():
isclass(m[1]) isclass(m[1])
and issubclass(m[1], M.Module) and issubclass(m[1], M.Module)
and m[1] is not M.Sequential and m[1] is not M.Sequential
and m[1] is not M.ModuleList
): ):
module_tracer.register_as_builtin(m[1]) module_tracer.register_as_builtin(m[1])


module_tracer.register_as_builtin(Observer)
module_tracer.register_as_builtin(MinMaxObserver)
module_tracer.register_as_builtin(SyncMinMaxObserver)
module_tracer.register_as_builtin(ExponentialMovingAverageObserver)
module_tracer.register_as_builtin(SyncExponentialMovingAverageObserver)
module_tracer.register_as_builtin(HistogramObserver)
module_tracer.register_as_builtin(PassiveObserver)

module_tracer.register_as_builtin(LSQ)
module_tracer.register_as_builtin(TQT)
module_tracer.register_as_builtin(FakeQuantize)
module_tracer.register_as_builtin(TM_FakeQuant)



def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
""" """


Loading…
Cancel
Save