Browse Source

test(mge/optimizer): fix test for new optimizer api

GitOrigin-RevId: 482ee62652
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
66b6daf777
5 changed files with 21 additions and 17 deletions
  1. +1
    -1
      imperative/python/megengine/distributed/functional.py
  2. +1
    -1
      imperative/python/test/integration/test_dp_correctness.py
  3. +6
    -5
      imperative/python/test/integration/test_optimizer.py
  4. +8
    -7
      imperative/python/test/integration/test_sgd_momentum.py
  5. +5
    -3
      imperative/python/test/integration/test_trace_dump.py

+ 1
- 1
imperative/python/megengine/distributed/functional.py View File

@@ -40,7 +40,7 @@ __all__ = [
]


@apply.add
@apply.register()
def _(op: RemoteSend, *args: Tensor):
ret = tensor_apply(op, *args)



+ 1
- 1
imperative/python/test/integration/test_dp_correctness.py View File

@@ -133,7 +133,7 @@ def update_model(model_path):
data = Tensor(checkpoint["data"], dtype=np.float32)
label = Tensor(checkpoint["label"], dtype=np.int32)

opt.zero_grad()
opt.clear_grad()
loss = train(data, label, net=net, opt=opt)
opt.step()



+ 6
- 5
imperative/python/test/integration/test_optimizer.py View File

@@ -73,17 +73,18 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
for symbolic in (False, True):

@trace(symbolic=symbolic)
def train_func(data, *, opt=None):
opt.zero_grad()
with opt.record():
def train_func(data, *, opt=None, gm=None):
opt.clear_grad()
with gm.record():
pred = net(data)
loss = pred.sum()
opt.backward(loss)
gm.backward(loss)
opt.step()

# reset net and opt
net = Simple()
opt = getattr(optimizer, opt_str)(net.parameters(), **test_case)
gm = ad.GradManager().register(net.parameters())
check_func = check_class(net, **test_case)
step = 0
for i in range(iter_num):
@@ -96,7 +97,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
for param in net.parameters():
ori_params[param] = np.copy(param.numpy())

train_func(np.random.random(data_shape).astype(np.float32), opt=opt)
train_func(np.random.random(data_shape).astype(np.float32), opt=opt, gm=gm)
step += 1
check_func(ori_params, net.parameters(), step)



+ 8
- 7
imperative/python/test/integration/test_sgd_momentum.py View File

@@ -67,23 +67,24 @@ def test_sgd_momentum_trace():
for symbolic in (True, False):

@trace(symbolic=symbolic)
def train_func(data, *, model=None, optim=None):
optim.zero_grad()
with optim.record():
def train_func(data, *, model=None, optim=None, gm=None):
optim.clear_grad()
with gm.record():
loss = net(data)
optim.backward(loss)
gm.backward(loss)
optim.step()
return loss

@trace(symbolic=symbolic)
def eval_func(data, *, model=None, optim=None):
def eval_func(data, *, model=None, optim=None, gm=None):
loss = net(data)
return loss

net = Simple()
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
gm = ad.GradManager().register(net.parameters())
data = tensor([2.34])
train_func(data, model=net, optim=optim)
train_func(data, model=net, optim=optim, gm=gm)
np.testing.assert_almost_equal(
optim._state[net.a]["momentum_buffer"].numpy(), 2.34
)
@@ -97,7 +98,7 @@ def test_sgd_momentum_trace():
)

# do a step of train
train_func(data, model=net, optim=optim)
train_func(data, model=net, optim=optim, gm=gm)
np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5)
np.testing.assert_almost_equal(
optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34


+ 5
- 3
imperative/python/test/integration/test_trace_dump.py View File

@@ -17,6 +17,7 @@ import megengine.functional as F
import megengine.module as M
import megengine.optimizer as optim
from megengine import tensor
from megengine.autodiff import GradManager
from megengine.jit import trace


@@ -61,17 +62,18 @@ class XORNet(M.Module):
def test_xornet_trace_dump():
net = XORNet()
opt = optim.SGD(net.parameters(requires_grad=True), lr=0.01, momentum=0.9)
gm = GradManager().register(net.parameters(requires_grad=True))
batch_size = 64
train_dataset = minibatch_generator(batch_size)
val_dataset = minibatch_generator(batch_size)

@trace
def train_fun(data, label):
with opt.record():
with gm.record():
net.train()
pred = net(data)
loss = F.cross_entropy_with_softmax(pred, label)
opt.backward(loss)
gm.backward(loss)
return pred, loss

@trace
@@ -95,7 +97,7 @@ def test_xornet_trace_dump():
break
data = tensor(minibatch["data"])
label = tensor(minibatch["label"])
opt.zero_grad()
opt.clear_grad()
_, loss = train_fun(data, label)
train_loss.append((step, loss.numpy()))
if step % 50 == 0:


Loading…
Cancel
Save