You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_optimizer.py 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. import megengine.autodiff as ad
  11. import megengine.functional as F
  12. from megengine import Parameter, optimizer
  13. from megengine.jit import trace
  14. from megengine.module import Linear, Module
  15. from megengine.tensor import Tensor
  16. class MLP(Module):
  17. def __init__(self):
  18. super().__init__()
  19. self.dense0 = Linear(28, 50)
  20. self.dense1 = Linear(50, 20)
  21. def forward(self, x):
  22. x = self.dense0(x)
  23. x = F.relu(x)
  24. x = self.dense1(x)
  25. return x
  26. class Simple(Module):
  27. def __init__(self):
  28. super().__init__()
  29. self.a = Parameter(1.23, dtype=np.float32)
  30. def forward(self, x):
  31. x = x * self.a
  32. return x
  33. def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
  34. iter_num = 3
  35. net = Simple()
  36. opt = getattr(optimizer, opt_str)(net.parameters(), **test_case)
  37. check_func = check_class(net, **test_case)
  38. gm = ad.GradManager().attach(net.parameters())
  39. step = 0
  40. data_shape = (2, 28)
  41. for i in range(iter_num):
  42. if update_lr and i == 1: # change learning rate
  43. for group in opt.param_groups:
  44. group["lr"] += 0.01
  45. check_func.lr += 0.01
  46. data = Tensor(np.random.random(data_shape).astype(np.float32))
  47. opt.clear_grad()
  48. with gm:
  49. pred = net(data)
  50. loss = pred.sum()
  51. gm.backward(loss)
  52. ori_params = {}
  53. for param in net.parameters():
  54. assert param._tuple_shape is ()
  55. ori_params[param] = np.copy(param.numpy())
  56. opt.step()
  57. step += 1
  58. check_func(ori_params, net.parameters(), step)
  59. # static graph
  60. for symbolic in (False, True):
  61. @trace(symbolic=symbolic)
  62. def train_func(data, *, opt=None, gm=None):
  63. opt.clear_grad()
  64. with gm:
  65. pred = net(data)
  66. loss = pred.sum()
  67. gm.backward(loss)
  68. opt.step()
  69. # reset net and opt
  70. net = Simple()
  71. opt = getattr(optimizer, opt_str)(net.parameters(), **test_case)
  72. gm = ad.GradManager().attach(net.parameters())
  73. check_func = check_class(net, **test_case)
  74. step = 0
  75. for i in range(iter_num):
  76. if update_lr and i == 1: # change learning rate
  77. for group in opt.param_groups:
  78. group["lr"] += 0.01
  79. check_func.lr += 0.01
  80. ori_params = {}
  81. for param in net.parameters():
  82. assert param._tuple_shape is ()
  83. ori_params[param] = np.copy(param.numpy())
  84. train_func(
  85. Tensor(np.random.random(data_shape).astype(np.float32)), opt=opt, gm=gm
  86. )
  87. step += 1
  88. check_func(ori_params, net.parameters(), step)
  89. try_state_dict = {
  90. "net": net.state_dict(),
  91. "opt": opt.state_dict(),
  92. }
  93. def test_sgd():
  94. class CheckValue:
  95. def __init__(self, net, **kwarg):
  96. self.slots = {}
  97. for param in net.parameters():
  98. self.slots[param] = np.zeros(param.shape).astype(np.float32)
  99. for k, v in kwarg.items():
  100. setattr(self, k, v)
  101. def __call__(self, ori_params, new_params, step):
  102. for param in new_params:
  103. grad = param.grad.numpy()
  104. if hasattr(self, "momentum"):
  105. self.slots[param] = grad + self.slots[param] * self.momentum
  106. delta = -self.lr * self.slots[param]
  107. else:
  108. delta = -self.lr * grad
  109. np.testing.assert_almost_equal(
  110. param.numpy(), ori_params[param] + delta, decimal=6
  111. )
  112. cases = [
  113. {"momentum": 0.9, "lr": 0.01}, # SGD with momentum
  114. {"lr": 0.01}, # simple SGD
  115. {"weight_decay": 0.1, "lr": 0.01}, # with weight_decay
  116. ]
  117. for case in cases:
  118. _test_optimizer("SGD", case, CheckValue)
  119. _test_optimizer("SGD", case, CheckValue, update_lr=True)
  120. def test_adam():
  121. class CheckValue:
  122. def __init__(self, net, **kwarg):
  123. self.m_slots = {}
  124. self.v_slots = {}
  125. for param in net.parameters():
  126. self.m_slots[param] = np.zeros(param.shape).astype(np.float32)
  127. self.v_slots[param] = np.zeros(param.shape).astype(np.float32)
  128. for k, v in kwarg.items():
  129. setattr(self, k, v)
  130. def __call__(self, ori_params, new_params, step):
  131. for param in new_params:
  132. grad = param.grad.numpy()
  133. m = self.m_slots[param]
  134. v = self.v_slots[param]
  135. m *= self.betas[0]
  136. m += (1 - self.betas[0]) * grad
  137. v *= self.betas[1]
  138. v += (1 - self.betas[1]) * grad * grad
  139. delta = (m / (1 - self.betas[0] ** step)) / (
  140. np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps
  141. )
  142. np.testing.assert_almost_equal(
  143. param.numpy(), ori_params[param] - self.lr * delta, decimal=6
  144. )
  145. cases = [
  146. {"betas": (0.8, 0.9), "eps": 1e-04, "lr": 0.01},
  147. {
  148. "betas": (0.8, 0.9),
  149. "eps": 1e-04,
  150. "lr": 0.01,
  151. "weight_decay": 0.1,
  152. }, # with weight_decay
  153. ]
  154. for case in cases:
  155. _test_optimizer("Adam", case, CheckValue)
  156. _test_optimizer("Adam", case, CheckValue, update_lr=True)
  157. def test_adagrad():
  158. class CheckValue:
  159. def __init__(self, net, **kwarg):
  160. self.s_slots = {}
  161. for param in net.parameters():
  162. self.s_slots[param] = np.zeros(param.shape).astype(np.float32)
  163. for k, v in kwarg.items():
  164. setattr(self, k, v)
  165. def __call__(self, ori_params, new_params, step):
  166. for param in new_params:
  167. grad = param.grad.numpy()
  168. self.s_slots[param] += grad ** 2
  169. delta = grad / (self.s_slots[param] + self.eps) ** 0.5
  170. delta *= -(self.lr / (1 + (step - 1) * self.lr_decay))
  171. np.testing.assert_almost_equal(
  172. param.numpy(), ori_params[param] + delta, decimal=6
  173. )
  174. cases = [
  175. {"lr": 0.01, "eps": 1e-06, "lr_decay": 0.01},
  176. {"lr": 0.01, "eps": 1e-06, "lr_decay": 0.0}, # without lr_decay
  177. {
  178. "lr": 0.01,
  179. "eps": 1e-06,
  180. "lr_decay": 0.01,
  181. "weight_decay": 0.1,
  182. }, # with weight_decay
  183. ]
  184. for case in cases:
  185. _test_optimizer("Adagrad", case, CheckValue)
  186. _test_optimizer("Adagrad", case, CheckValue, update_lr=True)
  187. def test_adadelta():
  188. class CheckValue:
  189. def __init__(self, net, **kwarg):
  190. self.s_slots = {}
  191. self.a_slots = {}
  192. for param in net.parameters():
  193. self.s_slots[param] = np.zeros(param.shape).astype(np.float32)
  194. self.a_slots[param] = np.zeros(param.shape).astype(np.float32)
  195. for k, v in kwarg.items():
  196. setattr(self, k, v)
  197. def __call__(self, ori_params, new_params, step):
  198. for param in new_params:
  199. grad = param.grad.numpy()
  200. self.s_slots[param] = self.s_slots[param] * self.rho + grad ** 2 * (
  201. 1 - self.rho
  202. )
  203. delta = (
  204. grad
  205. * ((self.a_slots[param] + self.eps) ** 0.5)
  206. / (self.s_slots[param] + self.eps) ** 0.5
  207. )
  208. self.a_slots[param] = self.a_slots[param] * self.rho + delta ** 2 * (
  209. 1 - self.rho
  210. )
  211. delta *= -self.lr
  212. np.testing.assert_almost_equal(
  213. param.numpy(), ori_params[param] + delta, decimal=6
  214. )
  215. cases = [
  216. {"lr": 1.0, "eps": 1e-06, "rho": 0.9},
  217. {"lr": 1.0, "eps": 1e-06, "rho": 0.9, "weight_decay": 0.9}, # with weight_decay
  218. ]
  219. for case in cases:
  220. _test_optimizer("Adadelta", case, CheckValue)
  221. _test_optimizer("Adadelta", case, CheckValue, update_lr=True)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台