Browse Source

fix(mge/param_pack): fix param pack with just one param

GitOrigin-RevId: 0d28a12e59
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
80d3a06f9e
2 changed files with 37 additions and 1 deletions
  1. +1
    -1
      python_module/megengine/module/parampack.py
  2. +36
    -0
      python_module/test/integration/test_parampack.py

+ 1
- 1
python_module/megengine/module/parampack.py View File

@@ -82,7 +82,7 @@ class ParamPack(Module):
group = group[idx:]
if idx == 1:
# ignore param packs with only one item
self._packed_params.append(params[0])
self._packed_params.append(params[0]['tensor'])
self._grouped_params.append(params)
continue



+ 36
- 0
python_module/test/integration/test_parampack.py View File

@@ -111,6 +111,42 @@ def test_static_graph_parampack():
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough"

@pytest.mark.slow
def test_nopack_parampack():
net = XORNet()
net = ParamPack(net,
max_size_per_group=0,
max_nr_params_per_group=0)
opt = SGD(
net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4
)

@trace(symbolic=True)
def train(data, label):
pred = net(data)
opt.zero_grad()
loss = cross_entropy_with_softmax(pred, label)
opt.backward(loss)
return loss

@trace(symbolic=True)
def infer(data):
return net(data)

train_dataset = minibatch_generator()
losses = []

for data, label in itertools.islice(train_dataset, 2000):
loss = train(data, label)
loss = loss[0][0]
opt.step()
losses.append(loss.numpy())
assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough"

data, _ = next(train_dataset)
pred = infer(data).numpy()
assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough"

@pytest.mark.slow
def test_dynamic_graph_parampack():
net = XORNet()
net = ParamPack(net,


Loading…
Cancel
Save