Browse Source

feat(imperative): add traced module

GitOrigin-RevId: 28c3503f2e
release-1.6
Megvii Engine Team 4 years ago
parent
commit
763c56f3b9
7 changed files with 699 additions and 0 deletions
  1. +1
    -0
      imperative/python/megengine/__init__.py
  2. +1
    -0
      imperative/python/megengine/experimental/__init__.py
  3. +12
    -0
      imperative/python/megengine/experimental/traced_module/__init__.py
  4. +215
    -0
      imperative/python/megengine/experimental/traced_module/expr.py
  5. +52
    -0
      imperative/python/megengine/experimental/traced_module/module_tracer.py
  6. +123
    -0
      imperative/python/megengine/experimental/traced_module/node.py
  7. +295
    -0
      imperative/python/megengine/experimental/traced_module/traced_module.py

+ 1
- 0
imperative/python/megengine/__init__.py View File

@@ -130,3 +130,4 @@ import megengine.optimizer
import megengine.quantization import megengine.quantization
import megengine.random import megengine.random
import megengine.utils import megengine.utils
import megengine.experimental

+ 1
- 0
imperative/python/megengine/experimental/__init__.py View File

@@ -6,4 +6,5 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from . import traced_module
from .weight_scaler import get_scaled_model from .weight_scaler import get_scaled_model

+ 12
- 0
imperative/python/megengine/experimental/traced_module/__init__.py View File

@@ -5,3 +5,15 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ...core._imperative_rt.core2 import set_cpp_apply_module_trace
from .traced_module import (
TracedModule,
_register_all_builtin_module,
cpp_apply_module_trace,
register_as_builtin,
trace_module,
)

_register_all_builtin_module()
set_cpp_apply_module_trace(cpp_apply_module_trace)

+ 215
- 0
imperative/python/megengine/experimental/traced_module/expr.py View File

@@ -0,0 +1,215 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


import collections
from typing import List

from ...core._imperative_rt import OpDef
from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...core._imperative_rt.core2 import apply, set_module_tracing, unset_module_tracing
from ...core.ops.special import Const
from ...tensor import Tensor
from .module_tracer import active_module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode


class Expr:
"""
``Expr`` represents the operations(i.e. Call, Apply, GetAttr, Input, Constant) on ``Node``.
"""

inputs = None # type: List[Node]
outputs = None # type: List[Node]


# expr: None (i.e. fake expression which is used to mark input)
class Input(Expr):
name = None

def __init__(self, name=None, type=None):
self.inputs = []
node_cls = type if type else Node
self.outputs = [
node_cls(self, name=name),
]
self.name = name

@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().add_input(expr.outputs[0])
return expr.outputs[0]

def __repr__(self):
return "{} = Input({})".format(self.outputs[0], self.name)


# expr: outputs = getattr(inputs[0], self.name)
class GetAttr(Expr):
name = None

def __init__(self, module, name, type=None):
assert isinstance(module, ModuleNode)
self.inputs = [
module,
]
self.name = name
node_cls = type if type else Node
self.outputs = [
node_cls(self),
]

@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
expr.outputs[0]._name = expr.name
return expr.outputs[0]

def interpret(self, *inputs):
return (getattr(inputs[0], self.name),)

def __repr__(self):
return '{} = GetAttr({}, "{}")'.format(
self.outputs[0], self.inputs[0], self.name
)


# expr: outputs = inputs[0].__call__(*inputs[1:])
class Call(Expr):
def __init__(self, module):
assert isinstance(module, ModuleNode)
self.inputs = [
module,
]

def add_input(self, node):
self.inputs.append(node)

def add_outputs(self, references):
self.outputs = []
if not isinstance(references, collections.Sequence):
references = (references,)

for i in references:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))

@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
return expr

def interpret(self, *inputs):
mod = inputs[0]
args = inputs[1:]
outputs = mod(*args)
if isinstance(outputs, RawTensor):
outputs = (outputs,)
return outputs

def __repr__(self):
return "{} = Call({})({})".format(
", ".join(str(i) for i in self.outputs),
self.inputs[0],
", ".join(str(i) for i in self.inputs[1:]),
)


# expr: outputs = apply(self.opdef, *inputs)
class Apply(Expr):
opdef = None

def __init__(self, opdef):
assert isinstance(opdef, OpDef)
self.opdef = opdef
self.inputs = []

def add_input(self, node):
self.inputs.append(node)

def add_outputs(self, references):
self.outputs = []
if not isinstance(references, collections.Sequence):
references = (references,)

for i in references:
self.outputs.append(NodeMixin.get_wrapped_type(i)(self))

@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
return expr

def interpret(self, *inputs):
return apply(self.opdef, *inputs)

def __repr__(self):
return "{} = {}({})".format(
", ".join(str(i) for i in self.outputs),
self.opdef,
", ".join(str(i) for i in self.inputs),
)

@classmethod
def apply_module_trace_hook(cls, opdef, *inputs):
for i in inputs:
node = NodeMixin.get(i, None)
if node is None: # capture as constant
NodeMixin.wrap_safe(i, Constant.make(i))
apply_node = cls.make(opdef)
for i in inputs:
apply_node.add_input(NodeMixin.get(i))

unset_module_tracing()
outputs = apply(opdef, *inputs)
set_module_tracing()

apply_node.add_outputs(outputs)
for n, v in zip(apply_node.outputs, outputs):
NodeMixin.wrap_safe(v, n)
return list(outputs)


# expr outputs = self.value
class Constant(Expr):
value = None
# TODO: constant cache to reduce the size of dumped model
_constant_cache = {}

def __init__(self, c):
# TODO: type check, since not all types should be captured as constant
self.value = c
self.inputs = []
node_cls = NodeMixin.get_wrapped_type(c)
self.outputs = [
node_cls(self),
]

@classmethod
def make(cls, *args, **kwargs):
expr = cls(*args, **kwargs)
active_module_tracer().current_scope().insert(expr)
return expr.outputs[0]

def interpret(self, *inputs):
if isinstance(self.value, RawTensor):
return Const(self.value.numpy())()
return (self.value,)

def __repr__(self):
return "{} = Constant({})".format(self.outputs[0], self.value)

def __getstate__(self):
state = self.__dict__.copy()
if isinstance(self.value, RawTensor):
state["value"] = Tensor(self.value)
return state

+ 52
- 0
imperative/python/megengine/experimental/traced_module/module_tracer.py View File

@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ...module import Module

_active_module_tracer = None


def active_module_tracer():
return _active_module_tracer


def set_active_module_tracer(tracer):
global _active_module_tracer
_active_module_tracer = tracer


class module_tracer:

_opaque_types = set()

_active_scopes = None

def __init__(self):
self._active_scopes = []

@classmethod
def register_as_builtin(cls, mod):
assert issubclass(mod, Module)
cls._opaque_types.add(mod)
return mod

@classmethod
def is_builtin(cls, mod):
return type(mod) in cls._opaque_types

def push_scope(self, scope):
self._active_scopes.append(scope)

def pop_scope(self):
self._active_scopes.pop()

def current_scope(self):
if self._active_scopes:
return self._active_scopes[-1]
return None

+ 123
- 0
imperative/python/megengine/experimental/traced_module/node.py View File

@@ -0,0 +1,123 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Any, Dict, Tuple, Type

import numpy

from ...core._imperative_rt.core2 import Tensor as RawTensor
from ...module import Module
from ...tensor import Tensor


class Node:
"""
``Node`` represents the variables (Tensor/Module/other python object) used in Module's forward method. They are inputs/outputs of Expr(the operations on variables).

param expr: the Expr which produces the node
param name: the name of the node
"""

expr = None
__total_id = 0
_id = None
_name = None

def __init__(self, expr: "Expr", name: str = None):
self.expr = expr
self._id = Node.__total_id
Node.__total_id += 1
self._name = name

def __repr__(self):
if self._name is None:
return "%{}".format(self._id)
else:
return "%{}".format(self._name)


class ModuleNode(Node):
"""
``ModuleNode`` represents the Module objects.

Attributes:
module_type: type of the Module correspending to the ModuleNode
graph: the InternalGraph which will be interpreted when call Module's forward method
attr_type_map: record the type of Module's attributes
"""

module_type = Module # type: Type[Module]
graph = None
attr_type_map = None # type: Dict[str, Type[Any]]

def __repr__(self):
if self._name is None:
return "%{}({})".format(self._id, self.module_type.__name__)
else:
return "%{}({})".format(self._name, self.module_type.__name__)


class TensorNode(Node):
"""
``TensorNode`` represents the Tensor objects.
"""

shape = None # type: Tuple[int]
dtype = None # type: numpy.dtype

def __repr__(self):
if self._name is None:
return "%{}(Tensor)".format(self._id)
else:
return "%{}(Tensor)".format(self._name)


class NodeMixin:
__node = None

@classmethod
def wrap(cls, value, node):
if isinstance(value, (NodeMixin, RawTensor)):
if isinstance(node, Node):
if isinstance(value, RawTensor):
node.dtype = value.dtype
node.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
setattr(value, "_NodeMixin__node", node)
else:
assert callable(node)
n = node()
if isinstance(value, RawTensor):
n.dtype = value.dtype
n.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
setattr(value, "_NodeMixin__node", n)

@classmethod
def wrap_safe(cls, value, node):
assert isinstance(value, (NodeMixin, RawTensor))
if isinstance(value, RawTensor):
node.dtype = value.dtype
node.shape = (
value._tuple_shape if isinstance(value, Tensor) else value.shape
)
setattr(value, "_NodeMixin__node", node)

@classmethod
def get(cls, value, *default):
return getattr(value, "_NodeMixin__node", *default)

@classmethod
def get_wrapped_type(cls, value):
if isinstance(value, RawTensor):
return TensorNode
if isinstance(value, (Module, NodeMixin)):
return ModuleNode
return Node

+ 295
- 0
imperative/python/megengine/experimental/traced_module/traced_module.py View File

@@ -0,0 +1,295 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import copy
from typing import List, Type

from ... import module as M
from ...core._imperative_rt.core2 import set_module_tracing, unset_module_tracing
from ...module import Module
from ...tensor import Tensor
from .expr import Apply, Call, Constant, Expr, GetAttr, Input
from .module_tracer import active_module_tracer, module_tracer, set_active_module_tracer
from .node import ModuleNode, Node, NodeMixin, TensorNode


class InternalGraph:
"""
``InternalGraph`` is a graph consist of ``Node`` and ``Expr``, it is used to represent the execution procedure of Module's forward method.

Attributes:
_exprs: List of Exprs in order of execution
_inputs: Input Nodes of InternalGraph
_outputs: Output Nodes of InternalGraph
"""

_exprs = None # type: List[Expr]
_inputs = None # type: List[Node]
_outputs = None # type: List[Node]

def __init__(self):
self._exprs = []
self._inputs = []
self._outputs = []

def insert(self, expr):
self._exprs.append(expr)

def add_input(self, i):
self._inputs.append(i)

def add_output(self, o):
self._outputs.append(o)

def interpret(self, *inputs):
# TODO: support kwargs ?
# TODO: skip expressions which are independent and have no side effect
node2value = {}
for n, v in zip(self._inputs, inputs):
node2value[n] = v
for expr in self._exprs:
values = expr.interpret(*list(node2value[i] for i in expr.inputs))
for n, v in zip(expr.outputs, values):
node2value[n] = v
return list(node2value[i] for i in self._outputs)

def __repr__(self):
return "InternalGraph ({}) {{\n\t{}\n\treturn {}\n}}".format(
", ".join(str(i) for i in self._inputs),
"\n\t".join(str(i) for i in self._exprs),
", ".join(str(i) for i in self._outputs),
)


class TracedModuleBuilder(NodeMixin):

_mod = None # type: Module
_body = None # type: InternalGraph
_is_builtin = None # type: bool

__builder_attributes__ = [
"_mod",
"_body",
"_NodeMixin__node",
"_is_builtin",
"_is_traced",
"build",
]

def __init__(self, mod):
super(TracedModuleBuilder, self).__init__()
self._mod = mod
self._body = InternalGraph()
self._is_traced = False
self._is_builtin = module_tracer.is_builtin(mod)

def build(self):
if self._is_builtin:
node = NodeMixin.get(self)
node.module_type = type(self._mod)
return self._mod
else:
node = NodeMixin.get(self)
node.graph = self._body
node.attr_type_map = {}
traced_module = TracedModule(node)
for k, v in self.__dict__.items():
if k not in TracedModuleBuilder.__builder_attributes__:
if isinstance(v, TracedModuleBuilder):
v = v.build()
setattr(traced_module, k, v)
traced_module.m_node.attr_type_map[k] = type(v)
return traced_module

def __call__(self, *inputs, **kwargs):
assert isinstance(self._mod, Module)

# prepare args and kwargs for inner graph
def mark_constant(x):
node = NodeMixin.get(x, None)
if node is None: # capture as constant
NodeMixin.wrap(x, lambda: Constant.make(x))

for i in inputs:
mark_constant(i)
for k, v in kwargs.items():
mark_constant(v)
callnode = Call.make(NodeMixin.get(self))

def add_input(x):
callnode.add_input(NodeMixin.get(x))

for i in inputs:
add_input(i)
for k, v in kwargs.items():
add_input(v)

if self._is_builtin or self._is_traced:
unset_module_tracing()
outputs = self._mod(*inputs, **kwargs)
set_module_tracing()
if self._is_builtin:
self._body = None
else:
active_module_tracer().push_scope(self._body)
# rebind self to new input node
orig_self = NodeMixin.get(self)
NodeMixin.wrap_safe(
self, Input.make("self", NodeMixin.get_wrapped_type(self))
)
# prepare args and kwargs for inner graph
def wrap(x):
wrapped = copy.copy(x) # FIXME
NodeMixin.wrap(
wrapped,
lambda: Input.make(type=NodeMixin.get_wrapped_type(wrapped)),
)
return wrapped

args = []
for i in inputs:
args.append(wrap(i))
for k, v in kwargs.items():
kwargs[k] = wrap(v)

outputs = type(self._mod).forward(self, *args, **kwargs)

for i in (
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,)
):
active_module_tracer().current_scope().add_output(NodeMixin.get(i))

NodeMixin.wrap_safe(self, orig_self)
self._is_traced = True
active_module_tracer().pop_scope()

# rebind output to outer graph
callnode.add_outputs(outputs)
for i, node in zip(
outputs if isinstance(outputs, collections.abc.Sequence) else (outputs,),
callnode.outputs,
):
NodeMixin.wrap_safe(i, node)
return outputs

def __getattr__(self, name):
if name not in self._mod.__dict__:
attr = getattr(type(self._mod), name).__get__(self, type(self))
else:
attr = getattr(self._mod, name)
if isinstance(attr, Module):
attr = TracedModuleBuilder(attr)
setattr(self, name, attr)
NodeMixin.wrap(
attr,
lambda: GetAttr.make(
NodeMixin.get(self), name, type=NodeMixin.get_wrapped_type(attr)
),
)
return attr

def __getattribute__(self, name):
if name in TracedModuleBuilder.__builder_attributes__:
return super().__getattribute__(name)
else:
wrapped = super().__getattribute__(name)
if name in self._mod.__dict__ and not NodeMixin.get(wrapped, None):
assert not self._is_builtin
NodeMixin.wrap(
wrapped,
lambda: GetAttr.make(
NodeMixin.get(self),
name,
type=NodeMixin.get_wrapped_type(wrapped),
),
)
return wrapped


class TracedModule(Module):
"""
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node), and will interpret the m_node.graph when it is called.
"""

m_node = None # type: ModuleNode

def __init__(self, node):
super(TracedModule, self).__init__()
self.m_node = node

def forward(self, *inputs):
rst = self.m_node.graph.interpret(self, *inputs)
if len(rst) == 1:
rst = rst[0]
return rst

def __getstate__(self):
d = self.__dict__
for k in Module.__dict__:
d.pop(k, None)
return d


def cpp_apply_module_trace(opdef, *args):
return Apply.apply_module_trace_hook(opdef, *args)


def register_as_builtin(mod_cls: Type[Module]) -> None:
"""
Registers class ``mod_cls`` (subclass of megengine.module.Module) as builtin module.

param mod_cls: the Module class which will be threated as builtin module in tracing
"""
module_tracer.register_as_builtin(mod_cls)


def _register_all_builtin_module():
from inspect import getmembers, isclass

for sub_mod in [M, M.qat, M.quantized]:
for m in getmembers(sub_mod):
if (
isclass(m[1])
and issubclass(m[1], M.Module)
and m[1] is not M.Sequential
):
module_tracer.register_as_builtin(m[1])


def trace_module(mod: Module, *inputs: Tensor, **kwargs: Tensor) -> TracedModule:
"""
Traces module ``mod`` and returns corresponding TracedModule.

param mod: the module will be converted to TracedModule
param input: the positional arguments passed to forward method of ``mod``
param kwargs: the keyword arguments passed to forward method of ``mod``
"""
assert active_module_tracer() is None
try:
set_module_tracing()
set_active_module_tracer(module_tracer())
global_scope = InternalGraph()

active_module_tracer().push_scope(global_scope)

builder = TracedModuleBuilder(mod)
NodeMixin.wrap_safe(builder, Input.make("TopModule", ModuleNode))

for _, i in enumerate(inputs):
NodeMixin.wrap_safe(i, Input.make("arg_{}".format(_)))
for k, v in kwargs.items():
NodeMixin.wrap_safe(v, Input.make("kwarg_{}".format(k)))

builder(*inputs, **kwargs)
active_module_tracer().pop_scope()

return builder.build()
finally:
set_active_module_tracer(None)
unset_module_tracing()

Loading…
Cancel
Save