Browse Source

test(traced_module): add some testcases for traced module

GitOrigin-RevId: 0d6bb20b2b
release-1.6
Megvii Engine Team 4 years ago
parent
commit
442b4f6c26
7 changed files with 147 additions and 20 deletions
  1. +2
    -1
      imperative/python/megengine/experimental/traced_module/expr.py
  2. +27
    -4
      imperative/python/megengine/experimental/traced_module/pytree.py
  3. +9
    -4
      imperative/python/test/integration/test_converge.py
  4. +8
    -4
      imperative/python/test/integration/test_converge_with_gradient_clip.py
  5. +1
    -0
      imperative/python/test/integration/test_trace_dump.py
  6. +41
    -7
      imperative/python/test/unit/module/test_module.py
  7. +59
    -0
      imperative/python/test/unit/traced_module/test_jit_trace.py

+ 2
- 1
imperative/python/megengine/experimental/traced_module/expr.py View File

@@ -201,7 +201,8 @@ class Apply(Expr):
NodeMixin.wrap_safe(i, Constant.make(i)) NodeMixin.wrap_safe(i, Constant.make(i))
apply_node = cls.make(opdef) apply_node = cls.make(opdef)
for i in inputs: for i in inputs:
apply_node.add_input(NodeMixin.get(i))
assert isinstance(i, RawTensor)
apply_node.inputs.append(NodeMixin.get(i))


unset_module_tracing() unset_module_tracing()
outputs = apply(opdef, *inputs) outputs = apply(opdef, *inputs)


+ 27
- 4
imperative/python/megengine/experimental/traced_module/pytree.py View File

@@ -1,3 +1,13 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import collections
from typing import Callable, NamedTuple from typing import Callable, NamedTuple


SUPPORTED_TYPE = {} SUPPORTED_TYPE = {}
@@ -9,11 +19,22 @@ def register_supported_type(type, flatten, unflatten):
SUPPORTED_TYPE[type] = NodeType(flatten, unflatten) SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)




def _dict_flatten(inp):
aux_data = []
results = []
for key, value in sorted(inp.items()):
results.append(value)
aux_data.append(key)
return results, aux_data


def _dict_unflatten(inps, aux_data):
return dict(zip(aux_data, inps))


register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x)) register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x)) register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x))
register_supported_type(
dict, lambda x: (list(x.values()), list(x.keys())), lambda x, y: dict(zip(y, x))
)
register_supported_type(dict, _dict_flatten, _dict_unflatten)
register_supported_type( register_supported_type(
slice, slice,
lambda x: ([x.start, x.stop, x.step], None), lambda x: ([x.start, x.stop, x.step], None),
@@ -68,6 +89,8 @@ class TreeDef:


class LeafDef(TreeDef): class LeafDef(TreeDef):
def __init__(self, type): def __init__(self, type):
if not isinstance(type, collections.abc.Sequence):
type = (type,)
super().__init__(type, None, []) super().__init__(type, None, [])
self.num_leaves = 1 self.num_leaves = 1


@@ -77,4 +100,4 @@ class LeafDef(TreeDef):
return leaves[0] return leaves[0]


def __repr__(self): def __repr__(self):
return "Leaf({})".format(self.type.__name__)
return "Leaf({})".format(", ".join(t.__name__ for t in self.type))

+ 9
- 4
imperative/python/test/integration/test_converge.py View File

@@ -14,6 +14,7 @@ import megengine as mge
import megengine.autodiff as ad import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
from megengine import Tensor from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.module import Linear, Module from megengine.module import Linear, Module
from megengine.optimizer import SGD from megengine.optimizer import SGD


@@ -71,8 +72,13 @@ class XORNet(Module):
return x return x




def test_training_converge():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_training_converge(test_traced_module):
net = XORNet() net = XORNet()
if test_training_converge:
inp = Tensor(np.random.random((14, 2)))
net = trace_module(net, inp)

opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
gm = ad.GradManager().attach(net.parameters()) gm = ad.GradManager().attach(net.parameters())


@@ -105,9 +111,8 @@ def test_training_converge():
xx = xx.reshape((ngrid * ngrid, 1)) xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1)) yy = yy.reshape((ngrid * ngrid, 1))
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))

pred = infer(data).numpy()
precision = calculate_precision(data.numpy(), pred)
pred = infer(data)
precision = calculate_precision(data.numpy(), pred.numpy())
assert precision == 1.0, "Test precision must be high enough, get {}".format( assert precision == 1.0, "Test precision must be high enough, get {}".format(
precision precision
) )

+ 8
- 4
imperative/python/test/integration/test_converge_with_gradient_clip.py View File

@@ -15,6 +15,7 @@ import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
import megengine.optimizer as optim import megengine.optimizer as optim
from megengine import Tensor from megengine import Tensor
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace from megengine.jit import trace
from megengine.module import Linear, Module from megengine.module import Linear, Module
from megengine.optimizer import SGD from megengine.optimizer import SGD
@@ -73,8 +74,12 @@ class XORNet(Module):
return x return x




def test_training_converge():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_training_converge(test_traced_module):
net = XORNet() net = XORNet()
if test_traced_module:
inp = Tensor(np.random.random((14, 2)))
net = trace_module(net, inp)
opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) opt = SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
gm = ad.GradManager().attach(net.parameters()) gm = ad.GradManager().attach(net.parameters())


@@ -110,9 +115,8 @@ def test_training_converge():
xx = xx.reshape((ngrid * ngrid, 1)) xx = xx.reshape((ngrid * ngrid, 1))
yy = yy.reshape((ngrid * ngrid, 1)) yy = yy.reshape((ngrid * ngrid, 1))
data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32)) data = mge.tensor(np.concatenate((xx, yy), axis=1).astype(np.float32))

pred = infer(data).numpy()
precision = calculate_precision(data.numpy(), pred)
pred = infer(data)
precision = calculate_precision(data.numpy(), pred.numpy())
print("precision=", precision) print("precision=", precision)
assert precision == 1.0, "Test precision must be high enough, get {}".format( assert precision == 1.0, "Test precision must be high enough, get {}".format(
precision precision


+ 1
- 0
imperative/python/test/integration/test_trace_dump.py View File

@@ -19,6 +19,7 @@ import megengine.module as M
import megengine.optimizer as optim import megengine.optimizer as optim
from megengine import tensor from megengine import tensor
from megengine.autodiff import GradManager from megengine.autodiff import GradManager
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace from megengine.jit import trace






+ 41
- 7
imperative/python/test/unit/module/test_module.py View File

@@ -15,6 +15,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine.functional as F import megengine.functional as F
from megengine import Parameter, Tensor, tensor from megengine import Parameter, Tensor, tensor
from megengine.experimental.traced_module import TracedModule, trace_module
from megengine.module import ( from megengine.module import (
BatchNorm1d, BatchNorm1d,
BatchNorm2d, BatchNorm2d,
@@ -67,8 +68,18 @@ class MyModule(Module):
return x return x




def test_module_api():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_module_api(test_traced_module):
m = MyModule() m = MyModule()
if test_traced_module:
buff = m.buff
param = m.param
m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16))))
assert "buff" not in m.__dict__
assert "param" not in m.__dict__
m.buff = buff
m.param = param

assert list(m.children()) == [m.bn, m.i] assert list(m.children()) == [m.bn, m.i]
assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)] assert list(m.named_children()) == [("bn", m.bn), ("i", m.i)]
assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] assert list(m.modules()) == [m, m.bn, m.i, m.i.bn]
@@ -141,8 +152,11 @@ def test_module_api():
assert m.bn.training == False and m.i.bn.training == False assert m.bn.training == False and m.i.bn.training == False




def test_module_api_reuse_submodule():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_module_api_reuse_submodule(test_traced_module):
m = MyModule() m = MyModule()
if test_traced_module:
m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16))))
m.h = m.i # pylint: disable=attribute-defined-outside-init m.h = m.i # pylint: disable=attribute-defined-outside-init
assert list(m.modules()) == [m, m.bn, m.i, m.i.bn] assert list(m.modules()) == [m, m.bn, m.i, m.i.bn]
assert list(m.named_modules()) == [ assert list(m.named_modules()) == [
@@ -153,15 +167,21 @@ def test_module_api_reuse_submodule():
] ]




def test_module_api_iterable_stability():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_module_api_iterable_stability(test_traced_module):
m = MyModule() m = MyModule()
if test_traced_module:
m = trace_module(m, Tensor(np.random.random((1, 4, 16, 16))))
l = list(m.modules()) l = list(m.modules())
for _ in range(100): for _ in range(100):
assert list(m.modules()) == l assert list(m.modules()) == l




def test_module_api_hooks():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_module_api_hooks(test_traced_module):
net = MyModule() net = MyModule()
if test_traced_module:
net = trace_module(net, Tensor(np.zeros((1, 4, 1, 1))))
pre_hook_num = 0 pre_hook_num = 0
post_hook_num = 0 post_hook_num = 0
hooks = [] hooks = []
@@ -383,11 +403,16 @@ class Simple(Module):
self.conv1.weight = self.conv0.weight self.conv1.weight = self.conv0.weight


def forward(self, inputs): def forward(self, inputs):
pass
x = self.conv0(inputs)
y = self.conv1(inputs)
return x + y




def test_shared_param():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_shared_param(test_traced_module):
net = Simple() net = Simple()
if test_traced_module:
net = trace_module(net, tensor(np.random.random((1, 1, 8, 8))))
assert net.conv0.weight is net.conv1.weight assert net.conv0.weight is net.conv1.weight
data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32)) data = tensor(np.random.random((1, 1, 8, 8)).astype(np.float32))
np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy()) np.testing.assert_allclose(net.conv0(data).numpy(), net.conv1(data).numpy())
@@ -449,15 +474,21 @@ def test_shared_param_1d():
np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy()) np.testing.assert_allclose(conv0(data).numpy(), conv1(data).numpy())




def test_pickle_module():
@pytest.mark.parametrize("test_traced_module", [True, False])
def test_pickle_module(test_traced_module):
data_shape = (2, 28) data_shape = (2, 28)
data = tensor(np.random.random(data_shape)) data = tensor(np.random.random(data_shape))
mlp = MLP() mlp = MLP()
pred_gt = mlp(data)
if test_traced_module:
mlp = trace_module(mlp, data)
# pickle before forward # pickle before forward
with BytesIO() as fout: with BytesIO() as fout:
mge.save(mlp, fout) mge.save(mlp, fout)
fout.seek(0) fout.seek(0)
mlp1 = mge.load(fout) mlp1 = mge.load(fout)
if test_traced_module:
assert type(mlp1) == TracedModule
pred0 = mlp1(data) pred0 = mlp1(data)


pred1 = mlp(data) pred1 = mlp(data)
@@ -467,8 +498,11 @@ def test_pickle_module():
mge.save(mlp, fout) mge.save(mlp, fout)
fout.seek(0) fout.seek(0)
mlp1 = mge.load(fout) mlp1 = mge.load(fout)
if test_traced_module:
assert type(mlp1) == TracedModule
pred2 = mlp1(data) pred2 = mlp1(data)


np.testing.assert_allclose(pred_gt.numpy(), pred1.numpy(), atol=5e-6)
np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6) np.testing.assert_allclose(pred0.numpy(), pred1.numpy(), atol=5e-6)
np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6) np.testing.assert_allclose(pred0.numpy(), pred2.numpy(), atol=5e-6)




+ 59
- 0
imperative/python/test/unit/traced_module/test_jit_trace.py View File

@@ -0,0 +1,59 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import io

import numpy as np

import megengine.functional as F
import megengine.module as M
import megengine.utils.comp_graph_tools as cgtools
from megengine.experimental.traced_module import trace_module
from megengine.jit import trace
from megengine.module import Module


class MyBlock(Module):
def __init__(self, in_channels, channels):
super(MyBlock, self).__init__()
self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=1, bias=False)
self.bn1 = M.BatchNorm2d(channels)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x) + 1
return x


class MyModule(Module):
def __init__(self):
super(MyModule, self).__init__()
self.block0 = MyBlock(8, 4)
self.block1 = MyBlock(4, 2)

def forward(self, x):
x = self.block0(x)
x = self.block1(x)
return x


def test_jit_trace():
module = MyModule()
module.eval()
x = F.ones((1, 8, 14, 14))
expect = module(x)
traced_module = trace_module(module, x)
func = trace(traced_module, capture_as_const=True)
np.testing.assert_array_equal(func(x), expect)
model = io.BytesIO()
func.dump(model)
model.seek(0)
infer_cg = cgtools.GraphInference(model)
np.testing.assert_allclose(
list(infer_cg.run(x.numpy()).values())[0], expect, atol=1e-6
)

Loading…
Cancel
Save