You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_save_load.py 2.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import os
  2. import numpy as np
  3. import megengine as mge
  4. import megengine.autodiff as ad
  5. import megengine.module as M
  6. import megengine.optimizer as optimizer
  7. from megengine import Parameter, tensor
  8. from megengine.module import Module
  9. class Simple(Module):
  10. def __init__(self):
  11. super().__init__()
  12. self.a = Parameter([1.23], dtype=np.float32)
  13. def forward(self, x):
  14. x = x * self.a
  15. return x
  16. class Net(Module):
  17. def __init__(self):
  18. super().__init__()
  19. self.fc = M.Linear(1, 1)
  20. def forward(self, images):
  21. x = self.fc(images)
  22. loss = x.mean() * 10000
  23. return loss
  24. def test_load_state_dict_no_cache(monkeypatch):
  25. with monkeypatch.context() as mk:
  26. mk.setenv("MEGENGINE_INPLACE_UPDATE", "1")
  27. net = Net()
  28. optim = optimizer.SGD(net.parameters(), lr=0.1)
  29. gm = ad.GradManager().attach(net.parameters())
  30. state = {
  31. "fc.weight": np.array([[0]], dtype=np.float32),
  32. "fc.bias": np.array([0.0], dtype=np.float32),
  33. }
  34. net.load_state_dict(state)
  35. images = mge.tensor([[0]], dtype=np.float32)
  36. with gm:
  37. loss = net(images)
  38. gm.backward(loss)
  39. optim.step()
  40. optim.clear_grad()
  41. def test_save_load():
  42. net = Simple()
  43. optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
  44. optim.clear_grad()
  45. gm = ad.GradManager().attach(net.parameters())
  46. data = tensor([2.34])
  47. with gm:
  48. loss = net(data)
  49. gm.backward(loss)
  50. optim.step()
  51. model_name = "simple.pkl"
  52. mge.save(
  53. {
  54. "name": "simple",
  55. "state_dict": net.state_dict(),
  56. "opt_state": optim.state_dict(),
  57. },
  58. model_name,
  59. )
  60. # Load param to cpu
  61. checkpoint = mge.load(model_name, map_location="cpu0")
  62. device_save = mge.get_default_device()
  63. mge.set_default_device("cpu0")
  64. net = Simple()
  65. net.load_state_dict(checkpoint["state_dict"])
  66. optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9)
  67. optim.load_state_dict(checkpoint["opt_state"])
  68. os.remove("simple.pkl")
  69. with gm:
  70. loss = net([1.23])
  71. gm.backward(loss)
  72. optim.step()
  73. # Restore device
  74. mge.set_default_device(device_save)