|
|
@@ -145,9 +145,8 @@ def _node_to_tensor(*args, **kwargs): |
|
|
|
value = n.value |
|
|
|
if value is None: |
|
|
|
flag = _set_graph_surgery_mode(False) |
|
|
|
unset_module_tracing() |
|
|
|
value = F.zeros(shape=n._shape, dtype=n._dtype) |
|
|
|
set_module_tracing() |
|
|
|
with _exclude_from_trace(): |
|
|
|
value = F.zeros(shape=n._shape, dtype=n._dtype) |
|
|
|
_set_graph_surgery_mode(flag) |
|
|
|
orig_n = NodeMixin.get(value, None) |
|
|
|
if orig_n is None or "setitem" not in orig_n._name: |
|
|
@@ -1274,8 +1273,10 @@ def _wrapped_function(orig_func): |
|
|
|
@functools.wraps(orig_func) |
|
|
|
def wrapped_fn(*args, **kwargs): |
|
|
|
method_func = kwargs.pop("method_func", wrapped_fn) |
|
|
|
if is_tracing_module(): |
|
|
|
unset_module_tracing() |
|
|
|
if not is_tracing_module(): |
|
|
|
return orig_func(*args, **kwargs) |
|
|
|
|
|
|
|
with _exclude_from_trace(): |
|
|
|
inputs, tree_def = tree_flatten((args, kwargs)) |
|
|
|
for i in inputs: |
|
|
|
if not NodeMixin.get(i, None): |
|
|
@@ -1290,7 +1291,6 @@ def _wrapped_function(orig_func): |
|
|
|
if meth_name == "__new__": |
|
|
|
if all([not isinstance(i, RawTensor) for i in inputs]): |
|
|
|
# only trace Tensor.__new__() when there are tensors in args |
|
|
|
set_module_tracing() |
|
|
|
return orig_func(*args, **kwargs) |
|
|
|
if isinstance(args[1], RawTensor): |
|
|
|
node = NodeMixin.get(inputs[1]) |
|
|
@@ -1327,9 +1327,7 @@ def _wrapped_function(orig_func): |
|
|
|
call_node, outputs |
|
|
|
) |
|
|
|
|
|
|
|
set_module_tracing() |
|
|
|
return rst |
|
|
|
return orig_func(*args, **kwargs) |
|
|
|
|
|
|
|
return wrapped_fn |
|
|
|
|
|
|
@@ -1339,8 +1337,8 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
_mod = None # type: Module |
|
|
|
_body = None # type: InternalGraph |
|
|
|
_is_builtin = None # type: bool |
|
|
|
_argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"] |
|
|
|
_argdef_outdef_map = None # type: Dict[Treedef, Treedef] |
|
|
|
_argdef_graph_map = None # type: Dict[TreeDef, "InternalGraph"] |
|
|
|
_argdef_outdef_map = None # type: Dict[TreeDef, TreeDef] |
|
|
|
nodes = None |
|
|
|
|
|
|
|
__builder_attributes__ = [ |
|
|
@@ -1371,9 +1369,8 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
else module_tracer.is_builtin(mod) |
|
|
|
) |
|
|
|
if isinstance(self._mod, QATModule): |
|
|
|
unset_module_tracing() |
|
|
|
self._check_qat_module(self._mod) |
|
|
|
set_module_tracing() |
|
|
|
with _exclude_from_trace(): |
|
|
|
self._check_qat_module(self._mod) |
|
|
|
self._argdef_graph_map = {} |
|
|
|
self._argdef_outdef_map = {} |
|
|
|
|
|
|
@@ -1458,18 +1455,17 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
setattr(traced_module, k, v) |
|
|
|
|
|
|
|
if isinstance(self._mod, QATModule): |
|
|
|
unset_module_tracing() |
|
|
|
traced_module.with_act = self._mod.with_act |
|
|
|
traced_module.with_weight = self._mod.with_weight |
|
|
|
if not hasattr(traced_module, "act_fake_quant"): |
|
|
|
traced_module.act_fake_quant = None |
|
|
|
if not hasattr(traced_module, "act_observer"): |
|
|
|
traced_module.act_observer = None |
|
|
|
if not hasattr(traced_module, "weight_fake_quant"): |
|
|
|
traced_module.weight_fake_quant = None |
|
|
|
if not hasattr(traced_module, "weight_observer"): |
|
|
|
traced_module.weight_observer = None |
|
|
|
set_module_tracing() |
|
|
|
with _exclude_from_trace(): |
|
|
|
traced_module.with_act = self._mod.with_act |
|
|
|
traced_module.with_weight = self._mod.with_weight |
|
|
|
if not hasattr(traced_module, "act_fake_quant"): |
|
|
|
traced_module.act_fake_quant = None |
|
|
|
if not hasattr(traced_module, "act_observer"): |
|
|
|
traced_module.act_observer = None |
|
|
|
if not hasattr(traced_module, "weight_fake_quant"): |
|
|
|
traced_module.weight_fake_quant = None |
|
|
|
if not hasattr(traced_module, "weight_observer"): |
|
|
|
traced_module.weight_observer = None |
|
|
|
|
|
|
|
if self._is_top: |
|
|
|
traced_module._update_ref() |
|
|
@@ -1505,16 +1501,14 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
callnode.arg_def = tree_def |
|
|
|
|
|
|
|
if self._is_builtin or tree_def in self._argdef_graph_map: |
|
|
|
unset_module_tracing() |
|
|
|
rst = self._mod(*args, **kwargs) |
|
|
|
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) |
|
|
|
if _get_expr_checker(): |
|
|
|
with _exclude_from_trace(): |
|
|
|
with _exclude_from_trace(): |
|
|
|
rst = self._mod(*args, **kwargs) |
|
|
|
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf) |
|
|
|
if _get_expr_checker(): |
|
|
|
tmp = self.build() |
|
|
|
active_module_tracer().checker.check_builtin_module( |
|
|
|
tmp, callnode, outputs |
|
|
|
) |
|
|
|
set_module_tracing() |
|
|
|
if self._is_builtin: |
|
|
|
self._body = None |
|
|
|
elif tree_def in self._argdef_graph_map: |
|
|
@@ -1640,16 +1634,17 @@ class TracedModuleBuilder(NodeMixin): |
|
|
|
|
|
|
|
if isinstance(attr, (List, Dict)): |
|
|
|
flag = _set_graph_surgery_mode(False) |
|
|
|
unset_module_tracing() |
|
|
|
has_module, m_container = replace_container_with_module_container(attr) |
|
|
|
if m_container: |
|
|
|
attr = m_container |
|
|
|
if has_module and not m_container: |
|
|
|
raise ValueError( |
|
|
|
"Can not trace the module that uses the same container to store" |
|
|
|
" Module and Non-Module objects." |
|
|
|
with _exclude_from_trace(): |
|
|
|
has_module, m_container = replace_container_with_module_container( |
|
|
|
attr |
|
|
|
) |
|
|
|
set_module_tracing() |
|
|
|
if m_container: |
|
|
|
attr = m_container |
|
|
|
if has_module and not m_container: |
|
|
|
raise ValueError( |
|
|
|
"Can not trace the module that uses the same container to store" |
|
|
|
" Module and Non-Module objects." |
|
|
|
) |
|
|
|
_set_graph_surgery_mode(flag) |
|
|
|
|
|
|
|
if isinstance(attr, Module): |
|
|
|