GitOrigin-RevId: ad2cdc1b61
release-1.6
@@ -9,12 +9,13 @@ | |||
import collections | |||
from typing import List | |||
from typing import Callable, List | |||
from ...core._imperative_rt import OpDef | |||
from ...core._imperative_rt.core2 import Tensor as RawTensor | |||
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | |||
from ...core.ops.special import Const | |||
from ...module import Module | |||
from ...tensor import Tensor | |||
from .module_tracer import active_module_tracer | |||
from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
@@ -22,12 +23,66 @@ from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
class Expr: | |||
""" | |||
``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``. | |||
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``. | |||
""" | |||
inputs = None # type: List[Node] | |||
outputs = None # type: List[Node] | |||
def add_input(self, node): | |||
self.inputs.append(node) | |||
def add_outputs(self, outputs): | |||
self.outputs = [] | |||
if not isinstance(outputs, collections.Sequence): | |||
outputs = (outputs,) | |||
for i in outputs: | |||
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||
for i, node in zip(outputs, self.outputs,): | |||
NodeMixin.wrap_safe(i, node) | |||
@classmethod | |||
def get_args_node(cls, arg): | |||
""" | |||
Create nodes by ``arg``, which may be a container. | |||
Return the same structure with arg. | |||
If ``arg`` was not Tensor or Module, it will be stored as const. | |||
:param arg: tensor, module or const. | |||
""" | |||
if isinstance(arg, (RawTensor, Module)): | |||
if not NodeMixin.get(arg, None): | |||
NodeMixin.wrap_safe(arg, Constant.make(arg)) | |||
return NodeMixin.get(arg) | |||
elif isinstance(arg, collections.abc.Sequence): | |||
seq_cls = type(arg) | |||
return seq_cls([Expr.get_args_node(a) for a in arg]) | |||
else: | |||
# TODO: assert arg type | |||
return arg # as const | |||
@classmethod | |||
def get_arg_value(cls, inp_node, node2value): | |||
""" | |||
Get values from node2value by inp_node, which may be a container. | |||
Return the same structure with inp_node. | |||
If ``inp_node`` was not in node2value, it is a const. | |||
:param inp_node: nodes. | |||
:param node2value: dict from node to tensor and module. | |||
""" | |||
if inp_node in node2value: | |||
return node2value[inp_node] | |||
elif isinstance(inp_node, collections.abc.Sequence): | |||
seq_cls = type(inp_node) | |||
return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node]) | |||
else: | |||
return inp_node | |||
# expr: None (i.e. fake expression which is used to mark input) | |||
class Input(Expr): | |||
@@ -83,23 +138,22 @@ class GetAttr(Expr): | |||
# expr: outputs = inputs[0].__call__(*inputs[1:]) | |||
class Call(Expr): | |||
def __init__(self, module): | |||
assert isinstance(module, ModuleNode) | |||
class CallMethod(Expr): | |||
def __init__(self, module, method="__call__"): | |||
assert isinstance(module, (TensorNode, ModuleNode)) | |||
self.inputs = [ | |||
module, | |||
] | |||
self.method = method | |||
self.arg_names = [] | |||
self.kwargs = {} # const kwargs | |||
def add_input(self, node): | |||
def add_input(self, node, arg_name=None): | |||
if arg_name == "self": # FIXME: <XP> | |||
return | |||
self.inputs.append(node) | |||
def add_outputs(self, references): | |||
self.outputs = [] | |||
if not isinstance(references, collections.Sequence): | |||
references = (references,) | |||
for i in references: | |||
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||
if arg_name is not None: | |||
self.arg_names.append(arg_name) | |||
@classmethod | |||
def make(cls, *args, **kwargs): | |||
@@ -110,15 +164,16 @@ class Call(Expr): | |||
def interpret(self, *inputs): | |||
mod = inputs[0] | |||
args = inputs[1:] | |||
outputs = mod(*args) | |||
outputs = getattr(mod, self.method)(*args, **self.kwargs) | |||
if isinstance(outputs, RawTensor): | |||
outputs = (outputs,) | |||
return outputs | |||
def __repr__(self): | |||
return "{} = Call({})({})".format( | |||
return "{} = CallMethod({}, {})({})".format( | |||
", ".join(str(i) for i in self.outputs), | |||
self.inputs[0], | |||
self.method, | |||
", ".join(str(i) for i in self.inputs[1:]), | |||
) | |||
@@ -132,17 +187,6 @@ class Apply(Expr): | |||
self.opdef = opdef | |||
self.inputs = [] | |||
def add_input(self, node): | |||
self.inputs.append(node) | |||
def add_outputs(self, references): | |||
self.outputs = [] | |||
if not isinstance(references, collections.Sequence): | |||
references = (references,) | |||
for i in references: | |||
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||
@classmethod | |||
def make(cls, *args, **kwargs): | |||
expr = cls(*args, **kwargs) | |||
@@ -179,6 +223,40 @@ class Apply(Expr): | |||
return list(outputs) | |||
class CallFunction(Expr): | |||
def __init__(self, func): | |||
assert isinstance(func, Callable) | |||
self.func = func | |||
self.inputs = [] | |||
self.arg_names = [] | |||
self.kwargs = {} # const kwargs | |||
def add_input(self, node, arg_name): | |||
self.inputs.append(node) | |||
self.arg_names.append(arg_name) | |||
@classmethod | |||
def make(cls, *args, **kwargs): | |||
expr = cls(*args, **kwargs) | |||
active_module_tracer().current_scope().insert(expr) | |||
return expr | |||
def interpret(self, *inputs): | |||
inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)]) | |||
outputs = self.func(**inp_dict, **self.kwargs) | |||
outputs = ( | |||
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | |||
) | |||
return outputs | |||
def __repr__(self): | |||
return "{} = {}({})".format( | |||
", ".join(str(i) for i in self.outputs), | |||
self.func.__module__ + "." + self.func.__name__, | |||
", ".join(str(i) for i in self.inputs), | |||
) | |||
# expr outputs = self.value | |||
class Constant(Expr): | |||
value = None | |||
@@ -6,7 +6,11 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
from ... import Tensor | |||
from ... import functional as F | |||
from ...core.tensor.array_method import ArrayMethodMixin | |||
from ...module import Module | |||
_active_module_tracer = None | |||
@@ -23,12 +27,14 @@ def set_active_module_tracer(tracer): | |||
class module_tracer: | |||
# builtin types | |||
_opaque_types = set() | |||
_active_scopes = None | |||
def __init__(self): | |||
def __init__(self, wrap_fn): | |||
self._active_scopes = [] | |||
self.patcher = Patcher(wrap_fn) | |||
@classmethod | |||
def register_as_builtin(cls, mod): | |||
@@ -50,3 +56,105 @@ class module_tracer: | |||
if self._active_scopes: | |||
return self._active_scopes[-1] | |||
return None | |||
class PatchedFn: | |||
frame_dict = None | |||
name = None | |||
origin_fn = None | |||
def __init__(self, frame_dict, name): | |||
self.frame_dict = frame_dict | |||
self.name = name | |||
self.origin_fn = ( | |||
self.frame_dict[name] | |||
if isinstance(frame_dict, collections.abc.Mapping) | |||
else getattr(frame_dict, name) | |||
) | |||
def set_func(self, func): | |||
if isinstance(self.frame_dict, collections.abc.Mapping): | |||
self.frame_dict[self.name] = func | |||
else: | |||
setattr(self.frame_dict, self.name, func) | |||
class Patcher: | |||
patched_fn_ids = set() | |||
_builtin_functions = [] | |||
_builtin_modules = [ | |||
F, | |||
F.distributed, | |||
F.elemwise, | |||
F.inplace, | |||
F.loss, | |||
F.math, | |||
F.metric, | |||
F.nn, | |||
F.quantized, | |||
F.tensor, | |||
F.utils, | |||
F.vision, | |||
] | |||
_builtin_methods = [ | |||
Tensor, | |||
ArrayMethodMixin, | |||
] | |||
def __init__(self, wrap_fn): | |||
self.patched_fn = [] | |||
self.visited_frames_ids = set() | |||
self.wrap_fn = wrap_fn | |||
for module in self._builtin_modules: | |||
self.patch_module(module) | |||
for cls in self._builtin_methods: | |||
self.patch_cls(cls) | |||
for i, j in self._builtin_functions: | |||
if id(i) not in self.visited_frames_ids: | |||
self.patch_function(i, j, self.wrap_fn) | |||
def patch_function(self, frame_dict, fn, wrap_fn): | |||
patched_fn = PatchedFn(frame_dict, fn) | |||
self.patched_fn_ids.add(id(patched_fn.origin_fn)) | |||
patched_fn.set_func(wrap_fn(patched_fn.origin_fn)) | |||
self.patched_fn.append(patched_fn) | |||
def patch_method(self, cls, name, wrap_fn): | |||
self.patch_function(cls, name, wrap_fn) | |||
def patch_cls(self, cls): | |||
import inspect | |||
if id(cls) not in self.visited_frames_ids: | |||
for k, v in cls.__dict__.items(): | |||
if inspect.isfunction(v) and not k.startswith("_"): | |||
self.patch_function(cls, k, self.wrap_fn) | |||
self.visited_frames_ids.add(id(cls)) | |||
def patch_module(self, module): | |||
import inspect | |||
if id(module.__dict__) not in self.visited_frames_ids: | |||
for k, v in module.__dict__.items(): | |||
if inspect.isfunction(v) and not k.startswith("_"): | |||
self.patch_function(module.__dict__, k, self.wrap_fn) | |||
self.visited_frames_ids.add(id(module.__dict__)) | |||
def auto_patch(self, frame_dict): | |||
if id(frame_dict) not in self.visited_frames_ids: | |||
for k, v in frame_dict.items(): | |||
if id(v) in self.patched_fn_ids: | |||
self.patch_function(frame_dict, k, self.wrap_fn) | |||
self.visited_frames_ids.add(id(frame_dict)) | |||
def __enter__(self): | |||
return self | |||
def __exit__(self, type, vlaue, trace): | |||
while self.patched_fn: | |||
pf = self.patched_fn.pop() | |||
pf.set_func(pf.origin_fn) | |||
self.visited_frames_ids.clear() |
@@ -34,6 +34,10 @@ class Node: | |||
Node.__total_id += 1 | |||
self._name = name | |||
def __setstate__(self, d): | |||
self.__dict__ = d | |||
Node.__total_id = max(Node.__total_id, self._id) + 1 | |||
def __repr__(self): | |||
if self._name is None: | |||
return "%{}".format(self._id) | |||
@@ -8,14 +8,25 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import copy | |||
import functools | |||
from typing import List, Type | |||
from ... import module as M | |||
from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing | |||
from ...core._imperative_rt.core2 import ( | |||
is_tracing_module, | |||
set_module_tracing, | |||
unset_module_tracing, | |||
) | |||
from ...core.tensor.array_method import ArrayMethodMixin | |||
from ...module import Module | |||
from ...tensor import Tensor | |||
from .expr import Apply, Call, Constant, Expr, GetAttr, Input | |||
from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer | |||
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input | |||
from .module_tracer import ( | |||
Patcher, | |||
active_module_tracer, | |||
module_tracer, | |||
set_active_module_tracer, | |||
) | |||
from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
@@ -54,7 +65,9 @@ class InternalGraph: | |||
for n, v in zip(self._inputs, inputs): | |||
node2value[n] = v | |||
for expr in self._exprs: | |||
values = expr.interpret(*list(node2value[i] for i in expr.inputs)) | |||
values = expr.interpret( | |||
*list(Expr.get_arg_value(i, node2value) for i in expr.inputs) | |||
) | |||
for n, v in zip(expr.outputs, values): | |||
node2value[n] = v | |||
return list(node2value[i] for i in self._outputs) | |||
@@ -67,6 +80,41 @@ class InternalGraph: | |||
) | |||
def _wrapped_function(orig_func): | |||
@functools.wraps(orig_func) | |||
def wrapped_fn(*inputs, **kwargs): | |||
if is_tracing_module(): | |||
unset_module_tracing() | |||
const_kwargs = {} | |||
arg_names = orig_func.__code__.co_varnames | |||
if orig_func.__qualname__.split(".").__len__() > 1: | |||
# FIXME: a robust way to distinguish method and function. <XP> | |||
self = inputs[0] | |||
call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__) | |||
else: | |||
call_node = CallFunction.make(orig_func) | |||
def add_input(inp, varname=None): | |||
node = Expr.get_args_node(inp) | |||
if node is not None: | |||
call_node.add_input(node, varname) | |||
else: | |||
const_kwargs[varname] = inp | |||
for ind, inp in enumerate(inputs): | |||
add_input(inp, arg_names[ind]) | |||
for k, v in kwargs.items(): | |||
add_input(v, k) | |||
call_node.kwargs = const_kwargs | |||
outputs = orig_func(*inputs, **kwargs) | |||
call_node.add_outputs(outputs) | |||
set_module_tracing() | |||
return outputs | |||
return orig_func(*inputs, **kwargs) | |||
return wrapped_fn | |||
class TracedModuleBuilder(NodeMixin): | |||
_mod = None # type: Module | |||
@@ -120,7 +168,7 @@ class TracedModuleBuilder(NodeMixin): | |||
mark_constant(i) | |||
for k, v in kwargs.items(): | |||
mark_constant(v) | |||
callnode = Call.make(NodeMixin.get(self)) | |||
callnode = CallMethod.make(NodeMixin.get(self)) | |||
def add_input(x): | |||
callnode.add_input(NodeMixin.get(x)) | |||
@@ -145,7 +193,8 @@ class TracedModuleBuilder(NodeMixin): | |||
) | |||
# prepare args and kwargs for inner graph | |||
def wrap(x): | |||
wrapped = copy.copy(x) # FIXME | |||
# wrapped = copy.copy(x) # FIXME | |||
wrapped = x # FIXME: <XP> | |||
NodeMixin.wrap( | |||
wrapped, | |||
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), | |||
@@ -157,7 +206,9 @@ class TracedModuleBuilder(NodeMixin): | |||
args.append(wrap(i)) | |||
for k, v in kwargs.items(): | |||
kwargs[k] = wrap(v) | |||
active_module_tracer().patcher.auto_patch( | |||
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) | |||
) | |||
outputs = type(self._mod).forward(self, *args, **kwargs) | |||
for i in ( | |||
@@ -171,11 +222,6 @@ class TracedModuleBuilder(NodeMixin): | |||
# rebind output to outer graph | |||
callnode.add_outputs(outputs) | |||
for i, node in zip( | |||
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,), | |||
callnode.outputs, | |||
): | |||
NodeMixin.wrap_safe(i, node) | |||
return outputs | |||
def __getattr__(self, name): | |||
@@ -229,6 +275,55 @@ class TracedModule(Module): | |||
rst = rst[0] | |||
return rst | |||
@property | |||
def all_exprs(self): | |||
""" | |||
Visit all ``Expr``s in the graph recursively. | |||
:return: List[Expr] | |||
""" | |||
in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self] | |||
def _flatten_submodule(module, call=None): | |||
if not isinstance(module, TracedModule): | |||
call.inputs[0] = module | |||
return (call,) | |||
exprs = [] | |||
graph = module.m_node.graph | |||
for expr in graph._exprs: | |||
# replace inputs for submodule's expr | |||
for idx, inp in enumerate(expr.inputs): | |||
if call and inp in graph._inputs: | |||
expr.inputs[idx] = call.inputs[idx] | |||
# replace outputs for submodule's expr | |||
for idx, outp in enumerate(expr.outputs): | |||
if call and outp in graph._outputs: | |||
expr.outputs[idx] = call.outputs[idx] | |||
if isinstance(expr, GetAttr): | |||
# replace GetAttr with Constant | |||
if isinstance(expr.outputs[0], TensorNode): | |||
const = Constant(getattr(module, expr.name)) | |||
const.outputs = expr.outputs | |||
exprs.append(const) | |||
elif isinstance(expr, CallMethod): | |||
obj_node = expr.inputs[0] | |||
if isinstance(obj_node, ModuleNode): | |||
(obj,) = expr.inputs[0].expr.interpret(module) | |||
exprs.extend(_flatten_submodule(obj, expr)) | |||
else: | |||
exprs.append(expr) | |||
else: | |||
exprs.append(expr) | |||
return exprs | |||
return in_nodes + _flatten_submodule(self) | |||
def __getstate__(self): | |||
d = self.__dict__ | |||
for k in Module.__dict__: | |||
@@ -273,23 +368,23 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule | |||
assert active_module_tracer() is None | |||
try: | |||
set_module_tracing() | |||
set_active_module_tracer(module_tracer()) | |||
global_scope = InternalGraph() | |||
active_module_tracer().push_scope(global_scope) | |||
set_active_module_tracer(module_tracer(_wrapped_function)) | |||
with active_module_tracer().patcher: | |||
global_scope = InternalGraph() | |||
active_module_tracer().push_scope(global_scope) | |||
builder = TracedModuleBuilder(mod) | |||
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | |||
builder = TracedModuleBuilder(mod) | |||
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | |||
for _, i in enumerate(inputs): | |||
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) | |||
for k, v in kwargs.items(): | |||
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) | |||
for _, i in enumerate(inputs): | |||
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) | |||
for k, v in kwargs.items(): | |||
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) | |||
builder(*inputs, **kwargs) | |||
active_module_tracer().pop_scope() | |||
builder(*inputs, **kwargs) | |||
active_module_tracer().pop_scope() | |||
return builder.build() | |||
return builder.build() | |||
finally: | |||
set_active_module_tracer(None) | |||
unset_module_tracing() |