From 80d3a06f9efb24228d1266ec51a03ad69822eea7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 27 Mar 2020 19:37:12 +0800 Subject: [PATCH] fix(mge/param_pack): fix param pack with just one param GitOrigin-RevId: 0d28a12e594c43e42a7486857dace201da2a3dfc --- python_module/megengine/module/parampack.py | 2 +- python_module/test/integration/test_parampack.py | 36 ++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/python_module/megengine/module/parampack.py b/python_module/megengine/module/parampack.py index d91aee83..fad1e27d 100644 --- a/python_module/megengine/module/parampack.py +++ b/python_module/megengine/module/parampack.py @@ -82,7 +82,7 @@ class ParamPack(Module): group = group[idx:] if idx == 1: # ignore param packs with only one item - self._packed_params.append(params[0]) + self._packed_params.append(params[0]['tensor']) self._grouped_params.append(params) continue diff --git a/python_module/test/integration/test_parampack.py b/python_module/test/integration/test_parampack.py index 678971a2..0a1b3780 100644 --- a/python_module/test/integration/test_parampack.py +++ b/python_module/test/integration/test_parampack.py @@ -111,6 +111,42 @@ def test_static_graph_parampack(): assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" @pytest.mark.slow +def test_nopack_parampack(): + net = XORNet() + net = ParamPack(net, + max_size_per_group=0, + max_nr_params_per_group=0) + opt = SGD( + net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 + ) + + @trace(symbolic=True) + def train(data, label): + pred = net(data) + opt.zero_grad() + loss = cross_entropy_with_softmax(pred, label) + opt.backward(loss) + return loss + + @trace(symbolic=True) + def infer(data): + return net(data) + + train_dataset = minibatch_generator() + losses = [] + + for data, label in itertools.islice(train_dataset, 2000): + loss = train(data, label) + loss = loss[0][0] + opt.step() + losses.append(loss.numpy()) + assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" + + data, _ = next(train_dataset) + pred = infer(data).numpy() + assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" + +@pytest.mark.slow def test_dynamic_graph_parampack(): net = XORNet() net = ParamPack(net,