Browse Source

refactor(mge/traced_module): refactor Node naming rule and merge GetAttr

GitOrigin-RevId: a43ad1273c
release-1.7
Megvii Engine Team 3 years ago
parent
commit
a6fe7f7ff2
6 changed files with 802 additions and 483 deletions
  1. +87
    -91
      imperative/python/megengine/traced_module/expr.py
  2. +1
    -2
      imperative/python/megengine/traced_module/module_tracer.py
  3. +76
    -25
      imperative/python/megengine/traced_module/node.py
  4. +346
    -359
      imperative/python/megengine/traced_module/traced_module.py
  5. +97
    -6
      imperative/python/test/unit/traced_module/test_modification.py
  6. +195
    -0
      imperative/python/test/unit/traced_module/test_qat_module.py

+ 87
- 91
imperative/python/megengine/traced_module/expr.py View File

@@ -11,7 +11,7 @@ import collections
import copy import copy
import inspect import inspect
import re import re
from typing import Callable, Dict, List
from typing import Callable, Dict, List, Optional, Union


from ..core._imperative_rt import OpDef from ..core._imperative_rt import OpDef
from ..core._imperative_rt.core2 import Tensor as RawTensor from ..core._imperative_rt.core2 import Tensor as RawTensor
@@ -32,6 +32,43 @@ def rstrip(s: str, __chars: str):
return s return s




def get_suffix_name(prefix: str, name: str):
if prefix == name:
return ""
matchd = re.compile("^%s\.(.*)" % prefix).match(name)
if matchd is None:
return None
return matchd.group(1)


def is_call_module(expr):
return (
isinstance(expr, CallMethod)
and isinstance(expr.inputs[0], ModuleNode)
and expr.method == "__call__"
)


def is_call_tensor_method(expr):
return isinstance(expr, CallMethod) and not is_call_module(expr)


def is_call_function(expr):
return isinstance(expr, CallFunction)


def is_constant(expr):
return isinstance(expr, Constant)


def is_getattr(expr):
return isinstance(expr, GetAttr)


def is_apply_def(expr):
return isinstance(expr, Apply)


class Expr: class Expr:
r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``, r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
``GetAttr``, ``Input``, ``Constant``) on ``Node``. ``GetAttr``, ``Input``, ``Constant``) on ``Node``.
@@ -76,50 +113,19 @@ class Expr:
self.const_val.append((idx, val)) self.const_val.append((idx, val))


def add_outputs(self, outputs): def add_outputs(self, outputs):
assert active_module_tracer() is not None
self.outputs = [] self.outputs = []
if outputs is not None:
if not isinstance(outputs, collections.Sequence):
outputs = (outputs,)

name = None
orig_name = None
if isinstance(self, CallMethod):
name = self.inputs[0]._name
orig_name = self.inputs[0]._orig_name
assert isinstance(name, str), "The name of ({}) must be a str".format(
self.inputs[0]
)
assert isinstance(
orig_name, str
), "The orig_name of ({}) must be a str".format(self.inputs[0])
name = rstrip(name, "_out")
if self.method == "__call__":
name += "_out"
orig_name += "_out"
else:
strip_method = self.method.strip("_")
name = "%s_out" % strip_method
orig_name = name
elif isinstance(self, CallFunction):
name = self.func.__name__ + "_out"
elif isinstance(self, Apply):
name = str(self.opdef).lower() + "_out"

for i in outputs:
assert isinstance(i, RawTensor), "The output must be a Tensor"
o_name = (
active_module_tracer().current_scope()._create_unique_name(name)
)
self.outputs.append(
NodeMixin.get_wrapped_type(i)(
expr=self,
name=o_name,
orig_name=orig_name if orig_name else o_name,
)
)

for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node)
if outputs is None:
return
current_graph = active_module_tracer().current_scope()
if not isinstance(outputs, collections.Sequence):
outputs = (outputs,)
for i in outputs:
assert isinstance(i, RawTensor), "The output must be a Tensor"
node = NodeMixin.get_wrapped_type(i)(expr=self, name="", qualname="",)
NodeMixin.wrap_safe(i, node)
self.outputs.append(node)
current_graph._namespace.auto_naming_for_outputs(self)


def unflatten_args(self, inputs): def unflatten_args(self, inputs):
if self.arg_def is not None: if self.arg_def is not None:
@@ -152,9 +158,7 @@ class Expr:
), "({}) must be generated before ({})".format(repl_node, self) ), "({}) must be generated before ({})".format(repl_node, self)
idx = self.inputs.index(node) idx = self.inputs.index(node)
self.inputs[idx] = repl_node self.inputs[idx] = repl_node
user_idx = node.users.index(self)
assert user_idx >= 0
node.users.pop(user_idx)
node.users.remove(self)
repl_node.users.append(self) repl_node.users.append(self)


@property @property
@@ -197,26 +201,23 @@ class Input(Expr):
r"""A fake Expr which is used to mark the input of graph.""" r"""A fake Expr which is used to mark the input of graph."""
name = None name = None


def __init__(self, name=None, type=None, orig_name=None):
def __init__(self, type: List[Node], name: str = "args", qualname: str = ""):
super().__init__() super().__init__()
assert type in [ModuleNode, TensorNode]
assert name and qualname
self.inputs = [] self.inputs = []
node_cls = type if type else Node node_cls = type if type else Node
if orig_name is None:
orig_name = name
self.outputs = [ self.outputs = [
node_cls(self, name=name, orig_name=orig_name),
node_cls(self, name=name, qualname=qualname),
] ]
self.name = name self.name = name


@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
oup_node = expr.outputs[0]
name = (
active_module_tracer().current_scope()._create_unique_name(oup_node._name)
)
oup_node._name = name
active_module_tracer().current_scope()._add_input(oup_node)
out_node = expr.outputs[0]
active_module_tracer().current_scope()._add_input(out_node)
return expr.outputs[0] return expr.outputs[0]


def __repr__(self): def __repr__(self):
@@ -230,34 +231,41 @@ class GetAttr(Expr):
name = None name = None
r"""name: the qualified name of the attribute to be retrieved.""" r"""name: the qualified name of the attribute to be retrieved."""


def __init__(self, module, name, type=None, orig_name=None):
def __init__(
self, module: ModuleNode, type: Union[Node], attr_name: str, name: str = "",
):
super().__init__() super().__init__()
assert isinstance(module, ModuleNode) assert isinstance(module, ModuleNode)
assert type in [TensorNode, ModuleNode]
self.inputs = [ self.inputs = [
module, module,
] ]
module.users.append(self) module.users.append(self)
self.name = name
node_cls = type if type else Node
self.name = attr_name
self.outputs = [ self.outputs = [
node_cls(self, name=name, orig_name=orig_name),
type(self, name=name, qualname="{}.{}".format(module.qualname, attr_name)),
] ]


@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
current_graph = active_module_tracer().current_scope()
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
module = expr.inputs[0]
oup_name = expr.name
while module._name != "self":
oup_name = module._name + "_" + oup_name
module = module.expr.inputs[0]
oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name)
expr.outputs[0]._name = oup_name
active_module_tracer().current_scope()._insert(expr)
current_graph._namespace.auto_naming_for_outputs(expr)
current_graph._insert(expr)
return expr.outputs[0] return expr.outputs[0]


def interpret(self, *inputs): def interpret(self, *inputs):
return (getattr(inputs[0], self.name),)
mod = inputs[0]
module_path, _, name = self.name.rpartition(".")
if module_path == "":
return (getattr(mod, name),)
module_names = module_path.split(".")
for item in module_names:
mod = getattr(mod, item)
if not isinstance(mod, Module):
raise AttributeError("`{}` is not an Module".format(item))
return (getattr(mod, name),)


def __repr__(self): def __repr__(self):
out_type = "Tensor" out_type = "Tensor"
@@ -297,6 +305,7 @@ class CallMethod(Expr):


@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
active_module_tracer().current_scope()._insert(expr) active_module_tracer().current_scope()._insert(expr)
return expr return expr
@@ -362,6 +371,7 @@ class Apply(Expr):


@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
active_module_tracer().current_scope()._insert(expr) active_module_tracer().current_scope()._insert(expr)
return expr return expr
@@ -435,6 +445,7 @@ class CallFunction(Expr):


@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
active_module_tracer().current_scope()._insert(expr) active_module_tracer().current_scope()._insert(expr)
return expr return expr
@@ -474,7 +485,7 @@ class Constant(Expr):
# TODO: constant cache to reduce the size of dumped model # TODO: constant cache to reduce the size of dumped model
_constant_cache = {} _constant_cache = {}


def __init__(self, c, name=None):
def __init__(self, c, name: str = "", qualname: str = ""):
super().__init__() super().__init__()
assert isinstance(c, (RawTensor, Module)) assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module): if isinstance(c, Module):
@@ -484,31 +495,16 @@ class Constant(Expr):
self.inputs = [] self.inputs = []
node_cls = NodeMixin.get_wrapped_type(c) node_cls = NodeMixin.get_wrapped_type(c)
self.outputs = [ self.outputs = [
node_cls(self, name=name, orig_name=name),
node_cls(self, name=name, qualname=qualname),
] ]
self.outputs[0]._name = name if name else "const_" + str(self._id)


@classmethod @classmethod
def make(cls, *args, **kwargs): def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
expr = cls(*args, **kwargs) expr = cls(*args, **kwargs)
name = "const_module" if isinstance(expr.value, Module) else "const_tensor"
full_name = name
if (
isinstance(expr.value, RawTensor)
and id(expr.value) in active_module_tracer().id2name
):
full_name = active_module_tracer().id2name[id(expr.value)]
scope_name = active_module_tracer().current_scope()._module_name
if full_name and scope_name:
full_name = ("self." + full_name)[len(scope_name) + 1 :]
else:
full_name = name
else:
full_name = name
name = active_module_tracer().current_scope()._create_unique_name(full_name)
expr.outputs[0]._name = name
expr.outputs[0]._orig_name = full_name
active_module_tracer().current_scope()._insert(expr)
current_graph = active_module_tracer().current_scope()
current_graph._namespace.auto_naming_for_outputs(expr)
current_graph._insert(expr)
return expr.outputs[0] return expr.outputs[0]


def interpret(self, *inputs): def interpret(self, *inputs):


+ 1
- 2
imperative/python/megengine/traced_module/module_tracer.py View File

@@ -128,10 +128,9 @@ class module_tracer:


_active_scopes = None _active_scopes = None


def __init__(self, wrap_fn, id2name):
def __init__(self, wrap_fn):
self._active_scopes = [] self._active_scopes = []
self.patcher = Patcher(wrap_fn) self.patcher = Patcher(wrap_fn)
self.id2name = id2name


@classmethod @classmethod
def register_as_builtin(cls, mod): def register_as_builtin(cls, mod):


+ 76
- 25
imperative/python/megengine/traced_module/node.py View File

@@ -29,17 +29,15 @@ class Node:
__total_id = 0 # type: int __total_id = 0 # type: int
_id = None # type: int _id = None # type: int
_top_graph = None # type: weakref.ReferenceType _top_graph = None # type: weakref.ReferenceType
_name = None # type: str
_orig_name = None # type: str
_format_spec = "" # type: str _format_spec = "" # type: str


def __init__(self, expr, name: str, orig_name: str):
def __init__(self, expr, name: str, qualname: str):
self.expr = expr self.expr = expr
self.users = [] # List[Expr] self.users = [] # List[Expr]
self._id = Node.__total_id self._id = Node.__total_id
Node.__total_id += 1 Node.__total_id += 1
self._name = name self._name = name
self._orig_name = orig_name
self._qualname = qualname
self.actual_node = [] # type: List[Node] self.actual_node = [] # type: List[Node]


def __repr__(self): def __repr__(self):
@@ -54,21 +52,10 @@ class Node:
name = "" name = ""
if format_spec in ["i", "p", "ip", "pi"]: if format_spec in ["i", "p", "ip", "pi"]:
if "p" in format_spec: if "p" in format_spec:
graph = self.top_graph
prefix_name = ""
if graph is not None:
prefix_name = graph._name
if graph._prefix_name:
prefix_name = "{}_{}".format(
graph._prefix_name, prefix_name.lstrip("_")
)
if name:
name = "_" + name.lstrip("_")
name = "{}{}".format(prefix_name, name)
prefix_name = self.top_graph._name
name = "{}_{}".format(prefix_name, name)
if "i" in format_spec: if "i" in format_spec:
if name:
name = "_" + name.lstrip("_")
name = "%{}{}".format(self._id, name)
name = "%{}_{}".format(self._id, name)
return name return name
else: else:
return name if name else ("%d" % self._id) return name if name else ("%d" % self._id)
@@ -80,15 +67,62 @@ class Node:


@name.setter @name.setter
def name(self, new_name: str): def name(self, new_name: str):
r"""Set a new name to this Node."""
graph = self.top_graph graph = self.top_graph
assert graph is not None, "The parent graph of this Node cannot be None." assert graph is not None, "The parent graph of this Node cannot be None."
assert new_name not in graph._used_names, (
assert new_name not in graph._namespace.used_names, (
"The name(%s) is already in use. Please try a different one again." "The name(%s) is already in use. Please try a different one again."
% (new_name) % (new_name)
) )
new_name = graph._create_unique_name(new_name)
new_name = graph._namespace.create_unique_name(new_name)
self._name = new_name self._name = new_name
self._orig_name = new_name

@property
def qualname(self):
r"""Get the `qualname` of this Node. The `qualname` can be used to get the
submodule from the traced Module or Module.

Example:
.. code-block::

import megengine.module as M
import megengine.functional as F
import megengine.traced_module as tm
import megengine as mge

class block(M.Module):
def __init__(self):
super().__init__()
self.param = mge.Tensor([1.])
self.relu = M.ReLU()

def forward(self, x):
x = x + self.param
return self.relu(F.relu(x))

class module(M.Module):
def __init__(self):
super().__init__()
self.block = block()

def forward(self, x):
x = self.block(x)
return x

net = module()
traced_net = tm.trace_module(net, mge.Tensor([0.]))
traced_net = traced_net.flatten()
out_node = traced_net.graph.outputs[0]

# qualname : "module.block.relu.[out]"
qualname = out_node.qualname
# qualname : "block.relu"
qualname = qualname.split(".", 1)[-1].rsplit(".", 1)[0]

assert qualname in list(map(lambda x: x[0], net.named_modules()))
assert qualname in list(map(lambda x: x[0], traced_net.named_modules()))
"""
return self._qualname


@property @property
def top_graph(self): def top_graph(self):
@@ -120,8 +154,8 @@ class ModuleNode(Node):
r"""The type of the Module correspending to the ModuleNode.""" r"""The type of the Module correspending to the ModuleNode."""
_owner = None # type: weakref.ReferenceType _owner = None # type: weakref.ReferenceType


def __init__(self, expr, name: str = None, orig_name: str = None):
super().__init__(expr, name, orig_name)
def __init__(self, expr, name: str = None, qualname: str = None):
super().__init__(expr, name, qualname)


def __getstate__(self): def __getstate__(self):
return { return {
@@ -129,10 +163,15 @@ class ModuleNode(Node):
"users": self.users, "users": self.users,
"_id": self._id, "_id": self._id,
"_name": self._name, "_name": self._name,
"_orig_name": self._orig_name,
"_qualname": self._qualname,
"module_type": self.module_type, "module_type": self.module_type,
} }


def __setstate__(self, state):
if "_orig_name" in state:
state["_qualname"] = state.pop("_orig_name")
self.__dict__.update(state)

@property @property
def owner(self): def owner(self):
r"""Get the ``Module`` corresponding to this ``ModuleNode``. r"""Get the ``Module`` corresponding to this ``ModuleNode``.
@@ -161,9 +200,21 @@ class TensorNode(Node):
"_dtype": self._dtype, "_dtype": self._dtype,
"_device": self._device, "_device": self._device,
"_name": self._name, "_name": self._name,
"_orig_name": self._orig_name,
"_qualname": self._qualname,
} }


def __setstate__(self, state):
if "_orig_name" in state:
qualname = state.pop("_orig_name")
modulepath, comma, qualname = qualname.rpartition(".")
expr_name = state["expr"].__class__.__name__
if expr_name not in ["GetAttr"]:
qualname = "[{}]".format(qualname)
if comma:
qualname = "{}.{}".format(modulepath, qualname)
state["_qualname"] = qualname
self.__dict__.update(state)

@property @property
def shape(self): def shape(self):
r"""Get the shape of this Node.""" r"""Get the shape of this Node."""


+ 346
- 359
imperative/python/megengine/traced_module/traced_module.py
File diff suppressed because it is too large
View File


+ 97
- 6
imperative/python/test/unit/traced_module/test_modification.py View File

@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle import pickle
from itertools import chain


import numpy as np import numpy as np


@@ -13,8 +14,8 @@ import megengine.functional as F
import megengine.module as M import megengine.module as M
from megengine.module.identity import Identity from megengine.module.identity import Identity
from megengine.traced_module import trace_module from megengine.traced_module import trace_module
from megengine.traced_module.expr import CallFunction, Expr, GetAttr
from megengine.traced_module.node import Node
from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input
from megengine.traced_module.node import ModuleNode, Node




class IdentityMod(M.Module): class IdentityMod(M.Module):
@@ -85,6 +86,34 @@ def test_search():
relu_expr = graph.get_function_by_type(F.relu).as_unique() relu_expr = graph.get_function_by_type(F.relu).as_unique()
assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu


conv_node = graph.get_module_by_type(M.Conv2d).as_unique()
assert isinstance(conv_node, ModuleNode) and conv_node.module_type == M.Conv2d

add_expr = graph.get_method_by_type("__add__").as_unique()
assert isinstance(add_expr, CallMethod) and add_expr.method == "__add__"

conv_node = graph.get_node_by_name("MyBlock_conv1").as_unique()
assert isinstance(conv_node, ModuleNode) and conv_node.module_type == M.Conv2d


def test_producer_and_users():
traced_module, *_ = _init_module()

def _check(exprs):
for expr in exprs:
for n in chain(expr.inputs, expr.outputs):
if not isinstance(n.expr, Input):
assert n.expr in exprs
for e in n.users:
assert e in exprs
assert n in e.inputs

for mod in traced_module.modules():
if not hasattr(mod, "argdef_graph_map"):
continue
for g in mod.argdef_graph_map.values():
_check(g._exprs)



def test_insert(): def test_insert():
traced_module, x, expect = _init_block() traced_module, x, expect = _init_block()
@@ -97,6 +126,54 @@ def test_insert():
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6) np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)




def test_insert_module():
class Neg(M.Module):
def forward(self, x):
return F.neg(x)

traced_module, x, expect = _init_block()
graph = traced_module.graph
relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0]
self = graph.inputs[0]
setattr(traced_module, "neg", Neg())
with graph.insert_exprs():
neg_out = self.neg(relu_out)
graph.replace_node({relu_out: neg_out})
graph.compile()
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
assert traced_module.neg.graph is not None
assert len(traced_module.neg.graph._exprs) == 1


def test_add_input_and_output():
traced_module, x, y = _init_module()

data_node = traced_module.graph.add_input_node(shape=(1, 3, 224, 224), name="data")
traced_module.graph.add_output_node(data_node)

assert data_node.name == "data"
assert traced_module.graph.inputs[-1] == data_node
assert len(traced_module.graph.inputs) == 3
assert len(traced_module.graph.outputs) == 2

y1, y2 = traced_module(x, x)
np.testing.assert_equal(y1.numpy(), y.numpy())
np.testing.assert_equal(y2.numpy(), x.numpy())

y1, y2 = traced_module(x, y)
np.testing.assert_equal(y2.numpy(), y.numpy())

traced_module.graph.reset_outputs(
({"orig_out": traced_module.graph.outputs[0]}, traced_module.graph.outputs[1])
)

out = traced_module(x, x)
assert isinstance(out, tuple)
assert isinstance(out[0], dict)
np.testing.assert_equal(out[0]["orig_out"].numpy(), y.numpy())
np.testing.assert_equal(out[1].numpy(), x.numpy())


def test_delete(): def test_delete():
traced_module, x, expect = _init_block() traced_module, x, expect = _init_block()
graph = traced_module.graph graph = traced_module.graph
@@ -117,8 +194,10 @@ def test_delete():
def test_flatten(): def test_flatten():
traced_module, x, expect = _init_module() traced_module, x, expect = _init_module()
traced_module = traced_module.flatten() traced_module = traced_module.flatten()
traced_module.graph.compile()
assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs)
assert len(traced_module.graph._exprs) == 12
np.testing.assert_equal(expect.numpy(), traced_module(x).numpy())

traced_module = traced_module.flatten()
assert len(traced_module.graph._exprs) == 12 assert len(traced_module.graph._exprs) == 12
np.testing.assert_equal(expect.numpy(), traced_module(x).numpy()) np.testing.assert_equal(expect.numpy(), traced_module(x).numpy())


@@ -128,7 +207,7 @@ def test_id_and_name():
_total_ids = traced_module.graph._total_ids _total_ids = traced_module.graph._total_ids
node_ids = [n._id for n in traced_module.graph.nodes().as_list()] node_ids = [n._id for n in traced_module.graph.nodes().as_list()]
assert len(set(node_ids)) == len(node_ids) assert len(set(node_ids)) == len(node_ids)
assert max(node_ids) + 1 == len(node_ids)
assert max(node_ids) + 1 == _total_ids[0]


expr_ids = [n._id for n in traced_module.graph.exprs().as_list()] expr_ids = [n._id for n in traced_module.graph.exprs().as_list()]
assert len(set(expr_ids)) == len(expr_ids) assert len(set(expr_ids)) == len(expr_ids)
@@ -177,7 +256,7 @@ def test_id_and_name():
_check_name(flattened_module) _check_name(flattened_module)




def test_set_name():
def test_set_node_name():
traced_module, x, expect = _init_module() traced_module, x, expect = _init_module()
graph = traced_module.graph graph = traced_module.graph
output_node = graph.outputs[0] output_node = graph.outputs[0]
@@ -190,6 +269,18 @@ def test_set_name():
np.testing.assert_equal(str(graph.outputs[0]), "output") np.testing.assert_equal(str(graph.outputs[0]), "output")




def test_set_graph_name():
traced_module, x, expect = _init_module()
graph = traced_module.graph
output_node = graph.outputs[0]

node_name = output_node.name

graph.name = "Top"
node = graph.get_node_by_name("{}_{}".format("Top", node_name)).as_unique()
assert node is output_node


def test_extra_block(): def test_extra_block():
class PostProcess(M.Module): class PostProcess(M.Module):
def forward(self, x): def forward(self, x):


+ 195
- 0
imperative/python/test/unit/traced_module/test_qat_module.py View File

@@ -0,0 +1,195 @@
import io
from functools import partial
from itertools import chain
from typing import Callable

import numpy as np

import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.quantization as Q
from megengine import Tensor
from megengine.module.qat.module import QATModule
from megengine.traced_module import TracedModule, trace_module


def get_subattr(self: M.Module, name: str):
if name == "":
return self
module_path, _, name = name.rpartition(".")
if module_path == "":
return getattr(self, name)
module_names = module_path.split(".")
for item in module_names:
self = getattr(self, item)
if not isinstance(self, M.Module):
raise AttributeError("`{}` is not an Module".format(item))
return getattr(self, name)


class Myblcok(M.Module):
def __init__(self,):
super().__init__()
self.conv0 = M.ConvBnRelu2d(3, 3, 3, 1, 1)
self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0)
self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0)
self.add = M.Elemwise("FUSE_ADD_RELU")

def forward(self, x):
x = self.conv0(x)
x0 = self.conv1(x)
x1 = self.conv2(x)
o = self.add(x0, x1)
return o


class MyModule(M.Module):
def __init__(self):
super().__init__()
self.block0 = Myblcok()
self.block1 = Myblcok()

def forward(self, x):
x = self.block0(x)
x = self.block1(x)
return x


class MyMinMaxObserver(Q.MinMaxObserver):
pass


class MyTQT(Q.TQT):
pass


def get_lsq_config(lsq_cls):
return Q.QConfig(
weight_observer=None,
act_observer=None,
weight_fake_quant=partial(lsq_cls, dtype="qint8_narrow"),
act_fake_quant=partial(lsq_cls, dtype="qint8"),
)


def get_observer_config(observer_cls):
return Q.QConfig(
weight_observer=partial(observer_cls, dtype="qint8_narrow"),
act_observer=partial(observer_cls, dtype="qint8"),
weight_fake_quant=None,
act_fake_quant=None,
)


def get_qparams(mod: QATModule):
weight_qparams, act_qparams = None, None
if mod.act_observer is not None:
act_qparams = mod.act_observer.get_qparams()
if mod.act_fake_quant:
act_qparams = mod.act_fake_quant.get_qparams()

if mod.weight_observer is not None:
weight_qparams = mod.weight_observer.get_qparams()
if mod.weight_fake_quant:
weight_qparams = mod.weight_fake_quant.get_qparams()

return weight_qparams, act_qparams


def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams):
assert qparmsa.dtype_meta == qparmsb.dtype_meta
assert qparmsa.mode == qparmsb.mode
np.testing.assert_equal(qparmsa.scale.numpy(), qparmsb.scale.numpy())
if qparmsa.zero_point is not None:
np.testing.assert_equal(qparmsa.zero_point.numpy(), qparmsb.zero_point.numpy())


def build_observered_net(net: M.Module, observer_cls):
qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls))
Q.enable_observer(qat_net)
for _ in range(5):
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
qat_net(inp)
Q.disable_observer(qat_net)
return qat_net


def build_fakequanted_net(net: QATModule, fakequant_cls):
qat_net = Q.reset_qconfig(net, get_lsq_config(fakequant_cls))
return qat_net


def test_trace_qat():
def _check_qat_module(qat_net: QATModule):
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
traced_net = trace_module(qat_net, inp)

for name, qat_module in qat_net.named_modules():
if not isinstance(qat_module, QATModule):
continue
traced_qat_module = get_subattr(traced_net, name)
weight_qparams, act_qparams = get_qparams(qat_module)
traced_weight_qparams, traced_act_qparams = get_qparams(traced_qat_module)
if weight_qparams:
check_qparams(weight_qparams, traced_weight_qparams)
if act_qparams:
check_qparams(act_qparams, traced_act_qparams)

_check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver))
_check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver))
_check_qat_module(
build_fakequanted_net(build_observered_net(MyModule(), Q.MinMaxObserver), Q.TQT)
)
_check_qat_module(
build_fakequanted_net(build_observered_net(MyModule(), Q.MinMaxObserver), MyTQT)
)


def test_load_param():
def _check_param(moda: M.Module, modb: M.Module):
for name, attr in chain(moda.named_parameters(), moda.named_buffers()):
traced_attr = get_subattr(modb, name)
np.testing.assert_equal(attr.numpy(), traced_attr.numpy())

def _check_module(build_func: Callable):
net = build_func()
buffer = io.BytesIO()
mge.save(net.state_dict(), buffer)
buffer.seek(0)

inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
traced_net = trace_module(build_func(), inp)
traced_net.load_state_dict(mge.load(buffer))

_check_param(net, traced_net)

buffer.seek(0)
traced_net = trace_module(build_func(), inp).flatten()
traced_net.load_state_dict(mge.load(buffer))

_check_param(net, traced_net)

_check_module(lambda: MyModule())
_check_module(lambda: build_observered_net(MyModule(), Q.MinMaxObserver))


def test_qualname():
def _check_qualname(net):
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
traced_net = trace_module(net, inp)
base_qualname = traced_net.graph.qualname
for node in traced_net.graph.nodes():
qualname = node.qualname
qualname = qualname[len(base_qualname) + 1 :]
if qualname.endswith("]"):
qualname = qualname.rsplit(".", 1)[0]
if qualname.startswith("["):
qualname = ""
traced_attr = get_subattr(traced_net, qualname)
orig_attr = get_subattr(net, qualname)
assert traced_attr is not None
assert orig_attr is not None

_check_qualname(MyModule())
_check_qualname(build_observered_net(MyModule(), Q.MinMaxObserver))

Loading…
Cancel
Save