Browse Source

refactor(mge/traced_module): refactor Node naming rule and merge GetAttr

GitOrigin-RevId: a43ad1273c
release-1.7
Megvii Engine Team 3 years ago
parent
commit
a6fe7f7ff2
6 changed files with 802 additions and 483 deletions
  1. +87
    -91
      imperative/python/megengine/traced_module/expr.py
  2. +1
    -2
      imperative/python/megengine/traced_module/module_tracer.py
  3. +76
    -25
      imperative/python/megengine/traced_module/node.py
  4. +346
    -359
      imperative/python/megengine/traced_module/traced_module.py
  5. +97
    -6
      imperative/python/test/unit/traced_module/test_modification.py
  6. +195
    -0
      imperative/python/test/unit/traced_module/test_qat_module.py

+ 87
- 91
imperative/python/megengine/traced_module/expr.py View File

@@ -11,7 +11,7 @@ import collections
import copy
import inspect
import re
from typing import Callable, Dict, List
from typing import Callable, Dict, List, Optional, Union

from ..core._imperative_rt import OpDef
from ..core._imperative_rt.core2 import Tensor as RawTensor
@@ -32,6 +32,43 @@ def rstrip(s: str, __chars: str):
return s


def get_suffix_name(prefix: str, name: str):
if prefix == name:
return ""
matchd = re.compile("^%s\.(.*)" % prefix).match(name)
if matchd is None:
return None
return matchd.group(1)


def is_call_module(expr):
return (
isinstance(expr, CallMethod)
and isinstance(expr.inputs[0], ModuleNode)
and expr.method == "__call__"
)


def is_call_tensor_method(expr):
return isinstance(expr, CallMethod) and not is_call_module(expr)


def is_call_function(expr):
return isinstance(expr, CallFunction)


def is_constant(expr):
return isinstance(expr, Constant)


def is_getattr(expr):
return isinstance(expr, GetAttr)


def is_apply_def(expr):
return isinstance(expr, Apply)


class Expr:
r"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
``GetAttr``, ``Input``, ``Constant``) on ``Node``.
@@ -76,50 +113,19 @@ class Expr:
self.const_val.append((idx, val))

def add_outputs(self, outputs):
assert active_module_tracer() is not None
self.outputs = []
if outputs is not None:
if not isinstance(outputs, collections.Sequence):
outputs = (outputs,)

name = None
orig_name = None
if isinstance(self, CallMethod):
name = self.inputs[0]._name
orig_name = self.inputs[0]._orig_name
assert isinstance(name, str), "The name of ({}) must be a str".format(
self.inputs[0]
)
assert isinstance(
orig_name, str
), "The orig_name of ({}) must be a str".format(self.inputs[0])
name = rstrip(name, "_out")
if self.method == "__call__":
name += "_out"
orig_name += "_out"
else:
strip_method = self.method.strip("_")
name = "%s_out" % strip_method
orig_name = name
elif isinstance(self, CallFunction):
name = self.func.__name__ + "_out"
elif isinstance(self, Apply):
name = str(self.opdef).lower() + "_out"

for i in outputs:
assert isinstance(i, RawTensor), "The output must be a Tensor"
o_name = (
active_module_tracer().current_scope()._create_unique_name(name)
)
self.outputs.append(
NodeMixin.get_wrapped_type(i)(
expr=self,
name=o_name,
orig_name=orig_name if orig_name else o_name,
)
)

for i, node in zip(outputs, self.outputs,):
NodeMixin.wrap_safe(i, node)
if outputs is None:
return
current_graph = active_module_tracer().current_scope()
if not isinstance(outputs, collections.Sequence):
outputs = (outputs,)
for i in outputs:
assert isinstance(i, RawTensor), "The output must be a Tensor"
node = NodeMixin.get_wrapped_type(i)(expr=self, name="", qualname="",)
NodeMixin.wrap_safe(i, node)
self.outputs.append(node)
current_graph._namespace.auto_naming_for_outputs(self)

def unflatten_args(self, inputs):
if self.arg_def is not None:
@@ -152,9 +158,7 @@ class Expr:
), "({}) must be generated before ({})".format(repl_node, self)
idx = self.inputs.index(node)
self.inputs[idx] = repl_node
user_idx = node.users.index(self)
assert user_idx >= 0
node.users.pop(user_idx)
node.users.remove(self)
repl_node.users.append(self)

@property
@@ -197,26 +201,23 @@ class Input(Expr):
r"""A fake Expr which is used to mark the input of graph."""
name = None

def __init__(self, name=None, type=None, orig_name=None):
def __init__(self, type: List[Node], name: str = "args", qualname: str = ""):
super().__init__()
assert type in [ModuleNode, TensorNode]
assert name and qualname
self.inputs = []
node_cls = type if type else Node
if orig_name is None:
orig_name = name
self.outputs = [
node_cls(self, name=name, orig_name=orig_name),
node_cls(self, name=name, qualname=qualname),
]
self.name = name

@classmethod
def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
expr = cls(*args, **kwargs)
oup_node = expr.outputs[0]
name = (
active_module_tracer().current_scope()._create_unique_name(oup_node._name)
)
oup_node._name = name
active_module_tracer().current_scope()._add_input(oup_node)
out_node = expr.outputs[0]
active_module_tracer().current_scope()._add_input(out_node)
return expr.outputs[0]

def __repr__(self):
@@ -230,34 +231,41 @@ class GetAttr(Expr):
name = None
r"""name: the qualified name of the attribute to be retrieved."""

def __init__(self, module, name, type=None, orig_name=None):
def __init__(
self, module: ModuleNode, type: Union[Node], attr_name: str, name: str = "",
):
super().__init__()
assert isinstance(module, ModuleNode)
assert type in [TensorNode, ModuleNode]
self.inputs = [
module,
]
module.users.append(self)
self.name = name
node_cls = type if type else Node
self.name = attr_name
self.outputs = [
node_cls(self, name=name, orig_name=orig_name),
type(self, name=name, qualname="{}.{}".format(module.qualname, attr_name)),
]

@classmethod
def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
current_graph = active_module_tracer().current_scope()
expr = cls(*args, **kwargs)
module = expr.inputs[0]
oup_name = expr.name
while module._name != "self":
oup_name = module._name + "_" + oup_name
module = module.expr.inputs[0]
oup_name = active_module_tracer().current_scope()._create_unique_name(oup_name)
expr.outputs[0]._name = oup_name
active_module_tracer().current_scope()._insert(expr)
current_graph._namespace.auto_naming_for_outputs(expr)
current_graph._insert(expr)
return expr.outputs[0]

def interpret(self, *inputs):
return (getattr(inputs[0], self.name),)
mod = inputs[0]
module_path, _, name = self.name.rpartition(".")
if module_path == "":
return (getattr(mod, name),)
module_names = module_path.split(".")
for item in module_names:
mod = getattr(mod, item)
if not isinstance(mod, Module):
raise AttributeError("`{}` is not an Module".format(item))
return (getattr(mod, name),)

def __repr__(self):
out_type = "Tensor"
@@ -297,6 +305,7 @@ class CallMethod(Expr):

@classmethod
def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
expr = cls(*args, **kwargs)
active_module_tracer().current_scope()._insert(expr)
return expr
@@ -362,6 +371,7 @@ class Apply(Expr):

@classmethod
def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
expr = cls(*args, **kwargs)
active_module_tracer().current_scope()._insert(expr)
return expr
@@ -435,6 +445,7 @@ class CallFunction(Expr):

@classmethod
def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
expr = cls(*args, **kwargs)
active_module_tracer().current_scope()._insert(expr)
return expr
@@ -474,7 +485,7 @@ class Constant(Expr):
# TODO: constant cache to reduce the size of dumped model
_constant_cache = {}

def __init__(self, c, name=None):
def __init__(self, c, name: str = "", qualname: str = ""):
super().__init__()
assert isinstance(c, (RawTensor, Module))
if isinstance(c, Module):
@@ -484,31 +495,16 @@ class Constant(Expr):
self.inputs = []
node_cls = NodeMixin.get_wrapped_type(c)
self.outputs = [
node_cls(self, name=name, orig_name=name),
node_cls(self, name=name, qualname=qualname),
]
self.outputs[0]._name = name if name else "const_" + str(self._id)

@classmethod
def make(cls, *args, **kwargs):
assert active_module_tracer() is not None
expr = cls(*args, **kwargs)
name = "const_module" if isinstance(expr.value, Module) else "const_tensor"
full_name = name
if (
isinstance(expr.value, RawTensor)
and id(expr.value) in active_module_tracer().id2name
):
full_name = active_module_tracer().id2name[id(expr.value)]
scope_name = active_module_tracer().current_scope()._module_name
if full_name and scope_name:
full_name = ("self." + full_name)[len(scope_name) + 1 :]
else:
full_name = name
else:
full_name = name
name = active_module_tracer().current_scope()._create_unique_name(full_name)
expr.outputs[0]._name = name
expr.outputs[0]._orig_name = full_name
active_module_tracer().current_scope()._insert(expr)
current_graph = active_module_tracer().current_scope()
current_graph._namespace.auto_naming_for_outputs(expr)
current_graph._insert(expr)
return expr.outputs[0]

def interpret(self, *inputs):


+ 1
- 2
imperative/python/megengine/traced_module/module_tracer.py View File

@@ -128,10 +128,9 @@ class module_tracer:

_active_scopes = None

def __init__(self, wrap_fn, id2name):
def __init__(self, wrap_fn):
self._active_scopes = []
self.patcher = Patcher(wrap_fn)
self.id2name = id2name

@classmethod
def register_as_builtin(cls, mod):


+ 76
- 25
imperative/python/megengine/traced_module/node.py View File

@@ -29,17 +29,15 @@ class Node:
__total_id = 0 # type: int
_id = None # type: int
_top_graph = None # type: weakref.ReferenceType
_name = None # type: str
_orig_name = None # type: str
_format_spec = "" # type: str

def __init__(self, expr, name: str, orig_name: str):
def __init__(self, expr, name: str, qualname: str):
self.expr = expr
self.users = [] # List[Expr]
self._id = Node.__total_id
Node.__total_id += 1
self._name = name
self._orig_name = orig_name
self._qualname = qualname
self.actual_node = [] # type: List[Node]

def __repr__(self):
@@ -54,21 +52,10 @@ class Node:
name = ""
if format_spec in ["i", "p", "ip", "pi"]:
if "p" in format_spec:
graph = self.top_graph
prefix_name = ""
if graph is not None:
prefix_name = graph._name
if graph._prefix_name:
prefix_name = "{}_{}".format(
graph._prefix_name, prefix_name.lstrip("_")
)
if name:
name = "_" + name.lstrip("_")
name = "{}{}".format(prefix_name, name)
prefix_name = self.top_graph._name
name = "{}_{}".format(prefix_name, name)
if "i" in format_spec:
if name:
name = "_" + name.lstrip("_")
name = "%{}{}".format(self._id, name)
name = "%{}_{}".format(self._id, name)
return name
else:
return name if name else ("%d" % self._id)
@@ -80,15 +67,62 @@ class Node:

@name.setter
def name(self, new_name: str):
r"""Set a new name to this Node."""
graph = self.top_graph
assert graph is not None, "The parent graph of this Node cannot be None."
assert new_name not in graph._used_names, (
assert new_name not in graph._namespace.used_names, (
"The name(%s) is already in use. Please try a different one again."
% (new_name)
)
new_name = graph._create_unique_name(new_name)
new_name = graph._namespace.create_unique_name(new_name)
self._name = new_name
self._orig_name = new_name

@property
def qualname(self):
r"""Get the `qualname` of this Node. The `qualname` can be used to get the
submodule from the traced Module or Module.

Example:
.. code-block::

import megengine.module as M
import megengine.functional as F
import megengine.traced_module as tm
import megengine as mge

class block(M.Module):
def __init__(self):
super().__init__()
self.param = mge.Tensor([1.])
self.relu = M.ReLU()

def forward(self, x):
x = x + self.param
return self.relu(F.relu(x))

class module(M.Module):
def __init__(self):
super().__init__()
self.block = block()

def forward(self, x):
x = self.block(x)
return x

net = module()
traced_net = tm.trace_module(net, mge.Tensor([0.]))
traced_net = traced_net.flatten()
out_node = traced_net.graph.outputs[0]

# qualname : "module.block.relu.[out]"
qualname = out_node.qualname
# qualname : "block.relu"
qualname = qualname.split(".", 1)[-1].rsplit(".", 1)[0]

assert qualname in list(map(lambda x: x[0], net.named_modules()))
assert qualname in list(map(lambda x: x[0], traced_net.named_modules()))
"""
return self._qualname

@property
def top_graph(self):
@@ -120,8 +154,8 @@ class ModuleNode(Node):
r"""The type of the Module correspending to the ModuleNode."""
_owner = None # type: weakref.ReferenceType

def __init__(self, expr, name: str = None, orig_name: str = None):
super().__init__(expr, name, orig_name)
def __init__(self, expr, name: str = None, qualname: str = None):
super().__init__(expr, name, qualname)

def __getstate__(self):
return {
@@ -129,10 +163,15 @@ class ModuleNode(Node):
"users": self.users,
"_id": self._id,
"_name": self._name,
"_orig_name": self._orig_name,
"_qualname": self._qualname,
"module_type": self.module_type,
}

def __setstate__(self, state):
if "_orig_name" in state:
state["_qualname"] = state.pop("_orig_name")
self.__dict__.update(state)

@property
def owner(self):
r"""Get the ``Module`` corresponding to this ``ModuleNode``.
@@ -161,9 +200,21 @@ class TensorNode(Node):
"_dtype": self._dtype,
"_device": self._device,
"_name": self._name,
"_orig_name": self._orig_name,
"_qualname": self._qualname,
}

def __setstate__(self, state):
if "_orig_name" in state:
qualname = state.pop("_orig_name")
modulepath, comma, qualname = qualname.rpartition(".")
expr_name = state["expr"].__class__.__name__
if expr_name not in ["GetAttr"]:
qualname = "[{}]".format(qualname)
if comma:
qualname = "{}.{}".format(modulepath, qualname)
state["_qualname"] = qualname
self.__dict__.update(state)

@property
def shape(self):
r"""Get the shape of this Node."""


+ 346
- 359
imperative/python/megengine/traced_module/traced_module.py
File diff suppressed because it is too large
View File


+ 97
- 6
imperative/python/test/unit/traced_module/test_modification.py View File

@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import pickle
from itertools import chain

import numpy as np

@@ -13,8 +14,8 @@ import megengine.functional as F
import megengine.module as M
from megengine.module.identity import Identity
from megengine.traced_module import trace_module
from megengine.traced_module.expr import CallFunction, Expr, GetAttr
from megengine.traced_module.node import Node
from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input
from megengine.traced_module.node import ModuleNode, Node


class IdentityMod(M.Module):
@@ -85,6 +86,34 @@ def test_search():
relu_expr = graph.get_function_by_type(F.relu).as_unique()
assert isinstance(relu_expr, CallFunction) and relu_expr.func == F.relu

conv_node = graph.get_module_by_type(M.Conv2d).as_unique()
assert isinstance(conv_node, ModuleNode) and conv_node.module_type == M.Conv2d

add_expr = graph.get_method_by_type("__add__").as_unique()
assert isinstance(add_expr, CallMethod) and add_expr.method == "__add__"

conv_node = graph.get_node_by_name("MyBlock_conv1").as_unique()
assert isinstance(conv_node, ModuleNode) and conv_node.module_type == M.Conv2d


def test_producer_and_users():
traced_module, *_ = _init_module()

def _check(exprs):
for expr in exprs:
for n in chain(expr.inputs, expr.outputs):
if not isinstance(n.expr, Input):
assert n.expr in exprs
for e in n.users:
assert e in exprs
assert n in e.inputs

for mod in traced_module.modules():
if not hasattr(mod, "argdef_graph_map"):
continue
for g in mod.argdef_graph_map.values():
_check(g._exprs)


def test_insert():
traced_module, x, expect = _init_block()
@@ -97,6 +126,54 @@ def test_insert():
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)


def test_insert_module():
class Neg(M.Module):
def forward(self, x):
return F.neg(x)

traced_module, x, expect = _init_block()
graph = traced_module.graph
relu_out = graph.get_function_by_type(F.relu).as_unique().outputs[0]
self = graph.inputs[0]
setattr(traced_module, "neg", Neg())
with graph.insert_exprs():
neg_out = self.neg(relu_out)
graph.replace_node({relu_out: neg_out})
graph.compile()
np.testing.assert_allclose(expect - 1, 1 - traced_module(x), atol=1e-6)
assert traced_module.neg.graph is not None
assert len(traced_module.neg.graph._exprs) == 1


def test_add_input_and_output():
traced_module, x, y = _init_module()

data_node = traced_module.graph.add_input_node(shape=(1, 3, 224, 224), name="data")
traced_module.graph.add_output_node(data_node)

assert data_node.name == "data"
assert traced_module.graph.inputs[-1] == data_node
assert len(traced_module.graph.inputs) == 3
assert len(traced_module.graph.outputs) == 2

y1, y2 = traced_module(x, x)
np.testing.assert_equal(y1.numpy(), y.numpy())
np.testing.assert_equal(y2.numpy(), x.numpy())

y1, y2 = traced_module(x, y)
np.testing.assert_equal(y2.numpy(), y.numpy())

traced_module.graph.reset_outputs(
({"orig_out": traced_module.graph.outputs[0]}, traced_module.graph.outputs[1])
)

out = traced_module(x, x)
assert isinstance(out, tuple)
assert isinstance(out[0], dict)
np.testing.assert_equal(out[0]["orig_out"].numpy(), y.numpy())
np.testing.assert_equal(out[1].numpy(), x.numpy())


def test_delete():
traced_module, x, expect = _init_block()
graph = traced_module.graph
@@ -117,8 +194,10 @@ def test_delete():
def test_flatten():
traced_module, x, expect = _init_module()
traced_module = traced_module.flatten()
traced_module.graph.compile()
assert all(not isinstance(i, GetAttr) for i in traced_module.graph._exprs)
assert len(traced_module.graph._exprs) == 12
np.testing.assert_equal(expect.numpy(), traced_module(x).numpy())

traced_module = traced_module.flatten()
assert len(traced_module.graph._exprs) == 12
np.testing.assert_equal(expect.numpy(), traced_module(x).numpy())

@@ -128,7 +207,7 @@ def test_id_and_name():
_total_ids = traced_module.graph._total_ids
node_ids = [n._id for n in traced_module.graph.nodes().as_list()]
assert len(set(node_ids)) == len(node_ids)
assert max(node_ids) + 1 == len(node_ids)
assert max(node_ids) + 1 == _total_ids[0]

expr_ids = [n._id for n in traced_module.graph.exprs().as_list()]
assert len(set(expr_ids)) == len(expr_ids)
@@ -177,7 +256,7 @@ def test_id_and_name():
_check_name(flattened_module)


def test_set_name():
def test_set_node_name():
traced_module, x, expect = _init_module()
graph = traced_module.graph
output_node = graph.outputs[0]
@@ -190,6 +269,18 @@ def test_set_name():
np.testing.assert_equal(str(graph.outputs[0]), "output")


def test_set_graph_name():
traced_module, x, expect = _init_module()
graph = traced_module.graph
output_node = graph.outputs[0]

node_name = output_node.name

graph.name = "Top"
node = graph.get_node_by_name("{}_{}".format("Top", node_name)).as_unique()
assert node is output_node


def test_extra_block():
class PostProcess(M.Module):
def forward(self, x):


+ 195
- 0
imperative/python/test/unit/traced_module/test_qat_module.py View File

@@ -0,0 +1,195 @@
import io
from functools import partial
from itertools import chain
from typing import Callable

import numpy as np

import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.quantization as Q
from megengine import Tensor
from megengine.module.qat.module import QATModule
from megengine.traced_module import TracedModule, trace_module


def get_subattr(self: M.Module, name: str):
if name == "":
return self
module_path, _, name = name.rpartition(".")
if module_path == "":
return getattr(self, name)
module_names = module_path.split(".")
for item in module_names:
self = getattr(self, item)
if not isinstance(self, M.Module):
raise AttributeError("`{}` is not an Module".format(item))
return getattr(self, name)


class Myblcok(M.Module):
def __init__(self,):
super().__init__()
self.conv0 = M.ConvBnRelu2d(3, 3, 3, 1, 1)
self.conv1 = M.ConvBn2d(3, 3, 1, 1, 0)
self.conv2 = M.ConvBn2d(3, 3, 1, 1, 0)
self.add = M.Elemwise("FUSE_ADD_RELU")

def forward(self, x):
x = self.conv0(x)
x0 = self.conv1(x)
x1 = self.conv2(x)
o = self.add(x0, x1)
return o


class MyModule(M.Module):
def __init__(self):
super().__init__()
self.block0 = Myblcok()
self.block1 = Myblcok()

def forward(self, x):
x = self.block0(x)
x = self.block1(x)
return x


class MyMinMaxObserver(Q.MinMaxObserver):
pass


class MyTQT(Q.TQT):
pass


def get_lsq_config(lsq_cls):
return Q.QConfig(
weight_observer=None,
act_observer=None,
weight_fake_quant=partial(lsq_cls, dtype="qint8_narrow"),
act_fake_quant=partial(lsq_cls, dtype="qint8"),
)


def get_observer_config(observer_cls):
return Q.QConfig(
weight_observer=partial(observer_cls, dtype="qint8_narrow"),
act_observer=partial(observer_cls, dtype="qint8"),
weight_fake_quant=None,
act_fake_quant=None,
)


def get_qparams(mod: QATModule):
weight_qparams, act_qparams = None, None
if mod.act_observer is not None:
act_qparams = mod.act_observer.get_qparams()
if mod.act_fake_quant:
act_qparams = mod.act_fake_quant.get_qparams()

if mod.weight_observer is not None:
weight_qparams = mod.weight_observer.get_qparams()
if mod.weight_fake_quant:
weight_qparams = mod.weight_fake_quant.get_qparams()

return weight_qparams, act_qparams


def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams):
assert qparmsa.dtype_meta == qparmsb.dtype_meta
assert qparmsa.mode == qparmsb.mode
np.testing.assert_equal(qparmsa.scale.numpy(), qparmsb.scale.numpy())
if qparmsa.zero_point is not None:
np.testing.assert_equal(qparmsa.zero_point.numpy(), qparmsb.zero_point.numpy())


def build_observered_net(net: M.Module, observer_cls):
qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls))
Q.enable_observer(qat_net)
for _ in range(5):
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
qat_net(inp)
Q.disable_observer(qat_net)
return qat_net


def build_fakequanted_net(net: QATModule, fakequant_cls):
qat_net = Q.reset_qconfig(net, get_lsq_config(fakequant_cls))
return qat_net


def test_trace_qat():
def _check_qat_module(qat_net: QATModule):
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
traced_net = trace_module(qat_net, inp)

for name, qat_module in qat_net.named_modules():
if not isinstance(qat_module, QATModule):
continue
traced_qat_module = get_subattr(traced_net, name)
weight_qparams, act_qparams = get_qparams(qat_module)
traced_weight_qparams, traced_act_qparams = get_qparams(traced_qat_module)
if weight_qparams:
check_qparams(weight_qparams, traced_weight_qparams)
if act_qparams:
check_qparams(act_qparams, traced_act_qparams)

_check_qat_module(build_observered_net(MyModule(), Q.MinMaxObserver))
_check_qat_module(build_observered_net(MyModule(), MyMinMaxObserver))
_check_qat_module(
build_fakequanted_net(build_observered_net(MyModule(), Q.MinMaxObserver), Q.TQT)
)
_check_qat_module(
build_fakequanted_net(build_observered_net(MyModule(), Q.MinMaxObserver), MyTQT)
)


def test_load_param():
def _check_param(moda: M.Module, modb: M.Module):
for name, attr in chain(moda.named_parameters(), moda.named_buffers()):
traced_attr = get_subattr(modb, name)
np.testing.assert_equal(attr.numpy(), traced_attr.numpy())

def _check_module(build_func: Callable):
net = build_func()
buffer = io.BytesIO()
mge.save(net.state_dict(), buffer)
buffer.seek(0)

inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
traced_net = trace_module(build_func(), inp)
traced_net.load_state_dict(mge.load(buffer))

_check_param(net, traced_net)

buffer.seek(0)
traced_net = trace_module(build_func(), inp).flatten()
traced_net.load_state_dict(mge.load(buffer))

_check_param(net, traced_net)

_check_module(lambda: MyModule())
_check_module(lambda: build_observered_net(MyModule(), Q.MinMaxObserver))


def test_qualname():
def _check_qualname(net):
inp = Tensor(np.random.random(size=(5, 3, 32, 32)))
traced_net = trace_module(net, inp)
base_qualname = traced_net.graph.qualname
for node in traced_net.graph.nodes():
qualname = node.qualname
qualname = qualname[len(base_qualname) + 1 :]
if qualname.endswith("]"):
qualname = qualname.rsplit(".", 1)[0]
if qualname.startswith("["):
qualname = ""
traced_attr = get_subattr(traced_net, qualname)
orig_attr = get_subattr(net, qualname)
assert traced_attr is not None
assert orig_attr is not None

_check_qualname(MyModule())
_check_qualname(build_observered_net(MyModule(), Q.MinMaxObserver))

Loading…
Cancel
Save