GitOrigin-RevId: 49f70a5f46
release-1.6
@@ -17,12 +17,14 @@ from typing import Callable, Dict, List | |||||
from ...core._imperative_rt import OpDef | from ...core._imperative_rt import OpDef | ||||
from ...core._imperative_rt.core2 import Tensor as RawTensor | from ...core._imperative_rt.core2 import Tensor as RawTensor | ||||
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | ||||
from ...core.ops.builtin import FakeQuant | |||||
from ...core.ops.special import Const | from ...core.ops.special import Const | ||||
from ...module import Module | from ...module import Module | ||||
from ...tensor import Parameter, Tensor | from ...tensor import Parameter, Tensor | ||||
from .module_tracer import active_module_tracer, module_tracer | from .module_tracer import active_module_tracer, module_tracer | ||||
from .node import ModuleNode, Node, NodeMixin, TensorNode | from .node import ModuleNode, Node, NodeMixin, TensorNode | ||||
from .pytree import ArgsIndex, TreeDef, tree_flatten | |||||
from .pytree import ArgsIndex, TreeDef, _is_const_leaf, _is_leaf, tree_flatten | |||||
from .serialization import get_opdef_state, load_opdef_from_state | |||||
def rstrip(s: str, __chars: str): | def rstrip(s: str, __chars: str): | ||||
@@ -76,6 +78,7 @@ class Expr: | |||||
node.users.append(self) | node.users.append(self) | ||||
else: | else: | ||||
assert node is None | assert node is None | ||||
assert _is_leaf(val) and _is_const_leaf(val) | |||||
idx = len(self.inputs) + len(self.const_val) | idx = len(self.inputs) + len(self.const_val) | ||||
self.const_val.append((idx, val)) | self.const_val.append((idx, val)) | ||||
@@ -154,6 +157,11 @@ class Expr: | |||||
return self._top_graph() | return self._top_graph() | ||||
return None | return None | ||||
def __getstate__(self): | |||||
state = self.__dict__.copy() | |||||
state.pop("_top_graph", None) | |||||
return state | |||||
# expr: None (i.e. fake expression which is used to mark input) | # expr: None (i.e. fake expression which is used to mark input) | ||||
class Input(Expr): | class Input(Expr): | ||||
@@ -321,14 +329,36 @@ class Apply(Expr): | |||||
", ".join(str(i) for i in self.inputs), | ", ".join(str(i) for i in self.inputs), | ||||
) | ) | ||||
def __getstate__(self): | |||||
state = super().__getstate__() | |||||
state["opdef"] = get_opdef_state(state["opdef"]) | |||||
return state | |||||
def __setstate__(self, state): | |||||
state["opdef"] = load_opdef_from_state(state["opdef"]) | |||||
for k, v in state.items(): | |||||
setattr(self, k, v) | |||||
@classmethod | @classmethod | ||||
def apply_module_trace_hook(cls, opdef, *inputs): | def apply_module_trace_hook(cls, opdef, *inputs): | ||||
for i in inputs: | for i in inputs: | ||||
node = NodeMixin.get(i, None) | node = NodeMixin.get(i, None) | ||||
if node is None: # capture as constant | if node is None: # capture as constant | ||||
NodeMixin.wrap_safe(i, Constant.make(i)) | NodeMixin.wrap_safe(i, Constant.make(i)) | ||||
apply_node = cls.make(opdef) | |||||
apply_node.add_inputs(inputs) | |||||
if isinstance(opdef, FakeQuant): | |||||
inp_nodes = [NodeMixin.get(inputs[0])] | |||||
for i in inputs[1:]: | |||||
node = Constant.make(i) | |||||
inp_nodes.append(node) | |||||
apply_node = cls.make(opdef) | |||||
for n in inp_nodes: | |||||
n.users.append(apply_node) | |||||
apply_node.inputs = inp_nodes | |||||
else: | |||||
apply_node = cls.make(opdef) | |||||
apply_node.add_inputs(inputs) | |||||
assert not apply_node.const_val | assert not apply_node.const_val | ||||
unset_module_tracing() | unset_module_tracing() | ||||
@@ -387,7 +417,7 @@ class Constant(Expr): | |||||
super().__init__() | super().__init__() | ||||
assert isinstance(c, (RawTensor, Module)) | assert isinstance(c, (RawTensor, Module)) | ||||
if isinstance(c, Module): | if isinstance(c, Module): | ||||
assert module_tracer.is_builtin(c) | |||||
assert module_tracer.is_builtin(c) or c.is_qat | |||||
self.value = c | self.value = c | ||||
self.name = name | self.name = name | ||||
self.inputs = [] | self.inputs = [] | ||||
@@ -395,6 +425,7 @@ class Constant(Expr): | |||||
self.outputs = [ | self.outputs = [ | ||||
node_cls(self, name=name), | node_cls(self, name=name), | ||||
] | ] | ||||
self.outputs[0]._name = name if name else "const_" + str(self._id) | |||||
@classmethod | @classmethod | ||||
def make(cls, *args, **kwargs): | def make(cls, *args, **kwargs): | ||||
@@ -422,7 +453,7 @@ class Constant(Expr): | |||||
) | ) | ||||
def __getstate__(self): | def __getstate__(self): | ||||
state = self.__dict__.copy() | |||||
state = super().__getstate__() | |||||
if isinstance(self.value, RawTensor): | if isinstance(self.value, RawTensor): | ||||
state["value"] = Tensor(self.value) | state["value"] = Tensor(self.value) | ||||
return state | return state |
@@ -0,0 +1,48 @@ | |||||
from copy import deepcopy | |||||
from typing import Union | |||||
from ...core.tensor.dtype import QuantDtypeMeta | |||||
from ...quantization.fake_quant import QParamsModuleMixin, _FakeQuantize | |||||
from ...quantization.utils import QParams, QuantMode, fake_quant_tensor | |||||
class FakeQuantize(_FakeQuantize, QParamsModuleMixin): | |||||
def __init__( | |||||
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs | |||||
): | |||||
super().__init__(dtype, enable, **kwargs) | |||||
self.qparams = None | |||||
def fake_quant_forward(self, inp, qparams: QParams = None): | |||||
if qparams is None: | |||||
qparams = self.get_qparams() | |||||
assert ( | |||||
qparams.dtype_meta is self.dtype | |||||
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format( | |||||
qparams.dtype_meta, self.dtype | |||||
) | |||||
return fake_quant_tensor(inp, qparams) | |||||
def get_qparams(self): | |||||
return self.qparams | |||||
def set_qparams(self, qparams: QParams): | |||||
""" | |||||
:param qparams: used to set initial scale. | |||||
""" | |||||
if qparams.scale is None: | |||||
raise AssertionError("Can not get an initialized scale") | |||||
scale = qparams.scale | |||||
if qparams.dtype_meta is None: | |||||
qparams.dtype_meta = self.dtype | |||||
else: | |||||
assert ( | |||||
qparams.dtype_meta is self.dtype | |||||
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format( | |||||
qparams.dtype_meta, self.dtype | |||||
) | |||||
dtype_meta = qparams.dtype_meta | |||||
zero_point = qparams.zero_point | |||||
mode = qparams.mode | |||||
self.qparams = QParams(mode, dtype_meta, scale, zero_point) |
@@ -12,6 +12,7 @@ from ... import Tensor | |||||
from ... import functional as F | from ... import functional as F | ||||
from ...core.tensor.array_method import ArrayMethodMixin | from ...core.tensor.array_method import ArrayMethodMixin | ||||
from ...module import Module | from ...module import Module | ||||
from ...module.qat import QATModule | |||||
_active_module_tracer = None | _active_module_tracer = None | ||||
@@ -68,7 +69,7 @@ BUILTIN_ARRAY_METHOD = [ | |||||
"__iand__", | "__iand__", | ||||
"__ior__", | "__ior__", | ||||
"__ixor__", | "__ixor__", | ||||
"T", | |||||
"transpose", | |||||
"astype", | "astype", | ||||
"reshape", | "reshape", | ||||
"_broadcast", | "_broadcast", | ||||
@@ -180,6 +181,7 @@ class Patcher: | |||||
self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) | self.patch_method(ArrayMethodMixin, meth, self.wrap_fn) | ||||
self.patch_method(Tensor, "detach", self.wrap_fn) | self.patch_method(Tensor, "detach", self.wrap_fn) | ||||
self.patch_method(Tensor, "__new__", self.wrap_fn) | self.patch_method(Tensor, "__new__", self.wrap_fn) | ||||
self.patch_method(QATModule, "_apply_fakequant_with_observer", self.wrap_fn) | |||||
for i, j in self._builtin_functions: | for i, j in self._builtin_functions: | ||||
if id(i) not in self.visited_frames_ids: | if id(i) not in self.visited_frames_ids: | ||||
self.patch_function(i, j, self.wrap_fn) | self.patch_function(i, j, self.wrap_fn) | ||||
@@ -127,7 +127,7 @@ class TensorNode(Node): | |||||
shape = None # type: Tuple[int] | shape = None # type: Tuple[int] | ||||
dtype = None # type: numpy.dtype | dtype = None # type: numpy.dtype | ||||
qparam = None | |||||
qparams = None | |||||
device = None | device = None | ||||
def __getstate__(self): | def __getstate__(self): | ||||
@@ -135,7 +135,7 @@ class TensorNode(Node): | |||||
"expr": self.expr, | "expr": self.expr, | ||||
"users": self.users, | "users": self.users, | ||||
"_id": self._id, | "_id": self._id, | ||||
"qparam": self.qparam, | |||||
"qparams": self.qparams, | |||||
"shape": self.shape, | "shape": self.shape, | ||||
"dtype": self.dtype, | "dtype": self.dtype, | ||||
"device": self.device, | "device": self.device, | ||||
@@ -155,10 +155,7 @@ def tree_flatten( | |||||
assert is_leaf(values), values | assert is_leaf(values), values | ||||
node = LeafDef(leaf_type(values)) | node = LeafDef(leaf_type(values)) | ||||
if is_const_leaf(values): | if is_const_leaf(values): | ||||
if isinstance(values, np.ndarray): | |||||
node.const_val = str(values) | |||||
else: | |||||
node.const_val = values | |||||
node.const_val = values | |||||
return [values,], node | return [values,], node | ||||
rst = [] | rst = [] | ||||
@@ -232,9 +229,13 @@ class LeafDef(TreeDef): | |||||
return leaves[0] | return leaves[0] | ||||
def __eq__(self, other): | def __eq__(self, other): | ||||
if isinstance(self.const_val, np.ndarray): | |||||
return self.type == other.type and (self.const_val == other.const_val).all() | |||||
return self.type == other.type and self.const_val == other.const_val | return self.type == other.type and self.const_val == other.const_val | ||||
def __hash__(self): | def __hash__(self): | ||||
if isinstance(self.const_val, np.ndarray): | |||||
return hash(tuple([self.type, str(self.const_val)])) | |||||
return hash(tuple([self.type, self.const_val])) | return hash(tuple([self.type, self.const_val])) | ||||
def __repr__(self): | def __repr__(self): | ||||
@@ -29,14 +29,20 @@ from ...core._imperative_rt.core2 import ( | |||||
from ...core._trace_option import set_symbolic_shape | from ...core._trace_option import set_symbolic_shape | ||||
from ...core.tensor.array_method import ArrayMethodMixin | from ...core.tensor.array_method import ArrayMethodMixin | ||||
from ...module import Module | from ...module import Module | ||||
from ...quantization.fake_quant import LSQ, TQT, FakeQuantize | |||||
from ...module.qat import QATModule | |||||
from ...quantization.fake_quant import LSQ, TQT, FakeQuantize, _FakeQuantize | |||||
from ...quantization.observer import ( | from ...quantization.observer import ( | ||||
ExponentialMovingAverageObserver, | ExponentialMovingAverageObserver, | ||||
HistogramObserver, | |||||
MinMaxObserver, | MinMaxObserver, | ||||
Observer, | |||||
PassiveObserver, | |||||
SyncExponentialMovingAverageObserver, | |||||
SyncMinMaxObserver, | SyncMinMaxObserver, | ||||
) | ) | ||||
from ...tensor import Tensor | from ...tensor import Tensor | ||||
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input | from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input | ||||
from .fake_quant import FakeQuantize as TM_FakeQuant | |||||
from .module_tracer import ( | from .module_tracer import ( | ||||
Patcher, | Patcher, | ||||
active_module_tracer, | active_module_tracer, | ||||
@@ -613,7 +619,8 @@ def _wrapped_function(orig_func): | |||||
if isinstance(i, (RawTensor, NodeMixin)): | if isinstance(i, (RawTensor, NodeMixin)): | ||||
NodeMixin.wrap_safe(i, Constant.make(i)) | NodeMixin.wrap_safe(i, Constant.make(i)) | ||||
meth_name = _get_meth_name(args[0], wrapped_fn) if args else None | meth_name = _get_meth_name(args[0], wrapped_fn) if args else None | ||||
if meth_name: | |||||
arg_type = args[0] if isinstance(args[0], type) else type(args[0]) | |||||
if meth_name and issubclass(arg_type, RawTensor): | |||||
self = inputs[0] | self = inputs[0] | ||||
if meth_name == "__new__": | if meth_name == "__new__": | ||||
if all([not isinstance(i, RawTensor) for i in inputs]): | if all([not isinstance(i, RawTensor) for i in inputs]): | ||||
@@ -680,7 +687,15 @@ class TracedModuleBuilder(NodeMixin): | |||||
self._mod = mod | self._mod = mod | ||||
self._body = None | self._body = None | ||||
self._is_top = is_top_module | self._is_top = is_top_module | ||||
self._is_builtin = module_tracer.is_builtin(mod) | |||||
self._is_builtin = ( | |||||
True | |||||
if isinstance(mod, (Observer, _FakeQuantize)) | |||||
else module_tracer.is_builtin(mod) | |||||
) | |||||
if isinstance(self._mod, QATModule): | |||||
unset_module_tracing() | |||||
self._check_qat_module(self._mod) | |||||
set_module_tracing() | |||||
self._argdef_graph_map = {} | self._argdef_graph_map = {} | ||||
self._argdef_outdef_map = {} | self._argdef_outdef_map = {} | ||||
@@ -693,15 +708,65 @@ class TracedModuleBuilder(NodeMixin): | |||||
dict(TracedModuleBuilder.__dict__), | dict(TracedModuleBuilder.__dict__), | ||||
) | ) | ||||
def _check_qat_module(self, qat_module): | |||||
def isbuiltin(m): | |||||
return m is None or module_tracer.is_builtin(m) | |||||
if qat_module.with_act: | |||||
act_observer = qat_module.act_observer | |||||
act_fakequant = qat_module.act_fake_quant | |||||
if not isbuiltin(act_observer) or not isbuiltin(act_fakequant): | |||||
qparams = ( | |||||
act_observer.get_qparams() | |||||
if hasattr(act_observer, "get_qparams") | |||||
else act_fakequant.get_qparams() | |||||
) | |||||
dtype = ( | |||||
act_observer.dtype | |||||
if hasattr(act_observer, "dtype") | |||||
else act_fakequant.dtype | |||||
) | |||||
qat_module.act_observer = None | |||||
qat_module.act_fake_quant = TM_FakeQuant(dtype) | |||||
qat_module.act_fake_quant.set_qparams(qparams) | |||||
if qat_module.with_weight: | |||||
weight_observer = qat_module.weight_observer | |||||
weight_fakequant = qat_module.weight_fake_quant | |||||
if not isbuiltin(weight_observer) or not isbuiltin(weight_fakequant): | |||||
qparams = ( | |||||
weight_observer.get_qparams() | |||||
if hasattr(weight_observer, "get_qparams") | |||||
else weight_fakequant.get_qparams() | |||||
) | |||||
dtype = ( | |||||
weight_observer.dtype | |||||
if hasattr(weight_observer, "dtype") | |||||
else weight_fakequant.dtype | |||||
) | |||||
qat_module.weight_observer = None | |||||
qat_module.weight_fake_quant = TM_FakeQuant(dtype) | |||||
qat_module.weight_fake_quant.set_qparams(qparams) | |||||
def build(self): | def build(self): | ||||
if self._is_builtin or isinstance(self._mod, TracedModule): | if self._is_builtin or isinstance(self._mod, TracedModule): | ||||
if module_tracer.is_builtin(self._mod) or isinstance( | |||||
self._mod, TracedModule | |||||
): | |||||
mod_type = type(self._mod) | |||||
else: | |||||
assert isinstance(self._mod, (Observer, _FakeQuantize)) | |||||
mod_type = ( | |||||
Observer if isinstance(self._mod, Observer) else _FakeQuantize | |||||
) | |||||
for node in self.nodes: | for node in self.nodes: | ||||
node.module_type = type(self._mod) | |||||
# node._owner = weakref.ref(self._mod) | |||||
node.module_type = mod_type | |||||
return self._mod | return self._mod | ||||
else: | else: | ||||
is_qat = isinstance(self._mod, QATModule) | |||||
traced_module = TracedModule( | traced_module = TracedModule( | ||||
self._is_top, self._argdef_graph_map, self._argdef_outdef_map | |||||
self._is_top, self._argdef_graph_map, self._argdef_outdef_map, is_qat | |||||
) | ) | ||||
for _, g in self._argdef_graph_map.items(): | for _, g in self._argdef_graph_map.items(): | ||||
g.compile() | g.compile() | ||||
@@ -712,6 +777,20 @@ class TracedModuleBuilder(NodeMixin): | |||||
v = v.build() | v = v.build() | ||||
setattr(traced_module, k, v) | setattr(traced_module, k, v) | ||||
if isinstance(self._mod, QATModule): | |||||
unset_module_tracing() | |||||
traced_module.with_act = self._mod.with_act | |||||
traced_module.with_weight = self._mod.with_weight | |||||
if not hasattr(traced_module, "act_fake_quant"): | |||||
traced_module.act_fakequant = None | |||||
if not hasattr(traced_module, "act_observer"): | |||||
traced_module.act_observer = None | |||||
if not hasattr(traced_module, "weight_fake_quant"): | |||||
traced_module.weight_fakequant = None | |||||
if not hasattr(traced_module, "weight_observer"): | |||||
traced_module.weight_observer = None | |||||
set_module_tracing() | |||||
return traced_module | return traced_module | ||||
def _record_wrapped_nodes(self, node): | def _record_wrapped_nodes(self, node): | ||||
@@ -846,7 +925,8 @@ class TracedModuleBuilder(NodeMixin): | |||||
attr = getattr(self._mod, name) | attr = getattr(self._mod, name) | ||||
if isinstance(attr, Module): | if isinstance(attr, Module): | ||||
attr = TracedModuleBuilder(attr) | attr = TracedModuleBuilder(attr) | ||||
setattr(self, name, attr) | |||||
if isinstance(attr, (Module, RawTensor)): | |||||
setattr(self, name, attr) | |||||
NodeMixin.wrap( | NodeMixin.wrap( | ||||
attr, | attr, | ||||
lambda: GetAttr.make( | lambda: GetAttr.make( | ||||
@@ -1066,7 +1146,7 @@ class TracedModule(Module): | |||||
argdef_graph_map = None | argdef_graph_map = None | ||||
argdef_outdef_map = None | argdef_outdef_map = None | ||||
def __init__(self, is_top, argdef_graph_map, argdef_outdef_map): | |||||
def __init__(self, is_top, argdef_graph_map, argdef_outdef_map, is_qat=False): | |||||
super(TracedModule, self).__init__() | super(TracedModule, self).__init__() | ||||
self.argdef_graph_map = argdef_graph_map | self.argdef_graph_map = argdef_graph_map | ||||
self.argdef_outdef_map = argdef_outdef_map | self.argdef_outdef_map = argdef_outdef_map | ||||
@@ -1074,6 +1154,7 @@ class TracedModule(Module): | |||||
self.watch_points = [] | self.watch_points = [] | ||||
self.watch_node_value = {} | self.watch_node_value = {} | ||||
self.end_points = [] | self.end_points = [] | ||||
self.is_qat = is_qat | |||||
def forward(self, *args, **kwargs): | def forward(self, *args, **kwargs): | ||||
inputs, treedef = tree_flatten(((self, *args), kwargs)) | inputs, treedef = tree_flatten(((self, *args), kwargs)) | ||||
@@ -1195,8 +1276,8 @@ class TracedModule(Module): | |||||
): | ): | ||||
if graph is not None and prefix_name and prefix_name[-1] != "_": | if graph is not None and prefix_name and prefix_name[-1] != "_": | ||||
prefix_name += "_" | prefix_name += "_" | ||||
if graph is None: | |||||
assert not isinstance(module, TracedModule) | |||||
if graph is None or module.is_qat: | |||||
assert not isinstance(module, TracedModule) or module.is_qat | |||||
const = Constant(module, "self.%s" % module2name[id(module)]) | const = Constant(module, "self.%s" % module2name[id(module)]) | ||||
m_node = call.inputs[0] | m_node = call.inputs[0] | ||||
if m_node.top_graph != active_module_tracer().current_scope(): | if m_node.top_graph != active_module_tracer().current_scope(): | ||||
@@ -1326,9 +1407,23 @@ def _register_all_builtin_module(): | |||||
isclass(m[1]) | isclass(m[1]) | ||||
and issubclass(m[1], M.Module) | and issubclass(m[1], M.Module) | ||||
and m[1] is not M.Sequential | and m[1] is not M.Sequential | ||||
and m[1] is not M.ModuleList | |||||
): | ): | ||||
module_tracer.register_as_builtin(m[1]) | module_tracer.register_as_builtin(m[1]) | ||||
module_tracer.register_as_builtin(Observer) | |||||
module_tracer.register_as_builtin(MinMaxObserver) | |||||
module_tracer.register_as_builtin(SyncMinMaxObserver) | |||||
module_tracer.register_as_builtin(ExponentialMovingAverageObserver) | |||||
module_tracer.register_as_builtin(SyncExponentialMovingAverageObserver) | |||||
module_tracer.register_as_builtin(HistogramObserver) | |||||
module_tracer.register_as_builtin(PassiveObserver) | |||||
module_tracer.register_as_builtin(LSQ) | |||||
module_tracer.register_as_builtin(TQT) | |||||
module_tracer.register_as_builtin(FakeQuantize) | |||||
module_tracer.register_as_builtin(TM_FakeQuant) | |||||
def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: | def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule: | ||||
""" | """ | ||||