@@ -130,3 +130,4 @@ import megengine.optimizer | |||||
import megengine.quantization | import megengine.quantization | ||||
import megengine.random | import megengine.random | ||||
import megengine.utils | import megengine.utils | ||||
import megengine.experimental |
@@ -6,4 +6,5 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from . import traced_module | |||||
from .weight_scaler import get_scaled_model | from .weight_scaler import get_scaled_model |
@@ -5,3 +5,15 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ...core._imperative_rt.core2 import set_cpp_apply_module_trace | |||||
from .traced_module import ( | |||||
TracedModule, | |||||
_register_all_builtin_module, | |||||
cpp_apply_module_trace, | |||||
register_as_builtin, | |||||
trace_module, | |||||
) | |||||
_register_all_builtin_module() | |||||
set_cpp_apply_module_trace(cpp_apply_module_trace) |
@@ -0,0 +1,215 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import collections | |||||
from typing import List | |||||
from ...core._imperative_rt import OpDef | |||||
from ...core._imperative_rt.core2 import Tensor as RawTensor | |||||
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | |||||
from ...core.ops.special import Const | |||||
from ...tensor import Tensor | |||||
from .module_tracer import active_module_tracer | |||||
from .node import ModuleNode, Node, NodeMixin, TensorNode | |||||
class Expr: | |||||
""" | |||||
``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``. | |||||
""" | |||||
inputs = None # type: List[Node] | |||||
outputs = None # type: List[Node] | |||||
# expr: None (i.e. fake expression which is used to mark input) | |||||
class Input(Expr): | |||||
name = None | |||||
def __init__(self, name=None, type=None): | |||||
self.inputs = [] | |||||
node_cls = type if type else Node | |||||
self.outputs = [ | |||||
node_cls(self, name=name), | |||||
] | |||||
self.name = name | |||||
@classmethod | |||||
def make(cls, *args, **kwargs): | |||||
expr = cls(*args, **kwargs) | |||||
active_module_tracer().current_scope().add_input(expr.outputs[0]) | |||||
return expr.outputs[0] | |||||
def __repr__(self): | |||||
return "{} = Input({})".format(self.outputs[0], self.name) | |||||
# expr: outputs = getattr(inputs[0], self.name) | |||||
class GetAttr(Expr): | |||||
name = None | |||||
def __init__(self, module, name, type=None): | |||||
assert isinstance(module, ModuleNode) | |||||
self.inputs = [ | |||||
module, | |||||
] | |||||
self.name = name | |||||
node_cls = type if type else Node | |||||
self.outputs = [ | |||||
node_cls(self), | |||||
] | |||||
@classmethod | |||||
def make(cls, *args, **kwargs): | |||||
expr = cls(*args, **kwargs) | |||||
active_module_tracer().current_scope().insert(expr) | |||||
expr.outputs[0]._name = expr.name | |||||
return expr.outputs[0] | |||||
def interpret(self, *inputs): | |||||
return (getattr(inputs[0], self.name),) | |||||
def __repr__(self): | |||||
return '{} = GetAttr({}, "{}")'.format( | |||||
self.outputs[0], self.inputs[0], self.name | |||||
) | |||||
# expr: outputs = inputs[0].__call__(*inputs[1:]) | |||||
class Call(Expr): | |||||
def __init__(self, module): | |||||
assert isinstance(module, ModuleNode) | |||||
self.inputs = [ | |||||
module, | |||||
] | |||||
def add_input(self, node): | |||||
self.inputs.append(node) | |||||
def add_outputs(self, references): | |||||
self.outputs = [] | |||||
if not isinstance(references, collections.Sequence): | |||||
references = (references,) | |||||
for i in references: | |||||
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||||
@classmethod | |||||
def make(cls, *args, **kwargs): | |||||
expr = cls(*args, **kwargs) | |||||
active_module_tracer().current_scope().insert(expr) | |||||
return expr | |||||
def interpret(self, *inputs): | |||||
mod = inputs[0] | |||||
args = inputs[1:] | |||||
outputs = mod(*args) | |||||
if isinstance(outputs, RawTensor): | |||||
outputs = (outputs,) | |||||
return outputs | |||||
def __repr__(self): | |||||
return "{} = Call({})({})".format( | |||||
", ".join(str(i) for i in self.outputs), | |||||
self.inputs[0], | |||||
", ".join(str(i) for i in self.inputs[1:]), | |||||
) | |||||
# expr: outputs = apply(self.opdef, *inputs) | |||||
class Apply(Expr): | |||||
opdef = None | |||||
def __init__(self, opdef): | |||||
assert isinstance(opdef, OpDef) | |||||
self.opdef = opdef | |||||
self.inputs = [] | |||||
def add_input(self, node): | |||||
self.inputs.append(node) | |||||
def add_outputs(self, references): | |||||
self.outputs = [] | |||||
if not isinstance(references, collections.Sequence): | |||||
references = (references,) | |||||
for i in references: | |||||
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||||
@classmethod | |||||
def make(cls, *args, **kwargs): | |||||
expr = cls(*args, **kwargs) | |||||
active_module_tracer().current_scope().insert(expr) | |||||
return expr | |||||
def interpret(self, *inputs): | |||||
return apply(self.opdef, *inputs) | |||||
def __repr__(self): | |||||
return "{} = {}({})".format( | |||||
", ".join(str(i) for i in self.outputs), | |||||
self.opdef, | |||||
", ".join(str(i) for i in self.inputs), | |||||
) | |||||
@classmethod | |||||
def apply_module_trace_hook(cls, opdef, *inputs): | |||||
for i in inputs: | |||||
node = NodeMixin.get(i, None) | |||||
if node is None: # capture as constant | |||||
NodeMixin.wrap_safe(i, Constant.make(i)) | |||||
apply_node = cls.make(opdef) | |||||
for i in inputs: | |||||
apply_node.add_input(NodeMixin.get(i)) | |||||
unset_module_tracing() | |||||
outputs = apply(opdef, *inputs) | |||||
set_module_tracing() | |||||
apply_node.add_outputs(outputs) | |||||
for n, v in zip(apply_node.outputs, outputs): | |||||
NodeMixin.wrap_safe(v, n) | |||||
return list(outputs) | |||||
# expr outputs = self.value | |||||
class Constant(Expr): | |||||
value = None | |||||
# TODO: constant cache to reduce the size of dumped model | |||||
_constant_cache = {} | |||||
def __init__(self, c): | |||||
# TODO: type check, since not all types should be captured as constant | |||||
self.value = c | |||||
self.inputs = [] | |||||
node_cls = NodeMixin.get_wrapped_type(c) | |||||
self.outputs = [ | |||||
node_cls(self), | |||||
] | |||||
@classmethod | |||||
def make(cls, *args, **kwargs): | |||||
expr = cls(*args, **kwargs) | |||||
active_module_tracer().current_scope().insert(expr) | |||||
return expr.outputs[0] | |||||
def interpret(self, *inputs): | |||||
if isinstance(self.value, RawTensor): | |||||
return Const(self.value.numpy())() | |||||
return (self.value,) | |||||
def __repr__(self): | |||||
return "{} = Constant({})".format(self.outputs[0], self.value) | |||||
def __getstate__(self): | |||||
state = self.__dict__.copy() | |||||
if isinstance(self.value, RawTensor): | |||||
state["value"] = Tensor(self.value) | |||||
return state |
@@ -0,0 +1,52 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ...module import Module | |||||
_active_module_tracer = None | |||||
def active_module_tracer(): | |||||
return _active_module_tracer | |||||
def set_active_module_tracer(tracer): | |||||
global _active_module_tracer | |||||
_active_module_tracer = tracer | |||||
class module_tracer: | |||||
_opaque_types = set() | |||||
_active_scopes = None | |||||
def __init__(self): | |||||
self._active_scopes = [] | |||||
@classmethod | |||||
def register_as_builtin(cls, mod): | |||||
assert issubclass(mod, Module) | |||||
cls._opaque_types.add(mod) | |||||
return mod | |||||
@classmethod | |||||
def is_builtin(cls, mod): | |||||
return type(mod) in cls._opaque_types | |||||
def push_scope(self, scope): | |||||
self._active_scopes.append(scope) | |||||
def pop_scope(self): | |||||
self._active_scopes.pop() | |||||
def current_scope(self): | |||||
if self._active_scopes: | |||||
return self._active_scopes[-1] | |||||
return None |
@@ -0,0 +1,123 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from typing import Any, Dict, Tuple, Type | |||||
import numpy | |||||
from ...core._imperative_rt.core2 import Tensor as RawTensor | |||||
from ...module import Module | |||||
from ...tensor import Tensor | |||||
class Node: | |||||
""" | |||||
``Node`` represents the variables (Tensor/Module/other python object) used in Module's forward method. They are inputs/outputs of Expr(the operations on variables). | |||||
param expr: the Expr which produces the node | |||||
param name: the name of the node | |||||
""" | |||||
expr = None | |||||
__total_id = 0 | |||||
_id = None | |||||
_name = None | |||||
def __init__(self, expr: "Expr", name: str = None): | |||||
self.expr = expr | |||||
self._id = Node.__total_id | |||||
Node.__total_id += 1 | |||||
self._name = name | |||||
def __repr__(self): | |||||
if self._name is None: | |||||
return "%{}".format(self._id) | |||||
else: | |||||
return "%{}".format(self._name) | |||||
class ModuleNode(Node): | |||||
""" | |||||
``ModuleNode`` represents the Module objects. | |||||
Attributes: | |||||
module_type: type of the Module correspending to the ModuleNode | |||||
graph: the InternalGraph which will be interpreted when call Module's forward method | |||||
attr_type_map: record the type of Module's attributes | |||||
""" | |||||
module_type = Module # type: Type[Module] | |||||
graph = None | |||||
attr_type_map = None # type: Dict[str, Type[Any]] | |||||
def __repr__(self): | |||||
if self._name is None: | |||||
return "%{}({})".format(self._id, self.module_type.__name__) | |||||
else: | |||||
return "%{}({})".format(self._name, self.module_type.__name__) | |||||
class TensorNode(Node): | |||||
""" | |||||
``TensorNode`` represents the Tensor objects. | |||||
""" | |||||
shape = None # type: Tuple[int] | |||||
dtype = None # type: numpy.dtype | |||||
def __repr__(self): | |||||
if self._name is None: | |||||
return "%{}(Tensor)".format(self._id) | |||||
else: | |||||
return "%{}(Tensor)".format(self._name) | |||||
class NodeMixin: | |||||
__node = None | |||||
@classmethod | |||||
def wrap(cls, value, node): | |||||
if isinstance(value, (NodeMixin, RawTensor)): | |||||
if isinstance(node, Node): | |||||
if isinstance(value, RawTensor): | |||||
node.dtype = value.dtype | |||||
node.shape = ( | |||||
value._tuple_shape if isinstance(value, Tensor) else value.shape | |||||
) | |||||
setattr(value, "_NodeMixin__node", node) | |||||
else: | |||||
assert callable(node) | |||||
n = node() | |||||
if isinstance(value, RawTensor): | |||||
n.dtype = value.dtype | |||||
n.shape = ( | |||||
value._tuple_shape if isinstance(value, Tensor) else value.shape | |||||
) | |||||
setattr(value, "_NodeMixin__node", n) | |||||
@classmethod | |||||
def wrap_safe(cls, value, node): | |||||
assert isinstance(value, (NodeMixin, RawTensor)) | |||||
if isinstance(value, RawTensor): | |||||
node.dtype = value.dtype | |||||
node.shape = ( | |||||
value._tuple_shape if isinstance(value, Tensor) else value.shape | |||||
) | |||||
setattr(value, "_NodeMixin__node", node) | |||||
@classmethod | |||||
def get(cls, value, *default): | |||||
return getattr(value, "_NodeMixin__node", *default) | |||||
@classmethod | |||||
def get_wrapped_type(cls, value): | |||||
if isinstance(value, RawTensor): | |||||
return TensorNode | |||||
if isinstance(value, (Module, NodeMixin)): | |||||
return ModuleNode | |||||
return Node |
@@ -0,0 +1,295 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import collections | |||||
import copy | |||||
from typing import List, Type | |||||
from ... import module as M | |||||
from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing | |||||
from ...module import Module | |||||
from ...tensor import Tensor | |||||
from .expr import Apply, Call, Constant, Expr, GetAttr, Input | |||||
from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer | |||||
from .node import ModuleNode, Node, NodeMixin, TensorNode | |||||
class InternalGraph: | |||||
""" | |||||
``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method. | |||||
Attributes: | |||||
_exprs: List of Exprs in order of execution | |||||
_inputs: Input Nodes of InternalGraph | |||||
_outputs: Output Nodes of InternalGraph | |||||
""" | |||||
_exprs = None # type: List[Expr] | |||||
_inputs = None # type: List[Node] | |||||
_outputs = None # type: List[Node] | |||||
def __init__(self): | |||||
self._exprs = [] | |||||
self._inputs = [] | |||||
self._outputs = [] | |||||
def insert(self, expr): | |||||
self._exprs.append(expr) | |||||
def add_input(self, i): | |||||
self._inputs.append(i) | |||||
def add_output(self, o): | |||||
self._outputs.append(o) | |||||
def interpret(self, *inputs): | |||||
# TODO: support kwargs ? | |||||
# TODO: skip expressions which are independent and have no side effect | |||||
node2value = {} | |||||
for n, v in zip(self._inputs, inputs): | |||||
node2value[n] = v | |||||
for expr in self._exprs: | |||||
values = expr.interpret(*list(node2value[i] for i in expr.inputs)) | |||||
for n, v in zip(expr.outputs, values): | |||||
node2value[n] = v | |||||
return list(node2value[i] for i in self._outputs) | |||||
def __repr__(self): | |||||
return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format( | |||||
", ".join(str(i) for i in self._inputs), | |||||
"\n\t".join(str(i) for i in self._exprs), | |||||
", ".join(str(i) for i in self._outputs), | |||||
) | |||||
class TracedModuleBuilder(NodeMixin): | |||||
_mod = None # type: Module | |||||
_body = None # type: InternalGraph | |||||
_is_builtin = None # type: bool | |||||
__builder_attributes__ = [ | |||||
"_mod", | |||||
"_body", | |||||
"_NodeMixin__node", | |||||
"_is_builtin", | |||||
"_is_traced", | |||||
"build", | |||||
] | |||||
def __init__(self, mod): | |||||
super(TracedModuleBuilder, self).__init__() | |||||
self._mod = mod | |||||
self._body = InternalGraph() | |||||
self._is_traced = False | |||||
self._is_builtin = module_tracer.is_builtin(mod) | |||||
def build(self): | |||||
if self._is_builtin: | |||||
node = NodeMixin.get(self) | |||||
node.module_type = type(self._mod) | |||||
return self._mod | |||||
else: | |||||
node = NodeMixin.get(self) | |||||
node.graph = self._body | |||||
node.attr_type_map = {} | |||||
traced_module = TracedModule(node) | |||||
for k, v in self.__dict__.items(): | |||||
if k not in TracedModuleBuilder.__builder_attributes__: | |||||
if isinstance(v, TracedModuleBuilder): | |||||
v = v.build() | |||||
setattr(traced_module, k, v) | |||||
traced_module.m_node.attr_type_map[k] = type(v) | |||||
return traced_module | |||||
def __call__(self, *inputs, **kwargs): | |||||
assert isinstance(self._mod, Module) | |||||
# prepare args and kwargs for inner graph | |||||
def mark_constant(x): | |||||
node = NodeMixin.get(x, None) | |||||
if node is None: # capture as constant | |||||
NodeMixin.wrap(x, lambda: Constant.make(x)) | |||||
for i in inputs: | |||||
mark_constant(i) | |||||
for k, v in kwargs.items(): | |||||
mark_constant(v) | |||||
callnode = Call.make(NodeMixin.get(self)) | |||||
def add_input(x): | |||||
callnode.add_input(NodeMixin.get(x)) | |||||
for i in inputs: | |||||
add_input(i) | |||||
for k, v in kwargs.items(): | |||||
add_input(v) | |||||
if self._is_builtin or self._is_traced: | |||||
unset_module_tracing() | |||||
outputs = self._mod(*inputs, **kwargs) | |||||
set_module_tracing() | |||||
if self._is_builtin: | |||||
self._body = None | |||||
else: | |||||
active_module_tracer().push_scope(self._body) | |||||
# rebind self to new input node | |||||
orig_self = NodeMixin.get(self) | |||||
NodeMixin.wrap_safe( | |||||
self, Input.make("self", NodeMixin.get_wrapped_type(self)) | |||||
) | |||||
# prepare args and kwargs for inner graph | |||||
def wrap(x): | |||||
wrapped = copy.copy(x) # FIXME | |||||
NodeMixin.wrap( | |||||
wrapped, | |||||
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), | |||||
) | |||||
return wrapped | |||||
args = [] | |||||
for i in inputs: | |||||
args.append(wrap(i)) | |||||
for k, v in kwargs.items(): | |||||
kwargs[k] = wrap(v) | |||||
outputs = type(self._mod).forward(self, *args, **kwargs) | |||||
for i in ( | |||||
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | |||||
): | |||||
active_module_tracer().current_scope().add_output(NodeMixin.get(i)) | |||||
NodeMixin.wrap_safe(self, orig_self) | |||||
self._is_traced = True | |||||
active_module_tracer().pop_scope() | |||||
# rebind output to outer graph | |||||
callnode.add_outputs(outputs) | |||||
for i, node in zip( | |||||
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,), | |||||
callnode.outputs, | |||||
): | |||||
NodeMixin.wrap_safe(i, node) | |||||
return outputs | |||||
def __getattr__(self, name): | |||||
if name not in self._mod.__dict__: | |||||
attr = getattr(type(self._mod), name).__get__(self, type(self)) | |||||
else: | |||||
attr = getattr(self._mod, name) | |||||
if isinstance(attr, Module): | |||||
attr = TracedModuleBuilder(attr) | |||||
setattr(self, name, attr) | |||||
NodeMixin.wrap( | |||||
attr, | |||||
lambda: GetAttr.make( | |||||
NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr) | |||||
), | |||||
) | |||||
return attr | |||||
def __getattribute__(self, name): | |||||
if name in TracedModuleBuilder.__builder_attributes__: | |||||
return super().__getattribute__(name) | |||||
else: | |||||
wrapped = super().__getattribute__(name) | |||||
if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None): | |||||
assert not self._is_builtin | |||||
NodeMixin.wrap( | |||||
wrapped, | |||||
lambda: GetAttr.make( | |||||
NodeMixin.get(self), | |||||
name, | |||||
type=NodeMixin.get_wrapped_type(wrapped), | |||||
), | |||||
) | |||||
return wrapped | |||||
class TracedModule(Module): | |||||
""" | |||||
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called. | |||||
""" | |||||
m_node = None # type: ModuleNode | |||||
def __init__(self, node): | |||||
super(TracedModule, self).__init__() | |||||
self.m_node = node | |||||
def forward(self, *inputs): | |||||
rst = self.m_node.graph.interpret(self, *inputs) | |||||
if len(rst) == 1: | |||||
rst = rst[0] | |||||
return rst | |||||
def __getstate__(self): | |||||
d = self.__dict__ | |||||
for k in Module.__dict__: | |||||
d.pop(k, None) | |||||
return d | |||||
def cpp_apply_module_trace(opdef, *args): | |||||
return Apply.apply_module_trace_hook(opdef, *args) | |||||
def register_as_builtin(mod_cls: Type[Module]) -> None: | |||||
""" | |||||
Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module. | |||||
param mod_cls: the Module class which will be threated as builtin module in tracing | |||||
""" | |||||
module_tracer.register_as_builtin(mod_cls) | |||||
def _register_all_builtin_module(): | |||||
from inspect import getmembers, isclass | |||||
for sub_mod in [M, M.qat, M.quantized]: | |||||
for m in getmembers(sub_mod): | |||||
if ( | |||||
isclass(m[1]) | |||||
and issubclass(m[1], M.Module) | |||||
and m[1] is not M.Sequential | |||||
): | |||||
module_tracer.register_as_builtin(m[1]) | |||||
def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule: | |||||
""" | |||||
Traces module ``mod`` and returns corresponding TracedModule. | |||||
param mod: the module will be converted to TracedModule | |||||
param input: the positional arguments passed to forward method of ``mod`` | |||||
param kwargs: the keyword arguments passed to forward method of ``mod`` | |||||
""" | |||||
assert active_module_tracer() is None | |||||
try: | |||||
set_module_tracing() | |||||
set_active_module_tracer(module_tracer()) | |||||
global_scope = InternalGraph() | |||||
active_module_tracer().push_scope(global_scope) | |||||
builder = TracedModuleBuilder(mod) | |||||
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | |||||
for _, i in enumerate(inputs): | |||||
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) | |||||
for k, v in kwargs.items(): | |||||
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) | |||||
builder(*inputs, **kwargs) | |||||
active_module_tracer().pop_scope() | |||||
return builder.build() | |||||
finally: | |||||
set_active_module_tracer(None) | |||||
unset_module_tracing() |