Browse Source

feat(traced_module): add _exclude_from_trace

GitOrigin-RevId: 615b769a02
tags/v1.8.0
Megvii Engine Team 3 years ago
parent
commit
97c90d9137
2 changed files with 45 additions and 48 deletions
  1. +35
    -40
      imperative/python/megengine/traced_module/traced_module.py
  2. +10
    -8
      imperative/python/megengine/traced_module/utils.py

+ 35
- 40
imperative/python/megengine/traced_module/traced_module.py View File

@@ -145,9 +145,8 @@ def _node_to_tensor(*args, **kwargs):
value = n.value value = n.value
if value is None: if value is None:
flag = _set_graph_surgery_mode(False) flag = _set_graph_surgery_mode(False)
unset_module_tracing()
value = F.zeros(shape=n._shape, dtype=n._dtype)
set_module_tracing()
with _exclude_from_trace():
value = F.zeros(shape=n._shape, dtype=n._dtype)
_set_graph_surgery_mode(flag) _set_graph_surgery_mode(flag)
orig_n = NodeMixin.get(value, None) orig_n = NodeMixin.get(value, None)
if orig_n is None or "setitem" not in orig_n._name: if orig_n is None or "setitem" not in orig_n._name:
@@ -1274,8 +1273,10 @@ def _wrapped_function(orig_func):
@functools.wraps(orig_func) @functools.wraps(orig_func)
def wrapped_fn(*args, **kwargs): def wrapped_fn(*args, **kwargs):
method_func = kwargs.pop("method_func", wrapped_fn) method_func = kwargs.pop("method_func", wrapped_fn)
if is_tracing_module():
unset_module_tracing()
if not is_tracing_module():
return orig_func(*args, **kwargs)

with _exclude_from_trace():
inputs, tree_def = tree_flatten((args, kwargs)) inputs, tree_def = tree_flatten((args, kwargs))
for i in inputs: for i in inputs:
if not NodeMixin.get(i, None): if not NodeMixin.get(i, None):
@@ -1290,7 +1291,6 @@ def _wrapped_function(orig_func):
if meth_name == "__new__": if meth_name == "__new__":
if all([not isinstance(i, RawTensor) for i in inputs]): if all([not isinstance(i, RawTensor) for i in inputs]):
# only trace Tensor.__new__() when there are tensors in args # only trace Tensor.__new__() when there are tensors in args
set_module_tracing()
return orig_func(*args, **kwargs) return orig_func(*args, **kwargs)
if isinstance(args[1], RawTensor): if isinstance(args[1], RawTensor):
node = NodeMixin.get(inputs[1]) node = NodeMixin.get(inputs[1])
@@ -1327,9 +1327,7 @@ def _wrapped_function(orig_func):
call_node, outputs call_node, outputs
) )


set_module_tracing()
return rst return rst
return orig_func(*args, **kwargs)


return wrapped_fn return wrapped_fn


@@ -1339,8 +1337,8 @@ class TracedModuleBuilder(NodeMixin):
_mod = None # type: Module _mod = None # type: Module
_body = None # type: InternalGraph _body = None # type: InternalGraph
_is_builtin = None # type: bool _is_builtin = None # type: bool
_argdef_graph_map = None # type: Dict[Treedef, "InternalGraph"]
_argdef_outdef_map = None # type: Dict[Treedef, Treedef]
_argdef_graph_map = None # type: Dict[TreeDef, "InternalGraph"]
_argdef_outdef_map = None # type: Dict[TreeDef, TreeDef]
nodes = None nodes = None


__builder_attributes__ = [ __builder_attributes__ = [
@@ -1371,9 +1369,8 @@ class TracedModuleBuilder(NodeMixin):
else module_tracer.is_builtin(mod) else module_tracer.is_builtin(mod)
) )
if isinstance(self._mod, QATModule): if isinstance(self._mod, QATModule):
unset_module_tracing()
self._check_qat_module(self._mod)
set_module_tracing()
with _exclude_from_trace():
self._check_qat_module(self._mod)
self._argdef_graph_map = {} self._argdef_graph_map = {}
self._argdef_outdef_map = {} self._argdef_outdef_map = {}


@@ -1458,18 +1455,17 @@ class TracedModuleBuilder(NodeMixin):
setattr(traced_module, k, v) setattr(traced_module, k, v)


if isinstance(self._mod, QATModule): if isinstance(self._mod, QATModule):
unset_module_tracing()
traced_module.with_act = self._mod.with_act
traced_module.with_weight = self._mod.with_weight
if not hasattr(traced_module, "act_fake_quant"):
traced_module.act_fake_quant = None
if not hasattr(traced_module, "act_observer"):
traced_module.act_observer = None
if not hasattr(traced_module, "weight_fake_quant"):
traced_module.weight_fake_quant = None
if not hasattr(traced_module, "weight_observer"):
traced_module.weight_observer = None
set_module_tracing()
with _exclude_from_trace():
traced_module.with_act = self._mod.with_act
traced_module.with_weight = self._mod.with_weight
if not hasattr(traced_module, "act_fake_quant"):
traced_module.act_fake_quant = None
if not hasattr(traced_module, "act_observer"):
traced_module.act_observer = None
if not hasattr(traced_module, "weight_fake_quant"):
traced_module.weight_fake_quant = None
if not hasattr(traced_module, "weight_observer"):
traced_module.weight_observer = None


if self._is_top: if self._is_top:
traced_module._update_ref() traced_module._update_ref()
@@ -1505,16 +1501,14 @@ class TracedModuleBuilder(NodeMixin):
callnode.arg_def = tree_def callnode.arg_def = tree_def


if self._is_builtin or tree_def in self._argdef_graph_map: if self._is_builtin or tree_def in self._argdef_graph_map:
unset_module_tracing()
rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
if _get_expr_checker():
with _exclude_from_trace():
with _exclude_from_trace():
rst = self._mod(*args, **kwargs)
outputs, out_def = tree_flatten(rst, is_leaf=_is_leaf)
if _get_expr_checker():
tmp = self.build() tmp = self.build()
active_module_tracer().checker.check_builtin_module( active_module_tracer().checker.check_builtin_module(
tmp, callnode, outputs tmp, callnode, outputs
) )
set_module_tracing()
if self._is_builtin: if self._is_builtin:
self._body = None self._body = None
elif tree_def in self._argdef_graph_map: elif tree_def in self._argdef_graph_map:
@@ -1640,16 +1634,17 @@ class TracedModuleBuilder(NodeMixin):


if isinstance(attr, (List, Dict)): if isinstance(attr, (List, Dict)):
flag = _set_graph_surgery_mode(False) flag = _set_graph_surgery_mode(False)
unset_module_tracing()
has_module, m_container = replace_container_with_module_container(attr)
if m_container:
attr = m_container
if has_module and not m_container:
raise ValueError(
"Can not trace the module that uses the same container to store"
" Module and Non-Module objects."
with _exclude_from_trace():
has_module, m_container = replace_container_with_module_container(
attr
) )
set_module_tracing()
if m_container:
attr = m_container
if has_module and not m_container:
raise ValueError(
"Can not trace the module that uses the same container to store"
" Module and Non-Module objects."
)
_set_graph_surgery_mode(flag) _set_graph_surgery_mode(flag)


if isinstance(attr, Module): if isinstance(attr, Module):


+ 10
- 8
imperative/python/megengine/traced_module/utils.py View File

@@ -5,8 +5,8 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import copy import copy
import contextlib
import inspect import inspect
from collections.abc import MutableMapping, MutableSequence from collections.abc import MutableMapping, MutableSequence
from inspect import FullArgSpec from inspect import FullArgSpec
@@ -71,7 +71,9 @@ def _convert_kwargs_to_args(
arg_specs_args = arg_specs.args arg_specs_args = arg_specs.args
arg_specs_defaults = arg_specs.defaults if arg_specs.defaults else [] arg_specs_defaults = arg_specs.defaults if arg_specs.defaults else []
arg_specs_kwonlyargs = arg_specs.kwonlyargs arg_specs_kwonlyargs = arg_specs.kwonlyargs
arg_specs_kwonlydefaults = arg_specs.kwonlydefaults if arg_specs.kwonlydefaults else dict()
arg_specs_kwonlydefaults = (
arg_specs.kwonlydefaults if arg_specs.kwonlydefaults else dict()
)
if is_bounded: if is_bounded:
arg_specs_args = arg_specs.args[1:] arg_specs_args = arg_specs.args[1:]
new_args = [] new_args = []
@@ -104,17 +106,17 @@ def _convert_kwargs_to_args(
new_kwargs[kwarg_name] = kwargs[kwarg_name] new_kwargs[kwarg_name] = kwargs[kwarg_name]
else: else:
if kwarg_name not in arg_specs_kwonlydefaults: if kwarg_name not in arg_specs_kwonlydefaults:
raise TypeError("{} missing required keyword-only argument: {}".format(
func_name, kwarg_name
))
raise TypeError(
"{} missing required keyword-only argument: {}".format(
func_name, kwarg_name
)
)
new_kwargs[kwarg_name] = arg_specs_kwonlydefaults[kwarg_name] new_kwargs[kwarg_name] = arg_specs_kwonlydefaults[kwarg_name]
for k, v in kwargs.items(): for k, v in kwargs.items():
if k not in arg_specs_args and k not in arg_specs_kwonlyargs: if k not in arg_specs_args and k not in arg_specs_kwonlyargs:
if arg_specs.varkw is None: if arg_specs.varkw is None:
raise TypeError( raise TypeError(
"{} got an unexpected keyword argument {}".format(
func_name, k
)
"{} got an unexpected keyword argument {}".format(func_name, k)
) )
new_kwargs[k] = v new_kwargs[k] = v
return tuple(new_args), new_kwargs return tuple(new_args), new_kwargs


Loading…
Cancel
Save