GitOrigin-RevId: ad2cdc1b61
release-1.6
@@ -9,12 +9,13 @@ | |||||
import collections | import collections | ||||
from typing import List | |||||
from typing import Callable, List | |||||
from ...core._imperative_rt import OpDef | from ...core._imperative_rt import OpDef | ||||
from ...core._imperative_rt.core2 import Tensor as RawTensor | from ...core._imperative_rt.core2 import Tensor as RawTensor | ||||
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | ||||
from ...core.ops.special import Const | from ...core.ops.special import Const | ||||
from ...module import Module | |||||
from ...tensor import Tensor | from ...tensor import Tensor | ||||
from .module_tracer import active_module_tracer | from .module_tracer import active_module_tracer | ||||
from .node import ModuleNode, Node, NodeMixin, TensorNode | from .node import ModuleNode, Node, NodeMixin, TensorNode | ||||
@@ -22,12 +23,66 @@ from .node import ModuleNode, Node, NodeMixin, TensorNode | |||||
class Expr: | class Expr: | ||||
""" | """ | ||||
``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``. | |||||
``Expr`` represents the operations(i.e. CallMethod, CallFunction, Apply, GetAttr, Input, Constant) on ``Node``. | |||||
""" | """ | ||||
inputs = None # type: List[Node] | inputs = None # type: List[Node] | ||||
outputs = None # type: List[Node] | outputs = None # type: List[Node] | ||||
def add_input(self, node): | |||||
self.inputs.append(node) | |||||
def add_outputs(self, outputs): | |||||
self.outputs = [] | |||||
if not isinstance(outputs, collections.Sequence): | |||||
outputs = (outputs,) | |||||
for i in outputs: | |||||
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||||
for i, node in zip(outputs, self.outputs,): | |||||
NodeMixin.wrap_safe(i, node) | |||||
@classmethod | |||||
def get_args_node(cls, arg): | |||||
""" | |||||
Create nodes by ``arg``, which may be a container. | |||||
Return the same structure with arg. | |||||
If ``arg`` was not Tensor or Module, it will be stored as const. | |||||
:param arg: tensor, module or const. | |||||
""" | |||||
if isinstance(arg, (RawTensor, Module)): | |||||
if not NodeMixin.get(arg, None): | |||||
NodeMixin.wrap_safe(arg, Constant.make(arg)) | |||||
return NodeMixin.get(arg) | |||||
elif isinstance(arg, collections.abc.Sequence): | |||||
seq_cls = type(arg) | |||||
return seq_cls([Expr.get_args_node(a) for a in arg]) | |||||
else: | |||||
# TODO: assert arg type | |||||
return arg # as const | |||||
@classmethod | |||||
def get_arg_value(cls, inp_node, node2value): | |||||
""" | |||||
Get values from node2value by inp_node, which may be a container. | |||||
Return the same structure with inp_node. | |||||
If ``inp_node`` was not in node2value, it is a const. | |||||
:param inp_node: nodes. | |||||
:param node2value: dict from node to tensor and module. | |||||
""" | |||||
if inp_node in node2value: | |||||
return node2value[inp_node] | |||||
elif isinstance(inp_node, collections.abc.Sequence): | |||||
seq_cls = type(inp_node) | |||||
return seq_cls([Expr.get_arg_value(i, node2value) for i in inp_node]) | |||||
else: | |||||
return inp_node | |||||
# expr: None (i.e. fake expression which is used to mark input) | # expr: None (i.e. fake expression which is used to mark input) | ||||
class Input(Expr): | class Input(Expr): | ||||
@@ -83,23 +138,22 @@ class GetAttr(Expr): | |||||
# expr: outputs = inputs[0].__call__(*inputs[1:]) | # expr: outputs = inputs[0].__call__(*inputs[1:]) | ||||
class Call(Expr): | |||||
def __init__(self, module): | |||||
assert isinstance(module, ModuleNode) | |||||
class CallMethod(Expr): | |||||
def __init__(self, module, method="__call__"): | |||||
assert isinstance(module, (TensorNode, ModuleNode)) | |||||
self.inputs = [ | self.inputs = [ | ||||
module, | module, | ||||
] | ] | ||||
self.method = method | |||||
self.arg_names = [] | |||||
self.kwargs = {} # const kwargs | |||||
def add_input(self, node): | |||||
def add_input(self, node, arg_name=None): | |||||
if arg_name == "self": # FIXME: <XP> | |||||
return | |||||
self.inputs.append(node) | self.inputs.append(node) | ||||
def add_outputs(self, references): | |||||
self.outputs = [] | |||||
if not isinstance(references, collections.Sequence): | |||||
references = (references,) | |||||
for i in references: | |||||
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||||
if arg_name is not None: | |||||
self.arg_names.append(arg_name) | |||||
@classmethod | @classmethod | ||||
def make(cls, *args, **kwargs): | def make(cls, *args, **kwargs): | ||||
@@ -110,15 +164,16 @@ class Call(Expr): | |||||
def interpret(self, *inputs): | def interpret(self, *inputs): | ||||
mod = inputs[0] | mod = inputs[0] | ||||
args = inputs[1:] | args = inputs[1:] | ||||
outputs = mod(*args) | |||||
outputs = getattr(mod, self.method)(*args, **self.kwargs) | |||||
if isinstance(outputs, RawTensor): | if isinstance(outputs, RawTensor): | ||||
outputs = (outputs,) | outputs = (outputs,) | ||||
return outputs | return outputs | ||||
def __repr__(self): | def __repr__(self): | ||||
return "{} = Call({})({})".format( | |||||
return "{} = CallMethod({}, {})({})".format( | |||||
", ".join(str(i) for i in self.outputs), | ", ".join(str(i) for i in self.outputs), | ||||
self.inputs[0], | self.inputs[0], | ||||
self.method, | |||||
", ".join(str(i) for i in self.inputs[1:]), | ", ".join(str(i) for i in self.inputs[1:]), | ||||
) | ) | ||||
@@ -132,17 +187,6 @@ class Apply(Expr): | |||||
self.opdef = opdef | self.opdef = opdef | ||||
self.inputs = [] | self.inputs = [] | ||||
def add_input(self, node): | |||||
self.inputs.append(node) | |||||
def add_outputs(self, references): | |||||
self.outputs = [] | |||||
if not isinstance(references, collections.Sequence): | |||||
references = (references,) | |||||
for i in references: | |||||
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||||
@classmethod | @classmethod | ||||
def make(cls, *args, **kwargs): | def make(cls, *args, **kwargs): | ||||
expr = cls(*args, **kwargs) | expr = cls(*args, **kwargs) | ||||
@@ -179,6 +223,40 @@ class Apply(Expr): | |||||
return list(outputs) | return list(outputs) | ||||
class CallFunction(Expr): | |||||
def __init__(self, func): | |||||
assert isinstance(func, Callable) | |||||
self.func = func | |||||
self.inputs = [] | |||||
self.arg_names = [] | |||||
self.kwargs = {} # const kwargs | |||||
def add_input(self, node, arg_name): | |||||
self.inputs.append(node) | |||||
self.arg_names.append(arg_name) | |||||
@classmethod | |||||
def make(cls, *args, **kwargs): | |||||
expr = cls(*args, **kwargs) | |||||
active_module_tracer().current_scope().insert(expr) | |||||
return expr | |||||
def interpret(self, *inputs): | |||||
inp_dict = dict([(name, node) for node, name in zip(inputs, self.arg_names)]) | |||||
outputs = self.func(**inp_dict, **self.kwargs) | |||||
outputs = ( | |||||
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | |||||
) | |||||
return outputs | |||||
def __repr__(self): | |||||
return "{} = {}({})".format( | |||||
", ".join(str(i) for i in self.outputs), | |||||
self.func.__module__ + "." + self.func.__name__, | |||||
", ".join(str(i) for i in self.inputs), | |||||
) | |||||
# expr outputs = self.value | # expr outputs = self.value | ||||
class Constant(Expr): | class Constant(Expr): | ||||
value = None | value = None | ||||
@@ -6,7 +6,11 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import collections | |||||
from ... import Tensor | |||||
from ... import functional as F | |||||
from ...core.tensor.array_method import ArrayMethodMixin | |||||
from ...module import Module | from ...module import Module | ||||
_active_module_tracer = None | _active_module_tracer = None | ||||
@@ -23,12 +27,14 @@ def set_active_module_tracer(tracer): | |||||
class module_tracer: | class module_tracer: | ||||
# builtin types | |||||
_opaque_types = set() | _opaque_types = set() | ||||
_active_scopes = None | _active_scopes = None | ||||
def __init__(self): | |||||
def __init__(self, wrap_fn): | |||||
self._active_scopes = [] | self._active_scopes = [] | ||||
self.patcher = Patcher(wrap_fn) | |||||
@classmethod | @classmethod | ||||
def register_as_builtin(cls, mod): | def register_as_builtin(cls, mod): | ||||
@@ -50,3 +56,105 @@ class module_tracer: | |||||
if self._active_scopes: | if self._active_scopes: | ||||
return self._active_scopes[-1] | return self._active_scopes[-1] | ||||
return None | return None | ||||
class PatchedFn: | |||||
frame_dict = None | |||||
name = None | |||||
origin_fn = None | |||||
def __init__(self, frame_dict, name): | |||||
self.frame_dict = frame_dict | |||||
self.name = name | |||||
self.origin_fn = ( | |||||
self.frame_dict[name] | |||||
if isinstance(frame_dict, collections.abc.Mapping) | |||||
else getattr(frame_dict, name) | |||||
) | |||||
def set_func(self, func): | |||||
if isinstance(self.frame_dict, collections.abc.Mapping): | |||||
self.frame_dict[self.name] = func | |||||
else: | |||||
setattr(self.frame_dict, self.name, func) | |||||
class Patcher: | |||||
patched_fn_ids = set() | |||||
_builtin_functions = [] | |||||
_builtin_modules = [ | |||||
F, | |||||
F.distributed, | |||||
F.elemwise, | |||||
F.inplace, | |||||
F.loss, | |||||
F.math, | |||||
F.metric, | |||||
F.nn, | |||||
F.quantized, | |||||
F.tensor, | |||||
F.utils, | |||||
F.vision, | |||||
] | |||||
_builtin_methods = [ | |||||
Tensor, | |||||
ArrayMethodMixin, | |||||
] | |||||
def __init__(self, wrap_fn): | |||||
self.patched_fn = [] | |||||
self.visited_frames_ids = set() | |||||
self.wrap_fn = wrap_fn | |||||
for module in self._builtin_modules: | |||||
self.patch_module(module) | |||||
for cls in self._builtin_methods: | |||||
self.patch_cls(cls) | |||||
for i, j in self._builtin_functions: | |||||
if id(i) not in self.visited_frames_ids: | |||||
self.patch_function(i, j, self.wrap_fn) | |||||
def patch_function(self, frame_dict, fn, wrap_fn): | |||||
patched_fn = PatchedFn(frame_dict, fn) | |||||
self.patched_fn_ids.add(id(patched_fn.origin_fn)) | |||||
patched_fn.set_func(wrap_fn(patched_fn.origin_fn)) | |||||
self.patched_fn.append(patched_fn) | |||||
def patch_method(self, cls, name, wrap_fn): | |||||
self.patch_function(cls, name, wrap_fn) | |||||
def patch_cls(self, cls): | |||||
import inspect | |||||
if id(cls) not in self.visited_frames_ids: | |||||
for k, v in cls.__dict__.items(): | |||||
if inspect.isfunction(v) and not k.startswith("_"): | |||||
self.patch_function(cls, k, self.wrap_fn) | |||||
self.visited_frames_ids.add(id(cls)) | |||||
def patch_module(self, module): | |||||
import inspect | |||||
if id(module.__dict__) not in self.visited_frames_ids: | |||||
for k, v in module.__dict__.items(): | |||||
if inspect.isfunction(v) and not k.startswith("_"): | |||||
self.patch_function(module.__dict__, k, self.wrap_fn) | |||||
self.visited_frames_ids.add(id(module.__dict__)) | |||||
def auto_patch(self, frame_dict): | |||||
if id(frame_dict) not in self.visited_frames_ids: | |||||
for k, v in frame_dict.items(): | |||||
if id(v) in self.patched_fn_ids: | |||||
self.patch_function(frame_dict, k, self.wrap_fn) | |||||
self.visited_frames_ids.add(id(frame_dict)) | |||||
def __enter__(self): | |||||
return self | |||||
def __exit__(self, type, vlaue, trace): | |||||
while self.patched_fn: | |||||
pf = self.patched_fn.pop() | |||||
pf.set_func(pf.origin_fn) | |||||
self.visited_frames_ids.clear() |
@@ -34,6 +34,10 @@ class Node: | |||||
Node.__total_id += 1 | Node.__total_id += 1 | ||||
self._name = name | self._name = name | ||||
def __setstate__(self, d): | |||||
self.__dict__ = d | |||||
Node.__total_id = max(Node.__total_id, self._id) + 1 | |||||
def __repr__(self): | def __repr__(self): | ||||
if self._name is None: | if self._name is None: | ||||
return "%{}".format(self._id) | return "%{}".format(self._id) | ||||
@@ -8,14 +8,25 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import collections | import collections | ||||
import copy | import copy | ||||
import functools | |||||
from typing import List, Type | from typing import List, Type | ||||
from ... import module as M | from ... import module as M | ||||
from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing | |||||
from ...core._imperative_rt.core2 import ( | |||||
is_tracing_module, | |||||
set_module_tracing, | |||||
unset_module_tracing, | |||||
) | |||||
from ...core.tensor.array_method import ArrayMethodMixin | |||||
from ...module import Module | from ...module import Module | ||||
from ...tensor import Tensor | from ...tensor import Tensor | ||||
from .expr import Apply, Call, Constant, Expr, GetAttr, Input | |||||
from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer | |||||
from .expr import Apply, CallFunction, CallMethod, Constant, Expr, GetAttr, Input | |||||
from .module_tracer import ( | |||||
Patcher, | |||||
active_module_tracer, | |||||
module_tracer, | |||||
set_active_module_tracer, | |||||
) | |||||
from .node import ModuleNode, Node, NodeMixin, TensorNode | from .node import ModuleNode, Node, NodeMixin, TensorNode | ||||
@@ -54,7 +65,9 @@ class InternalGraph: | |||||
for n, v in zip(self._inputs, inputs): | for n, v in zip(self._inputs, inputs): | ||||
node2value[n] = v | node2value[n] = v | ||||
for expr in self._exprs: | for expr in self._exprs: | ||||
values = expr.interpret(*list(node2value[i] for i in expr.inputs)) | |||||
values = expr.interpret( | |||||
*list(Expr.get_arg_value(i, node2value) for i in expr.inputs) | |||||
) | |||||
for n, v in zip(expr.outputs, values): | for n, v in zip(expr.outputs, values): | ||||
node2value[n] = v | node2value[n] = v | ||||
return list(node2value[i] for i in self._outputs) | return list(node2value[i] for i in self._outputs) | ||||
@@ -67,6 +80,41 @@ class InternalGraph: | |||||
) | ) | ||||
def _wrapped_function(orig_func): | |||||
@functools.wraps(orig_func) | |||||
def wrapped_fn(*inputs, **kwargs): | |||||
if is_tracing_module(): | |||||
unset_module_tracing() | |||||
const_kwargs = {} | |||||
arg_names = orig_func.__code__.co_varnames | |||||
if orig_func.__qualname__.split(".").__len__() > 1: | |||||
# FIXME: a robust way to distinguish method and function. <XP> | |||||
self = inputs[0] | |||||
call_node = CallMethod.make(NodeMixin.get(self), orig_func.__name__) | |||||
else: | |||||
call_node = CallFunction.make(orig_func) | |||||
def add_input(inp, varname=None): | |||||
node = Expr.get_args_node(inp) | |||||
if node is not None: | |||||
call_node.add_input(node, varname) | |||||
else: | |||||
const_kwargs[varname] = inp | |||||
for ind, inp in enumerate(inputs): | |||||
add_input(inp, arg_names[ind]) | |||||
for k, v in kwargs.items(): | |||||
add_input(v, k) | |||||
call_node.kwargs = const_kwargs | |||||
outputs = orig_func(*inputs, **kwargs) | |||||
call_node.add_outputs(outputs) | |||||
set_module_tracing() | |||||
return outputs | |||||
return orig_func(*inputs, **kwargs) | |||||
return wrapped_fn | |||||
class TracedModuleBuilder(NodeMixin): | class TracedModuleBuilder(NodeMixin): | ||||
_mod = None # type: Module | _mod = None # type: Module | ||||
@@ -120,7 +168,7 @@ class TracedModuleBuilder(NodeMixin): | |||||
mark_constant(i) | mark_constant(i) | ||||
for k, v in kwargs.items(): | for k, v in kwargs.items(): | ||||
mark_constant(v) | mark_constant(v) | ||||
callnode = Call.make(NodeMixin.get(self)) | |||||
callnode = CallMethod.make(NodeMixin.get(self)) | |||||
def add_input(x): | def add_input(x): | ||||
callnode.add_input(NodeMixin.get(x)) | callnode.add_input(NodeMixin.get(x)) | ||||
@@ -145,7 +193,8 @@ class TracedModuleBuilder(NodeMixin): | |||||
) | ) | ||||
# prepare args and kwargs for inner graph | # prepare args and kwargs for inner graph | ||||
def wrap(x): | def wrap(x): | ||||
wrapped = copy.copy(x) # FIXME | |||||
# wrapped = copy.copy(x) # FIXME | |||||
wrapped = x # FIXME: <XP> | |||||
NodeMixin.wrap( | NodeMixin.wrap( | ||||
wrapped, | wrapped, | ||||
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), | lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), | ||||
@@ -157,7 +206,9 @@ class TracedModuleBuilder(NodeMixin): | |||||
args.append(wrap(i)) | args.append(wrap(i)) | ||||
for k, v in kwargs.items(): | for k, v in kwargs.items(): | ||||
kwargs[k] = wrap(v) | kwargs[k] = wrap(v) | ||||
active_module_tracer().patcher.auto_patch( | |||||
getattr(getattr(self._mod, "forward", self._mod), "__globals__", {}) | |||||
) | |||||
outputs = type(self._mod).forward(self, *args, **kwargs) | outputs = type(self._mod).forward(self, *args, **kwargs) | ||||
for i in ( | for i in ( | ||||
@@ -171,11 +222,6 @@ class TracedModuleBuilder(NodeMixin): | |||||
# rebind output to outer graph | # rebind output to outer graph | ||||
callnode.add_outputs(outputs) | callnode.add_outputs(outputs) | ||||
for i, node in zip( | |||||
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,), | |||||
callnode.outputs, | |||||
): | |||||
NodeMixin.wrap_safe(i, node) | |||||
return outputs | return outputs | ||||
def __getattr__(self, name): | def __getattr__(self, name): | ||||
@@ -229,6 +275,55 @@ class TracedModule(Module): | |||||
rst = rst[0] | rst = rst[0] | ||||
return rst | return rst | ||||
@property | |||||
def all_exprs(self): | |||||
""" | |||||
Visit all ``Expr``s in the graph recursively. | |||||
:return: List[Expr] | |||||
""" | |||||
in_nodes = [i.expr for i in self.m_node.graph._inputs if not i is self] | |||||
def _flatten_submodule(module, call=None): | |||||
if not isinstance(module, TracedModule): | |||||
call.inputs[0] = module | |||||
return (call,) | |||||
exprs = [] | |||||
graph = module.m_node.graph | |||||
for expr in graph._exprs: | |||||
# replace inputs for submodule's expr | |||||
for idx, inp in enumerate(expr.inputs): | |||||
if call and inp in graph._inputs: | |||||
expr.inputs[idx] = call.inputs[idx] | |||||
# replace outputs for submodule's expr | |||||
for idx, outp in enumerate(expr.outputs): | |||||
if call and outp in graph._outputs: | |||||
expr.outputs[idx] = call.outputs[idx] | |||||
if isinstance(expr, GetAttr): | |||||
# replace GetAttr with Constant | |||||
if isinstance(expr.outputs[0], TensorNode): | |||||
const = Constant(getattr(module, expr.name)) | |||||
const.outputs = expr.outputs | |||||
exprs.append(const) | |||||
elif isinstance(expr, CallMethod): | |||||
obj_node = expr.inputs[0] | |||||
if isinstance(obj_node, ModuleNode): | |||||
(obj,) = expr.inputs[0].expr.interpret(module) | |||||
exprs.extend(_flatten_submodule(obj, expr)) | |||||
else: | |||||
exprs.append(expr) | |||||
else: | |||||
exprs.append(expr) | |||||
return exprs | |||||
return in_nodes + _flatten_submodule(self) | |||||
def __getstate__(self): | def __getstate__(self): | ||||
d = self.__dict__ | d = self.__dict__ | ||||
for k in Module.__dict__: | for k in Module.__dict__: | ||||
@@ -273,23 +368,23 @@ def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule | |||||
assert active_module_tracer() is None | assert active_module_tracer() is None | ||||
try: | try: | ||||
set_module_tracing() | set_module_tracing() | ||||
set_active_module_tracer(module_tracer()) | |||||
global_scope = InternalGraph() | |||||
active_module_tracer().push_scope(global_scope) | |||||
set_active_module_tracer(module_tracer(_wrapped_function)) | |||||
with active_module_tracer().patcher: | |||||
global_scope = InternalGraph() | |||||
active_module_tracer().push_scope(global_scope) | |||||
builder = TracedModuleBuilder(mod) | |||||
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | |||||
builder = TracedModuleBuilder(mod) | |||||
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | |||||
for _, i in enumerate(inputs): | |||||
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) | |||||
for k, v in kwargs.items(): | |||||
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) | |||||
for _, i in enumerate(inputs): | |||||
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) | |||||
for k, v in kwargs.items(): | |||||
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) | |||||
builder(*inputs, **kwargs) | |||||
active_module_tracer().pop_scope() | |||||
builder(*inputs, **kwargs) | |||||
active_module_tracer().pop_scope() | |||||
return builder.build() | |||||
return builder.build() | |||||
finally: | finally: | ||||
set_active_module_tracer(None) | set_active_module_tracer(None) | ||||
unset_module_tracing() | unset_module_tracing() |