You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_optimizer.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. # -*- coding: utf-8 -*-
  2. import os
  3. import numpy as np
  4. import pytest
  5. import megengine.autodiff as ad
  6. import megengine.functional as F
  7. from megengine import Parameter, optimizer
  8. from megengine.jit import trace
  9. from megengine.module import Linear, Module
  10. from megengine.tensor import Tensor
  11. class MLP(Module):
  12. def __init__(self):
  13. super().__init__()
  14. self.dense0 = Linear(28, 50)
  15. self.dense1 = Linear(50, 20)
  16. def forward(self, x):
  17. x = self.dense0(x)
  18. x = F.relu(x)
  19. x = self.dense1(x)
  20. return x
  21. class Simple(Module):
  22. def __init__(self):
  23. super().__init__()
  24. self.a = Parameter(1.23, dtype=np.float32)
  25. def forward(self, x):
  26. x = x * self.a
  27. return x
  28. def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
  29. iter_num = 3
  30. net = Simple()
  31. opt = getattr(optimizer, opt_str)(net.parameters(), **test_case)
  32. check_func = check_class(net, **test_case)
  33. gm = ad.GradManager().attach(net.parameters())
  34. step = 0
  35. data_shape = (2, 28)
  36. for i in range(iter_num):
  37. if update_lr and i == 1: # change learning rate
  38. for group in opt.param_groups:
  39. group["lr"] += 0.01
  40. check_func.lr += 0.01
  41. data = Tensor(np.random.random(data_shape).astype(np.float32))
  42. opt.clear_grad()
  43. with gm:
  44. pred = net(data)
  45. loss = pred.sum()
  46. gm.backward(loss)
  47. ori_params = {}
  48. ori_grads = {}
  49. for param in net.parameters():
  50. assert param._tuple_shape is ()
  51. ori_params[param] = np.copy(param.numpy())
  52. ori_grads[param] = np.copy(param.grad.numpy())
  53. opt.step()
  54. # check grad not change
  55. for param in net.parameters():
  56. assert np.equal(
  57. ori_grads[param], param.grad.numpy()
  58. ), "step should not change param.grad"
  59. step += 1
  60. check_func(ori_params, net.parameters(), step)
  61. # static graph
  62. for symbolic in (False, True):
  63. @trace(symbolic=symbolic)
  64. def train_func(data, *, opt=None, gm=None):
  65. opt.clear_grad()
  66. with gm:
  67. pred = net(data)
  68. loss = pred.sum()
  69. gm.backward(loss)
  70. opt.step()
  71. # reset net and opt
  72. net = Simple()
  73. opt = getattr(optimizer, opt_str)(net.parameters(), **test_case)
  74. gm = ad.GradManager().attach(net.parameters())
  75. check_func = check_class(net, **test_case)
  76. step = 0
  77. for i in range(iter_num):
  78. if update_lr and i == 1: # change learning rate
  79. for group in opt.param_groups:
  80. group["lr"] += 0.01
  81. check_func.lr += 0.01
  82. ori_params = {}
  83. for param in net.parameters():
  84. assert param._tuple_shape is ()
  85. ori_params[param] = np.copy(param.numpy())
  86. train_func(
  87. Tensor(np.random.random(data_shape).astype(np.float32)), opt=opt, gm=gm
  88. )
  89. step += 1
  90. check_func(ori_params, net.parameters(), step)
  91. try_state_dict = {
  92. "net": net.state_dict(),
  93. "opt": opt.state_dict(),
  94. }
  95. @pytest.mark.parametrize(
  96. "case",
  97. [
  98. {"momentum": 0.9, "lr": 0.01}, # SGD with momentum
  99. {"momentum": 0.9, "lr": 0.01, "nesterov": True}, # with nesterov momentum
  100. {"lr": 0.01}, # simple SGD
  101. {"weight_decay": 0.1, "lr": 0.01}, # with weight_decay
  102. ],
  103. )
  104. @pytest.mark.parametrize("update_lr", [False, True])
  105. @pytest.mark.parametrize("inplace_mode", [False, True])
  106. def test_sgd(monkeypatch, case, update_lr, inplace_mode):
  107. class CheckValue:
  108. def __init__(self, net, **kwarg):
  109. self.slots = {}
  110. for param in net.parameters():
  111. self.slots[param] = np.zeros(param.shape).astype(np.float32)
  112. for k, v in kwarg.items():
  113. setattr(self, k, v)
  114. def __call__(self, ori_params, new_params, step):
  115. for param in new_params:
  116. grad = param.grad.numpy()
  117. if hasattr(self, "weight_decay") and self.weight_decay != 0.0:
  118. grad = grad + ori_params[param] * self.weight_decay
  119. if hasattr(self, "momentum") and self.momentum != 0.0:
  120. self.slots[param] = grad + self.slots[param] * self.momentum
  121. if hasattr(self, "nesterov") and self.nesterov:
  122. delta = -self.lr * (grad + self.slots[param] * self.momentum)
  123. else:
  124. delta = -self.lr * self.slots[param]
  125. else:
  126. delta = -self.lr * grad
  127. np.testing.assert_almost_equal(
  128. param.numpy(), ori_params[param] + delta, decimal=6
  129. )
  130. with monkeypatch.context() as mk:
  131. mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
  132. _test_optimizer("SGD", case, CheckValue, update_lr=update_lr)
  133. @pytest.mark.parametrize(
  134. "case",
  135. [
  136. {"betas": (0.8, 0.9), "eps": 1e-04, "lr": 0.01},
  137. {
  138. "betas": (0.8, 0.9),
  139. "eps": 1e-04,
  140. "lr": 0.01,
  141. "weight_decay": 0.1,
  142. }, # with weight_decay
  143. ],
  144. )
  145. @pytest.mark.parametrize("update_lr", [False, True])
  146. @pytest.mark.parametrize("inplace_mode", [False, True])
  147. def test_adam(monkeypatch, case, update_lr, inplace_mode):
  148. class CheckValue:
  149. def __init__(self, net, **kwarg):
  150. self.m_slots = {}
  151. self.v_slots = {}
  152. for param in net.parameters():
  153. self.m_slots[param] = np.zeros(param.shape).astype(np.float32)
  154. self.v_slots[param] = np.zeros(param.shape).astype(np.float32)
  155. for k, v in kwarg.items():
  156. setattr(self, k, v)
  157. def __call__(self, ori_params, new_params, step):
  158. for param in new_params:
  159. grad = param.grad.numpy()
  160. if hasattr(self, "weight_decay") and self.weight_decay != 0.0:
  161. grad = grad + ori_params[param] * self.weight_decay
  162. m = self.m_slots[param]
  163. v = self.v_slots[param]
  164. m *= self.betas[0]
  165. m += (1 - self.betas[0]) * grad
  166. v *= self.betas[1]
  167. v += (1 - self.betas[1]) * grad * grad
  168. delta = (m / (1 - self.betas[0] ** step)) / (
  169. np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps
  170. )
  171. np.testing.assert_almost_equal(
  172. param.numpy(), ori_params[param] - self.lr * delta, decimal=6
  173. )
  174. with monkeypatch.context() as mk:
  175. mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
  176. _test_optimizer("Adam", case, CheckValue, update_lr=update_lr)
  177. @pytest.mark.parametrize(
  178. "case",
  179. [
  180. {"lr": 0.01, "eps": 1e-06, "lr_decay": 0.01},
  181. {"lr": 0.01, "eps": 1e-06, "lr_decay": 0.0}, # without lr_decay
  182. {
  183. "lr": 0.01,
  184. "eps": 1e-06,
  185. "lr_decay": 0.01,
  186. "weight_decay": 0.1,
  187. }, # with weight_decay
  188. ],
  189. )
  190. @pytest.mark.parametrize("update_lr", [False, True])
  191. @pytest.mark.parametrize("inplace_mode", [False, True])
  192. def test_adagrad(monkeypatch, case, update_lr, inplace_mode):
  193. class CheckValue:
  194. def __init__(self, net, **kwarg):
  195. self.s_slots = {}
  196. for param in net.parameters():
  197. self.s_slots[param] = np.zeros(param.shape).astype(np.float32)
  198. for k, v in kwarg.items():
  199. setattr(self, k, v)
  200. def __call__(self, ori_params, new_params, step):
  201. for param in new_params:
  202. grad = param.grad.numpy()
  203. if hasattr(self, "weight_decay") and self.weight_decay != 0.0:
  204. grad = grad + ori_params[param] * self.weight_decay
  205. self.s_slots[param] += grad ** 2
  206. delta = grad / (self.s_slots[param] + self.eps) ** 0.5
  207. delta *= -(self.lr / (1 + (step - 1) * self.lr_decay))
  208. np.testing.assert_almost_equal(
  209. param.numpy(), ori_params[param] + delta, decimal=6
  210. )
  211. with monkeypatch.context() as mk:
  212. mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
  213. _test_optimizer("Adagrad", case, CheckValue, update_lr=update_lr)
  214. @pytest.mark.parametrize(
  215. "case",
  216. [
  217. {"lr": 1.0, "eps": 1e-06, "rho": 0.9},
  218. {"lr": 1.0, "eps": 1e-06, "rho": 0.9, "weight_decay": 0.9}, # with weight_decay
  219. ],
  220. )
  221. @pytest.mark.parametrize("update_lr", [False, True])
  222. @pytest.mark.parametrize("inplace_mode", [False, True])
  223. def test_adadelta(monkeypatch, case, update_lr, inplace_mode):
  224. class CheckValue:
  225. def __init__(self, net, **kwarg):
  226. self.s_slots = {}
  227. self.a_slots = {}
  228. for param in net.parameters():
  229. self.s_slots[param] = np.zeros(param.shape).astype(np.float32)
  230. self.a_slots[param] = np.zeros(param.shape).astype(np.float32)
  231. for k, v in kwarg.items():
  232. setattr(self, k, v)
  233. def __call__(self, ori_params, new_params, step):
  234. for param in new_params:
  235. grad = param.grad.numpy()
  236. if hasattr(self, "weight_decay") and self.weight_decay != 0.0:
  237. grad = grad + ori_params[param] * self.weight_decay
  238. self.s_slots[param] = self.s_slots[param] * self.rho + grad ** 2 * (
  239. 1 - self.rho
  240. )
  241. delta = (
  242. grad
  243. * ((self.a_slots[param] + self.eps) ** 0.5)
  244. / (self.s_slots[param] + self.eps) ** 0.5
  245. )
  246. self.a_slots[param] = self.a_slots[param] * self.rho + delta ** 2 * (
  247. 1 - self.rho
  248. )
  249. delta *= -self.lr
  250. np.testing.assert_almost_equal(
  251. param.numpy(), ori_params[param] + delta, decimal=6
  252. )
  253. with monkeypatch.context() as mk:
  254. mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
  255. _test_optimizer("Adadelta", case, CheckValue, update_lr=update_lr)
  256. @pytest.mark.parametrize(
  257. "case",
  258. [
  259. {"betas": (0.8, 0.9), "eps": 1e-08, "lr": 0.01},
  260. {
  261. "betas": (0.8, 0.9),
  262. "eps": 1e-08,
  263. "lr": 0.01,
  264. "weight_decay": 0.1,
  265. }, # with weight_decay
  266. ],
  267. )
  268. @pytest.mark.parametrize("update_lr", [False, True])
  269. @pytest.mark.parametrize("inplace_mode", [False, True])
  270. def test_adamw(monkeypatch, case, update_lr, inplace_mode):
  271. class CheckValue:
  272. def __init__(self, net, **kwarg):
  273. self.m_slots = {}
  274. self.v_slots = {}
  275. for param in net.parameters():
  276. self.m_slots[param] = np.zeros(param.shape).astype(np.float32)
  277. self.v_slots[param] = np.zeros(param.shape).astype(np.float32)
  278. self.weight_decay = 0.01
  279. for k, v in kwarg.items():
  280. setattr(self, k, v)
  281. def __call__(self, ori_params, new_params, step):
  282. step = np.array(step).astype(np.float32)
  283. for param in new_params:
  284. grad = param.grad.numpy()
  285. m = self.m_slots[param]
  286. v = self.v_slots[param]
  287. m *= self.betas[0]
  288. m += (1 - self.betas[0]) * grad
  289. v *= self.betas[1]
  290. v += (1 - self.betas[1]) * grad * grad
  291. delta = (m / (1 - self.betas[0] ** step)) / (
  292. np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps
  293. )
  294. delta += ori_params[param] * self.weight_decay
  295. np.testing.assert_almost_equal(
  296. param.numpy(), ori_params[param] - self.lr * delta, decimal=6
  297. )
  298. with monkeypatch.context() as mk:
  299. mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
  300. _test_optimizer("AdamW", case, CheckValue, update_lr=update_lr)