@@ -130,3 +130,4 @@ import megengine.optimizer | |||
import megengine.quantization | |||
import megengine.random | |||
import megengine.utils | |||
import megengine.experimental |
@@ -6,4 +6,5 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from . import traced_module | |||
from .weight_scaler import get_scaled_model |
@@ -5,3 +5,15 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ...core._imperative_rt.core2 import set_cpp_apply_module_trace | |||
from .traced_module import ( | |||
TracedModule, | |||
_register_all_builtin_module, | |||
cpp_apply_module_trace, | |||
register_as_builtin, | |||
trace_module, | |||
) | |||
_register_all_builtin_module() | |||
set_cpp_apply_module_trace(cpp_apply_module_trace) |
@@ -0,0 +1,215 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
from typing import List | |||
from ...core._imperative_rt import OpDef | |||
from ...core._imperative_rt.core2 import Tensor as RawTensor | |||
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing | |||
from ...core.ops.special import Const | |||
from ...tensor import Tensor | |||
from .module_tracer import active_module_tracer | |||
from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
class Expr: | |||
""" | |||
``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``. | |||
""" | |||
inputs = None # type: List[Node] | |||
outputs = None # type: List[Node] | |||
# expr: None (i.e. fake expression which is used to mark input) | |||
class Input(Expr): | |||
name = None | |||
def __init__(self, name=None, type=None): | |||
self.inputs = [] | |||
node_cls = type if type else Node | |||
self.outputs = [ | |||
node_cls(self, name=name), | |||
] | |||
self.name = name | |||
@classmethod | |||
def make(cls, *args, **kwargs): | |||
expr = cls(*args, **kwargs) | |||
active_module_tracer().current_scope().add_input(expr.outputs[0]) | |||
return expr.outputs[0] | |||
def __repr__(self): | |||
return "{} = Input({})".format(self.outputs[0], self.name) | |||
# expr: outputs = getattr(inputs[0], self.name) | |||
class GetAttr(Expr): | |||
name = None | |||
def __init__(self, module, name, type=None): | |||
assert isinstance(module, ModuleNode) | |||
self.inputs = [ | |||
module, | |||
] | |||
self.name = name | |||
node_cls = type if type else Node | |||
self.outputs = [ | |||
node_cls(self), | |||
] | |||
@classmethod | |||
def make(cls, *args, **kwargs): | |||
expr = cls(*args, **kwargs) | |||
active_module_tracer().current_scope().insert(expr) | |||
expr.outputs[0]._name = expr.name | |||
return expr.outputs[0] | |||
def interpret(self, *inputs): | |||
return (getattr(inputs[0], self.name),) | |||
def __repr__(self): | |||
return '{} = GetAttr({}, "{}")'.format( | |||
self.outputs[0], self.inputs[0], self.name | |||
) | |||
# expr: outputs = inputs[0].__call__(*inputs[1:]) | |||
class Call(Expr): | |||
def __init__(self, module): | |||
assert isinstance(module, ModuleNode) | |||
self.inputs = [ | |||
module, | |||
] | |||
def add_input(self, node): | |||
self.inputs.append(node) | |||
def add_outputs(self, references): | |||
self.outputs = [] | |||
if not isinstance(references, collections.Sequence): | |||
references = (references,) | |||
for i in references: | |||
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||
@classmethod | |||
def make(cls, *args, **kwargs): | |||
expr = cls(*args, **kwargs) | |||
active_module_tracer().current_scope().insert(expr) | |||
return expr | |||
def interpret(self, *inputs): | |||
mod = inputs[0] | |||
args = inputs[1:] | |||
outputs = mod(*args) | |||
if isinstance(outputs, RawTensor): | |||
outputs = (outputs,) | |||
return outputs | |||
def __repr__(self): | |||
return "{} = Call({})({})".format( | |||
", ".join(str(i) for i in self.outputs), | |||
self.inputs[0], | |||
", ".join(str(i) for i in self.inputs[1:]), | |||
) | |||
# expr: outputs = apply(self.opdef, *inputs) | |||
class Apply(Expr): | |||
opdef = None | |||
def __init__(self, opdef): | |||
assert isinstance(opdef, OpDef) | |||
self.opdef = opdef | |||
self.inputs = [] | |||
def add_input(self, node): | |||
self.inputs.append(node) | |||
def add_outputs(self, references): | |||
self.outputs = [] | |||
if not isinstance(references, collections.Sequence): | |||
references = (references,) | |||
for i in references: | |||
self.outputs.append(NodeMixin.get_wrapped_type(i)(self)) | |||
@classmethod | |||
def make(cls, *args, **kwargs): | |||
expr = cls(*args, **kwargs) | |||
active_module_tracer().current_scope().insert(expr) | |||
return expr | |||
def interpret(self, *inputs): | |||
return apply(self.opdef, *inputs) | |||
def __repr__(self): | |||
return "{} = {}({})".format( | |||
", ".join(str(i) for i in self.outputs), | |||
self.opdef, | |||
", ".join(str(i) for i in self.inputs), | |||
) | |||
@classmethod | |||
def apply_module_trace_hook(cls, opdef, *inputs): | |||
for i in inputs: | |||
node = NodeMixin.get(i, None) | |||
if node is None: # capture as constant | |||
NodeMixin.wrap_safe(i, Constant.make(i)) | |||
apply_node = cls.make(opdef) | |||
for i in inputs: | |||
apply_node.add_input(NodeMixin.get(i)) | |||
unset_module_tracing() | |||
outputs = apply(opdef, *inputs) | |||
set_module_tracing() | |||
apply_node.add_outputs(outputs) | |||
for n, v in zip(apply_node.outputs, outputs): | |||
NodeMixin.wrap_safe(v, n) | |||
return list(outputs) | |||
# expr outputs = self.value | |||
class Constant(Expr): | |||
value = None | |||
# TODO: constant cache to reduce the size of dumped model | |||
_constant_cache = {} | |||
def __init__(self, c): | |||
# TODO: type check, since not all types should be captured as constant | |||
self.value = c | |||
self.inputs = [] | |||
node_cls = NodeMixin.get_wrapped_type(c) | |||
self.outputs = [ | |||
node_cls(self), | |||
] | |||
@classmethod | |||
def make(cls, *args, **kwargs): | |||
expr = cls(*args, **kwargs) | |||
active_module_tracer().current_scope().insert(expr) | |||
return expr.outputs[0] | |||
def interpret(self, *inputs): | |||
if isinstance(self.value, RawTensor): | |||
return Const(self.value.numpy())() | |||
return (self.value,) | |||
def __repr__(self): | |||
return "{} = Constant({})".format(self.outputs[0], self.value) | |||
def __getstate__(self): | |||
state = self.__dict__.copy() | |||
if isinstance(self.value, RawTensor): | |||
state["value"] = Tensor(self.value) | |||
return state |
@@ -0,0 +1,52 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ...module import Module | |||
_active_module_tracer = None | |||
def active_module_tracer(): | |||
return _active_module_tracer | |||
def set_active_module_tracer(tracer): | |||
global _active_module_tracer | |||
_active_module_tracer = tracer | |||
class module_tracer: | |||
_opaque_types = set() | |||
_active_scopes = None | |||
def __init__(self): | |||
self._active_scopes = [] | |||
@classmethod | |||
def register_as_builtin(cls, mod): | |||
assert issubclass(mod, Module) | |||
cls._opaque_types.add(mod) | |||
return mod | |||
@classmethod | |||
def is_builtin(cls, mod): | |||
return type(mod) in cls._opaque_types | |||
def push_scope(self, scope): | |||
self._active_scopes.append(scope) | |||
def pop_scope(self): | |||
self._active_scopes.pop() | |||
def current_scope(self): | |||
if self._active_scopes: | |||
return self._active_scopes[-1] | |||
return None |
@@ -0,0 +1,123 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Any, Dict, Tuple, Type | |||
import numpy | |||
from ...core._imperative_rt.core2 import Tensor as RawTensor | |||
from ...module import Module | |||
from ...tensor import Tensor | |||
class Node: | |||
""" | |||
``Node`` represents the variables (Tensor/Module/other python object) used in Module's forward method. They are inputs/outputs of Expr(the operations on variables). | |||
param expr: the Expr which produces the node | |||
param name: the name of the node | |||
""" | |||
expr = None | |||
__total_id = 0 | |||
_id = None | |||
_name = None | |||
def __init__(self, expr: "Expr", name: str = None): | |||
self.expr = expr | |||
self._id = Node.__total_id | |||
Node.__total_id += 1 | |||
self._name = name | |||
def __repr__(self): | |||
if self._name is None: | |||
return "%{}".format(self._id) | |||
else: | |||
return "%{}".format(self._name) | |||
class ModuleNode(Node): | |||
""" | |||
``ModuleNode`` represents the Module objects. | |||
Attributes: | |||
module_type: type of the Module correspending to the ModuleNode | |||
graph: the InternalGraph which will be interpreted when call Module's forward method | |||
attr_type_map: record the type of Module's attributes | |||
""" | |||
module_type = Module # type: Type[Module] | |||
graph = None | |||
attr_type_map = None # type: Dict[str, Type[Any]] | |||
def __repr__(self): | |||
if self._name is None: | |||
return "%{}({})".format(self._id, self.module_type.__name__) | |||
else: | |||
return "%{}({})".format(self._name, self.module_type.__name__) | |||
class TensorNode(Node): | |||
""" | |||
``TensorNode`` represents the Tensor objects. | |||
""" | |||
shape = None # type: Tuple[int] | |||
dtype = None # type: numpy.dtype | |||
def __repr__(self): | |||
if self._name is None: | |||
return "%{}(Tensor)".format(self._id) | |||
else: | |||
return "%{}(Tensor)".format(self._name) | |||
class NodeMixin: | |||
__node = None | |||
@classmethod | |||
def wrap(cls, value, node): | |||
if isinstance(value, (NodeMixin, RawTensor)): | |||
if isinstance(node, Node): | |||
if isinstance(value, RawTensor): | |||
node.dtype = value.dtype | |||
node.shape = ( | |||
value._tuple_shape if isinstance(value, Tensor) else value.shape | |||
) | |||
setattr(value, "_NodeMixin__node", node) | |||
else: | |||
assert callable(node) | |||
n = node() | |||
if isinstance(value, RawTensor): | |||
n.dtype = value.dtype | |||
n.shape = ( | |||
value._tuple_shape if isinstance(value, Tensor) else value.shape | |||
) | |||
setattr(value, "_NodeMixin__node", n) | |||
@classmethod | |||
def wrap_safe(cls, value, node): | |||
assert isinstance(value, (NodeMixin, RawTensor)) | |||
if isinstance(value, RawTensor): | |||
node.dtype = value.dtype | |||
node.shape = ( | |||
value._tuple_shape if isinstance(value, Tensor) else value.shape | |||
) | |||
setattr(value, "_NodeMixin__node", node) | |||
@classmethod | |||
def get(cls, value, *default): | |||
return getattr(value, "_NodeMixin__node", *default) | |||
@classmethod | |||
def get_wrapped_type(cls, value): | |||
if isinstance(value, RawTensor): | |||
return TensorNode | |||
if isinstance(value, (Module, NodeMixin)): | |||
return ModuleNode | |||
return Node |
@@ -0,0 +1,295 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
import copy | |||
from typing import List, Type | |||
from ... import module as M | |||
from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing | |||
from ...module import Module | |||
from ...tensor import Tensor | |||
from .expr import Apply, Call, Constant, Expr, GetAttr, Input | |||
from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer | |||
from .node import ModuleNode, Node, NodeMixin, TensorNode | |||
class InternalGraph: | |||
""" | |||
``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method. | |||
Attributes: | |||
_exprs: List of Exprs in order of execution | |||
_inputs: Input Nodes of InternalGraph | |||
_outputs: Output Nodes of InternalGraph | |||
""" | |||
_exprs = None # type: List[Expr] | |||
_inputs = None # type: List[Node] | |||
_outputs = None # type: List[Node] | |||
def __init__(self): | |||
self._exprs = [] | |||
self._inputs = [] | |||
self._outputs = [] | |||
def insert(self, expr): | |||
self._exprs.append(expr) | |||
def add_input(self, i): | |||
self._inputs.append(i) | |||
def add_output(self, o): | |||
self._outputs.append(o) | |||
def interpret(self, *inputs): | |||
# TODO: support kwargs ? | |||
# TODO: skip expressions which are independent and have no side effect | |||
node2value = {} | |||
for n, v in zip(self._inputs, inputs): | |||
node2value[n] = v | |||
for expr in self._exprs: | |||
values = expr.interpret(*list(node2value[i] for i in expr.inputs)) | |||
for n, v in zip(expr.outputs, values): | |||
node2value[n] = v | |||
return list(node2value[i] for i in self._outputs) | |||
def __repr__(self): | |||
return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format( | |||
", ".join(str(i) for i in self._inputs), | |||
"\n\t".join(str(i) for i in self._exprs), | |||
", ".join(str(i) for i in self._outputs), | |||
) | |||
class TracedModuleBuilder(NodeMixin): | |||
_mod = None # type: Module | |||
_body = None # type: InternalGraph | |||
_is_builtin = None # type: bool | |||
__builder_attributes__ = [ | |||
"_mod", | |||
"_body", | |||
"_NodeMixin__node", | |||
"_is_builtin", | |||
"_is_traced", | |||
"build", | |||
] | |||
def __init__(self, mod): | |||
super(TracedModuleBuilder, self).__init__() | |||
self._mod = mod | |||
self._body = InternalGraph() | |||
self._is_traced = False | |||
self._is_builtin = module_tracer.is_builtin(mod) | |||
def build(self): | |||
if self._is_builtin: | |||
node = NodeMixin.get(self) | |||
node.module_type = type(self._mod) | |||
return self._mod | |||
else: | |||
node = NodeMixin.get(self) | |||
node.graph = self._body | |||
node.attr_type_map = {} | |||
traced_module = TracedModule(node) | |||
for k, v in self.__dict__.items(): | |||
if k not in TracedModuleBuilder.__builder_attributes__: | |||
if isinstance(v, TracedModuleBuilder): | |||
v = v.build() | |||
setattr(traced_module, k, v) | |||
traced_module.m_node.attr_type_map[k] = type(v) | |||
return traced_module | |||
def __call__(self, *inputs, **kwargs): | |||
assert isinstance(self._mod, Module) | |||
# prepare args and kwargs for inner graph | |||
def mark_constant(x): | |||
node = NodeMixin.get(x, None) | |||
if node is None: # capture as constant | |||
NodeMixin.wrap(x, lambda: Constant.make(x)) | |||
for i in inputs: | |||
mark_constant(i) | |||
for k, v in kwargs.items(): | |||
mark_constant(v) | |||
callnode = Call.make(NodeMixin.get(self)) | |||
def add_input(x): | |||
callnode.add_input(NodeMixin.get(x)) | |||
for i in inputs: | |||
add_input(i) | |||
for k, v in kwargs.items(): | |||
add_input(v) | |||
if self._is_builtin or self._is_traced: | |||
unset_module_tracing() | |||
outputs = self._mod(*inputs, **kwargs) | |||
set_module_tracing() | |||
if self._is_builtin: | |||
self._body = None | |||
else: | |||
active_module_tracer().push_scope(self._body) | |||
# rebind self to new input node | |||
orig_self = NodeMixin.get(self) | |||
NodeMixin.wrap_safe( | |||
self, Input.make("self", NodeMixin.get_wrapped_type(self)) | |||
) | |||
# prepare args and kwargs for inner graph | |||
def wrap(x): | |||
wrapped = copy.copy(x) # FIXME | |||
NodeMixin.wrap( | |||
wrapped, | |||
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)), | |||
) | |||
return wrapped | |||
args = [] | |||
for i in inputs: | |||
args.append(wrap(i)) | |||
for k, v in kwargs.items(): | |||
kwargs[k] = wrap(v) | |||
outputs = type(self._mod).forward(self, *args, **kwargs) | |||
for i in ( | |||
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,) | |||
): | |||
active_module_tracer().current_scope().add_output(NodeMixin.get(i)) | |||
NodeMixin.wrap_safe(self, orig_self) | |||
self._is_traced = True | |||
active_module_tracer().pop_scope() | |||
# rebind output to outer graph | |||
callnode.add_outputs(outputs) | |||
for i, node in zip( | |||
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,), | |||
callnode.outputs, | |||
): | |||
NodeMixin.wrap_safe(i, node) | |||
return outputs | |||
def __getattr__(self, name): | |||
if name not in self._mod.__dict__: | |||
attr = getattr(type(self._mod), name).__get__(self, type(self)) | |||
else: | |||
attr = getattr(self._mod, name) | |||
if isinstance(attr, Module): | |||
attr = TracedModuleBuilder(attr) | |||
setattr(self, name, attr) | |||
NodeMixin.wrap( | |||
attr, | |||
lambda: GetAttr.make( | |||
NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr) | |||
), | |||
) | |||
return attr | |||
def __getattribute__(self, name): | |||
if name in TracedModuleBuilder.__builder_attributes__: | |||
return super().__getattribute__(name) | |||
else: | |||
wrapped = super().__getattribute__(name) | |||
if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None): | |||
assert not self._is_builtin | |||
NodeMixin.wrap( | |||
wrapped, | |||
lambda: GetAttr.make( | |||
NodeMixin.get(self), | |||
name, | |||
type=NodeMixin.get_wrapped_type(wrapped), | |||
), | |||
) | |||
return wrapped | |||
class TracedModule(Module): | |||
""" | |||
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called. | |||
""" | |||
m_node = None # type: ModuleNode | |||
def __init__(self, node): | |||
super(TracedModule, self).__init__() | |||
self.m_node = node | |||
def forward(self, *inputs): | |||
rst = self.m_node.graph.interpret(self, *inputs) | |||
if len(rst) == 1: | |||
rst = rst[0] | |||
return rst | |||
def __getstate__(self): | |||
d = self.__dict__ | |||
for k in Module.__dict__: | |||
d.pop(k, None) | |||
return d | |||
def cpp_apply_module_trace(opdef, *args): | |||
return Apply.apply_module_trace_hook(opdef, *args) | |||
def register_as_builtin(mod_cls: Type[Module]) -> None: | |||
""" | |||
Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module. | |||
param mod_cls: the Module class which will be threated as builtin module in tracing | |||
""" | |||
module_tracer.register_as_builtin(mod_cls) | |||
def _register_all_builtin_module(): | |||
from inspect import getmembers, isclass | |||
for sub_mod in [M, M.qat, M.quantized]: | |||
for m in getmembers(sub_mod): | |||
if ( | |||
isclass(m[1]) | |||
and issubclass(m[1], M.Module) | |||
and m[1] is not M.Sequential | |||
): | |||
module_tracer.register_as_builtin(m[1]) | |||
def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule: | |||
""" | |||
Traces module ``mod`` and returns corresponding TracedModule. | |||
param mod: the module will be converted to TracedModule | |||
param input: the positional arguments passed to forward method of ``mod`` | |||
param kwargs: the keyword arguments passed to forward method of ``mod`` | |||
""" | |||
assert active_module_tracer() is None | |||
try: | |||
set_module_tracing() | |||
set_active_module_tracer(module_tracer()) | |||
global_scope = InternalGraph() | |||
active_module_tracer().push_scope(global_scope) | |||
builder = TracedModuleBuilder(mod) | |||
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode)) | |||
for _, i in enumerate(inputs): | |||
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_))) | |||
for k, v in kwargs.items(): | |||
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k))) | |||
builder(*inputs, **kwargs) | |||
active_module_tracer().pop_scope() | |||
return builder.build() | |||
finally: | |||
set_active_module_tracer(None) | |||
unset_module_tracing() |