You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_lr_scheduler.py 866 B

123456789101112131415161718192021222324252627282930313233
  1. # -*- coding: utf-8 -*-
  2. from bisect import bisect_right
  3. import numpy as np
  4. from megengine import Parameter, tensor
  5. from megengine.module import Module
  6. from megengine.optimizer import SGD, MultiStepLR
  7. class Simple(Module):
  8. def __init__(self):
  9. super().__init__()
  10. self.a = Parameter([1.23], dtype=np.float32)
  11. def forward(self, x):
  12. x = x * self.a
  13. return x
  14. def test_multi_step_lr():
  15. net = Simple()
  16. opt = SGD(net.parameters(), lr=0.01, momentum=0.9)
  17. scheduler = MultiStepLR(opt, [3, 6, 8])
  18. lr = np.array(0.01, dtype=np.float32)
  19. for i in range(10):
  20. for group in opt.param_groups:
  21. np.testing.assert_almost_equal(
  22. np.array(group["lr"], dtype=np.float32),
  23. (lr * 0.1 ** bisect_right([3, 6, 8], i)).astype(np.float32),
  24. )
  25. scheduler.step()