GitOrigin-RevId: 49f70a5f46
release-1.6
@@ -17,12 +17,14 @@ from typing import Callable, Dict, List | |||
from ...core._imperative_rt import OpDef | |||
from ...core._imperative_rt.core2 import Tensor as RawTensor | |||
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | |||
from ...core.ops.builtin import FakeQuant | |||
from ...core.ops.special import Const | |||
from ...module import Module | |||
from ...tensor import Parameter, Tensor | |||
from .module_tracer import active_module_tracer, module_tracer | |||
from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
from .pytree import ArgsIndex, TreeDef, tree_flatten | |||
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten | |||
from .serialization import get_opdef_state, load_opdef_from_state | |||
def rstrip(s: str, __chars: str): | |||
@@ -76,6 +78,7 @@ class Expr: | |||
node.users.append(self) | |||
else: | |||
assert node is None | |||
assert _is_leaf(val) and _is_const_leaf(val) | |||
idx = len(self.inputs) + len(self.const_val) | |||
self.const_val.append((idx, val)) | |||
@@ -154,6 +157,11 @@ class Expr: | |||
return self._top_graph() | |||
return None | |||
def __getstate__(self): | |||
state = self.__dict__.copy() | |||
state.pop("_top_graph", None) | |||
return state | |||
# expr: None (i.e. fake expression which is used to mark input) | |||
class Input(Expr): | |||
@@ -321,14 +329,36 @@ class Apply(Expr): | |||
", ".join(str(i) for i in self.inputs), | |||
) | |||
def __getstate__(self): | |||
state = super().__getstate__() | |||
state["opdef"] = get_opdef_state(state["opdef"]) | |||
return state | |||
def __setstate__(self, state): | |||
state["opdef"] = load_opdef_from_state(state["opdef"]) | |||
for k, v in state.items(): | |||
setattr(self, k, v) | |||
@classmethod | |||
def apply_module_trace_hook(cls, opdef, *inputs): | |||
for i in inputs: | |||
node = NodeMixin.get(i, None) | |||
if node is None: # capture as constant | |||
NodeMixin.wrap_safe(i, Constant.make(i)) | |||
apply_node = cls.make(opdef) | |||
apply_node.add_inputs(inputs) | |||
if isinstance(opdef, FakeQuant): | |||
inp_nodes = [NodeMixin.get(inputs[0])] | |||
for i in inputs[1:]: | |||
node = Constant.make(i) | |||
inp_nodes.append(node) | |||
apply_node = cls.make(opdef) | |||
for n in inp_nodes: | |||
n.users.append(apply_node) | |||
apply_node.inputs = inp_nodes | |||
else: | |||
apply_node = cls.make(opdef) | |||
apply_node.add_inputs(inputs) | |||
assert not apply_node.const_val | |||
unset_module_tracing() | |||
@@ -387,7 +417,7 @@ class Constant(Expr): | |||
super().__init__() | |||
assert isinstance(c, (RawTensor, Module)) | |||
if isinstance(c, Module): | |||
assert module_tracer.is_builtin(c) | |||
assert module_tracer.is_builtin(c) or c.is_qat | |||
self.value = c | |||
self.name = name | |||
self.inputs = [] | |||
@@ -395,6 +425,7 @@ class Constant(Expr): | |||
self.outputs = [ | |||
node_cls(self, name=name), | |||
] | |||
self.outputs[0]._name = name if name else "const_" + str(self._id) | |||
@classmethod | |||
def make(cls, *args, **kwargs): | |||
@@ -422,7 +453,7 @@ class Constant(Expr): | |||
) | |||
def __getstate__(self): | |||
state = self.__dict__.copy() | |||
state = super().__getstate__() | |||
if isinstance(self.value, RawTensor): | |||
state["value"] = Tensor(self.value) | |||
return state |
@@ -0,0 +1,48 @@ | |||
from copy import deepcopy | |||
from typing import Union | |||
from ...core.tensor.dtype import QuantDtypeMeta | |||
from ...quantization.fake_quant import QParamsModuleMixin, _FakeQuantize | |||
from ...quantization.utils import QParams, QuantMode, fake_quant_tensor | |||
class FakeQuantize(_FakeQuantize, QParamsModuleMixin): | |||
def __init__( | |||
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs | |||
): | |||
super().__init__(dtype, enable, **kwargs) | |||
self.qparams = None | |||
def fake_quant_forward(self, inp, qparams: QParams = None): | |||
if qparams is None: | |||
qparams = self.get_qparams() | |||
assert ( | |||
qparams.dtype_meta is self.dtype | |||
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format( | |||
qparams.dtype_meta, self.dtype | |||
) | |||
return fake_quant_tensor(inp, qparams) | |||
def get_qparams(self): | |||
return self.qparams | |||
def set_qparams(self, qparams: QParams): | |||
""" | |||
:param qparams: used to set initial scale. | |||
""" | |||
if qparams.scale is None: | |||
raise AssertionError("Can not get an initialized scale") | |||
scale = qparams.scale | |||
if qparams.dtype_meta is None: | |||
qparams.dtype_meta = self.dtype | |||
else: | |||
assert ( | |||
qparams.dtype_meta is self.dtype | |||
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format( | |||
qparams.dtype_meta, self.dtype | |||
) | |||
dtype_meta = qparams.dtype_meta | |||
zero_point = qparams.zero_point | |||
mode = qparams.mode | |||
self.qparams = QParams(mode, dtype_meta, scale, zero_point) |
@@ -12,6 +12,7 @@ from ... import Tensor | |||
from ... import functional as F | |||
from ...core.tensor.array_method import ArrayMethodMixin | |||
from ...module import Module | |||
from ...module.qat import QATModule | |||
_active_module_tracer = None | |||
@@ -68,7 +69,7 @@ BUILTIN_ARRAY_METHOD = [ | |||
"__iand__", | |||
"__ior__", | |||
"__ixor__", | |||
"T", | |||
"transpose", | |||
"astype", | |||
"reshape", | |||
"_broadcast", | |||
@@ -180,6 +181,7 @@ class Patcher: | |||
self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) | |||
self.patch_method(Tensor, "detach", self.wrap_fn) | |||
self.patch_method(Tensor, "__new__", self.wrap_fn) | |||
self.patch_method(QATModule, "_apply_fakequant_with_observer", self.wrap_fn) | |||
for i, j in self._builtin_functions: | |||
if id(i) not in self.visited_frames_ids: | |||
self.patch_function(i, j, self.wrap_fn) | |||
@@ -127,7 +127,7 @@ class TensorNode(Node): | |||
shape = None # type: Tuple[int] | |||
dtype = None # type: numpy.dtype | |||
qparam = None | |||
qparams = None | |||
device = None | |||
def __getstate__(self): | |||
@@ -135,7 +135,7 @@ class TensorNode(Node): | |||
"expr": self.expr, | |||
"users": self.users, | |||
"_id": self._id, | |||
"qparam": self.qparam, | |||
"qparams": self.qparams, | |||
"shape": self.shape, | |||
"dtype": self.dtype, | |||
"device": self.device, | |||
@@ -155,10 +155,7 @@ def tree_flatten( | |||
assert is_leaf(values), values | |||
node = LeafDef(leaf_type(values)) | |||
if is_const_leaf(values): | |||
if isinstance(values, np.ndarray): | |||
node.const_val = str(values) | |||
else: | |||
node.const_val = values | |||
node.const_val = values | |||
return [values,], node | |||
rst = [] | |||
@@ -232,9 +229,13 @@ class LeafDef(TreeDef): | |||
return leaves[0] | |||
def __eq__(self, other): | |||
if isinstance(self.const_val, np.ndarray): | |||
return self.type == other.type and (self.const_val == other.const_val).all() | |||
return self.type == other.type and self.const_val == other.const_val | |||
def __hash__(self): | |||
if isinstance(self.const_val, np.ndarray): | |||
return hash(tuple([self.type, str(self.const_val)])) | |||
return hash(tuple([self.type, self.const_val])) | |||
def __repr__(self): | |||
@@ -29,14 +29,20 @@ from ...core._imperative_rt.core2 import ( | |||
from ...core._trace_option import set_symbolic_shape | |||
from ...core.tensor.array_method import ArrayMethodMixin | |||
from ...module import Module | |||
from ...quantization.fake_quant import LSQ, TQT, FakeQuantize | |||
from ...module.qat import QATModule | |||
from ...quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize | |||
from ...quantization.observer import ( | |||
ExponentialMovingAverageObserver, | |||
HistogramObserver, | |||
MinMaxObserver, | |||
Observer, | |||
PassiveObserver, | |||
SyncExponentialMovingAverageObserver, | |||
SyncMinMaxObserver, | |||
) | |||
from ...tensor import Tensor | |||
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input | |||
from .fake_quant import FakeQuantize as TM_FakeQuant | |||
from .module_tracer import ( | |||
Patcher, | |||
active_module_tracer, | |||
@@ -613,7 +619,8 @@ def _wrapped_function(orig_func): | |||
if isinstance(i, (RawTensor, NodeMixin)): | |||
NodeMixin.wrap_safe(i, Constant.make(i)) | |||
meth_name = _get_meth_name(args[0], wrapped_fn) if args else None | |||
if meth_name: | |||
arg_type = args[0] if isinstance(args[0], type) else type(args[0]) | |||
if meth_name and issubclass(arg_type, RawTensor): | |||
self = inputs[0] | |||
if meth_name == "__new__": | |||
if all([not isinstance(i, RawTensor) for i in inputs]): | |||
@@ -680,7 +687,15 @@ class TracedModuleBuilder(NodeMixin): | |||
self._mod = mod | |||
self._body = None | |||
self._is_top = is_top_module | |||
self._is_builtin = module_tracer.is_builtin(mod) | |||
self._is_builtin = ( | |||
True | |||
if isinstance(mod, (Observer, _FakeQuantize)) | |||
else module_tracer.is_builtin(mod) | |||
) | |||
if isinstance(self._mod, QATModule): | |||
unset_module_tracing() | |||
self._check_qat_module(self._mod) | |||
set_module_tracing() | |||
self._argdef_graph_map = {} | |||
self._argdef_outdef_map = {} | |||
@@ -693,15 +708,65 @@ class TracedModuleBuilder(NodeMixin): | |||
dict(TracedModuleBuilder.__dict__), | |||
) | |||
def _check_qat_module(self, qat_module): | |||
def isbuiltin(m): | |||
return m is None or module_tracer.is_builtin(m) | |||
if qat_module.with_act: | |||
act_observer = qat_module.act_observer | |||
act_fakequant = qat_module.act_fake_quant | |||
if not isbuiltin(act_observer) or not isbuiltin(act_fakequant): | |||
qparams = ( | |||
act_observer.get_qparams() | |||
if hasattr(act_observer, "get_qparams") | |||
else act_fakequant.get_qparams() | |||
) | |||
dtype = ( | |||
act_observer.dtype | |||
if hasattr(act_observer, "dtype") | |||
else act_fakequant.dtype | |||
) | |||
qat_module.act_observer = None | |||
qat_module.act_fake_quant = TM_FakeQuant(dtype) | |||
qat_module.act_fake_quant.set_qparams(qparams) | |||
if qat_module.with_weight: | |||
weight_observer = qat_module.weight_observer | |||
weight_fakequant = qat_module.weight_fake_quant | |||
if not isbuiltin(weight_observer) or not isbuiltin(weight_fakequant): | |||
qparams = ( | |||
weight_observer.get_qparams() | |||
if hasattr(weight_observer, "get_qparams") | |||
else weight_fakequant.get_qparams() | |||
) | |||
dtype = ( | |||
weight_observer.dtype | |||
if hasattr(weight_observer, "dtype") | |||
else weight_fakequant.dtype | |||
) | |||
qat_module.weight_observer = None | |||
qat_module.weight_fake_quant = TM_FakeQuant(dtype) | |||
qat_module.weight_fake_quant.set_qparams(qparams) | |||
def build(self): | |||
if self._is_builtin or isinstance(self._mod, TracedModule): | |||
if module_tracer.is_builtin(self._mod) or isinstance( | |||
self._mod, TracedModule | |||
): | |||
mod_type = type(self._mod) | |||
else: | |||
assert isinstance(self._mod, (Observer, _FakeQuantize)) | |||
mod_type = ( | |||
Observer if isinstance(self._mod, Observer) else _FakeQuantize | |||
) | |||
for node in self.nodes: | |||
node.module_type = type(self._mod) | |||
# node._owner = weakref.ref(self._mod) | |||
node.module_type = mod_type | |||
return self._mod | |||
else: | |||
is_qat = isinstance(self._mod, QATModule) | |||
traced_module = TracedModule( | |||
self._is_top, self._argdef_graph_map, self._argdef_outdef_map | |||
self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat | |||
) | |||
for _, g in self._argdef_graph_map.items(): | |||
g.compile() | |||
@@ -712,6 +777,20 @@ class TracedModuleBuilder(NodeMixin): | |||
v = v.build() | |||
setattr(traced_module, k, v) | |||
if isinstance(self._mod, QATModule): | |||
unset_module_tracing() | |||
traced_module.with_act = self._mod.with_act | |||
traced_module.with_weight = self._mod.with_weight | |||
if not hasattr(traced_module, "act_fake_quant"): | |||
traced_module.act_fakequant = None | |||
if not hasattr(traced_module, "act_observer"): | |||
traced_module.act_observer = None | |||
if not hasattr(traced_module, "weight_fake_quant"): | |||
traced_module.weight_fakequant = None | |||
if not hasattr(traced_module, "weight_observer"): | |||
traced_module.weight_observer = None | |||
set_module_tracing() | |||
return traced_module | |||
def _record_wrapped_nodes(self, node): | |||
@@ -846,7 +925,8 @@ class TracedModuleBuilder(NodeMixin): | |||
attr = getattr(self._mod, name) | |||
if isinstance(attr, Module): | |||
attr = TracedModuleBuilder(attr) | |||
setattr(self, name, attr) | |||
if isinstance(attr, (Module, RawTensor)): | |||
setattr(self, name, attr) | |||
NodeMixin.wrap( | |||
attr, | |||
lambda: GetAttr.make( | |||
@@ -1066,7 +1146,7 @@ class TracedModule(Module): | |||
argdef_graph_map = None | |||
argdef_outdef_map = None | |||
def __init__(self, is_top, argdef_graph_map, argdef_outdef_map): | |||
def __init__(self, is_top, argdef_graph_map, argdef_outdef_map, is_qat=False): | |||
super(TracedModule, self).__init__() | |||
self.argdef_graph_map = argdef_graph_map | |||
self.argdef_outdef_map = argdef_outdef_map | |||
@@ -1074,6 +1154,7 @@ class TracedModule(Module): | |||
self.watch_points = [] | |||
self.watch_node_value = {} | |||
self.end_points = [] | |||
self.is_qat = is_qat | |||
def forward(self, *args, **kwargs): | |||
inputs, treedef = tree_flatten(((self, *args), kwargs)) | |||
@@ -1195,8 +1276,8 @@ class TracedModule(Module): | |||
): | |||
if graph is not None and prefix_name and prefix_name[-1] != "_": | |||
prefix_name += "_" | |||
if graph is None: | |||
assert not isinstance(module, TracedModule) | |||
if graph is None or module.is_qat: | |||
assert not isinstance(module, TracedModule) or module.is_qat | |||
const = Constant(module, "self.%s" % module2name[id(module)]) | |||
m_node = call.inputs[0] | |||
if m_node.top_graph != active_module_tracer().current_scope(): | |||
@@ -1326,9 +1407,23 @@ def _register_all_builtin_module(): | |||
isclass(m[1]) | |||
and issubclass(m[1], M.Module) | |||
and m[1] is not M.Sequential | |||
and m[1] is not M.ModuleList | |||
): | |||
module_tracer.register_as_builtin(m[1]) | |||
module_tracer.register_as_builtin(Observer) | |||
module_tracer.register_as_builtin(MinMaxObserver) | |||
module_tracer.register_as_builtin(SyncMinMaxObserver) | |||
module_tracer.register_as_builtin(ExponentialMovingAverageObserver) | |||
module_tracer.register_as_builtin(SyncExponentialMovingAverageObserver) | |||
module_tracer.register_as_builtin(HistogramObserver) | |||
module_tracer.register_as_builtin(PassiveObserver) | |||
module_tracer.register_as_builtin(LSQ) | |||
module_tracer.register_as_builtin(TQT) | |||
module_tracer.register_as_builtin(FakeQuantize) | |||
module_tracer.register_as_builtin(TM_FakeQuant) | |||
def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: | |||
""" | |||