You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_optimizer.py 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. import numpy as np
  11. import pytest
  12. import megengine.autodiff as ad
  13. import megengine.functional as F
  14. from megengine import Parameter, optimizer
  15. from megengine.jit import trace
  16. from megengine.module import Linear, Module
  17. from megengine.tensor import Tensor
  18. class MLP(Module):
  19. def __init__(self):
  20. super().__init__()
  21. self.dense0 = Linear(28, 50)
  22. self.dense1 = Linear(50, 20)
  23. def forward(self, x):
  24. x = self.dense0(x)
  25. x = F.relu(x)
  26. x = self.dense1(x)
  27. return x
  28. class Simple(Module):
  29. def __init__(self):
  30. super().__init__()
  31. self.a = Parameter(1.23, dtype=np.float32)
  32. def forward(self, x):
  33. x = x * self.a
  34. return x
  35. def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
  36. iter_num = 3
  37. net = Simple()
  38. opt = getattr(optimizer, opt_str)(net.parameters(), **test_case)
  39. check_func = check_class(net, **test_case)
  40. gm = ad.GradManager().attach(net.parameters())
  41. step = 0
  42. data_shape = (2, 28)
  43. for i in range(iter_num):
  44. if update_lr and i == 1: # change learning rate
  45. for group in opt.param_groups:
  46. group["lr"] += 0.01
  47. check_func.lr += 0.01
  48. data = Tensor(np.random.random(data_shape).astype(np.float32))
  49. opt.clear_grad()
  50. with gm:
  51. pred = net(data)
  52. loss = pred.sum()
  53. gm.backward(loss)
  54. ori_params = {}
  55. ori_grads = {}
  56. for param in net.parameters():
  57. assert param._tuple_shape is ()
  58. ori_params[param] = np.copy(param.numpy())
  59. ori_grads[param] = np.copy(param.grad.numpy())
  60. opt.step()
  61. # check grad not change
  62. for param in net.parameters():
  63. assert np.equal(
  64. ori_grads[param], param.grad.numpy()
  65. ), "step should not change param.grad"
  66. step += 1
  67. check_func(ori_params, net.parameters(), step)
  68. # static graph
  69. for symbolic in (False, True):
  70. @trace(symbolic=symbolic)
  71. def train_func(data, *, opt=None, gm=None):
  72. opt.clear_grad()
  73. with gm:
  74. pred = net(data)
  75. loss = pred.sum()
  76. gm.backward(loss)
  77. opt.step()
  78. # reset net and opt
  79. net = Simple()
  80. opt = getattr(optimizer, opt_str)(net.parameters(), **test_case)
  81. gm = ad.GradManager().attach(net.parameters())
  82. check_func = check_class(net, **test_case)
  83. step = 0
  84. for i in range(iter_num):
  85. if update_lr and i == 1: # change learning rate
  86. for group in opt.param_groups:
  87. group["lr"] += 0.01
  88. check_func.lr += 0.01
  89. ori_params = {}
  90. for param in net.parameters():
  91. assert param._tuple_shape is ()
  92. ori_params[param] = np.copy(param.numpy())
  93. train_func(
  94. Tensor(np.random.random(data_shape).astype(np.float32)), opt=opt, gm=gm
  95. )
  96. step += 1
  97. check_func(ori_params, net.parameters(), step)
  98. try_state_dict = {
  99. "net": net.state_dict(),
  100. "opt": opt.state_dict(),
  101. }
  102. @pytest.mark.parametrize(
  103. "case",
  104. [
  105. {"momentum": 0.9, "lr": 0.01}, # SGD with momentum
  106. {"momentum": 0.9, "lr": 0.01, "nesterov": True}, # with nesterov momentum
  107. {"lr": 0.01}, # simple SGD
  108. {"weight_decay": 0.1, "lr": 0.01}, # with weight_decay
  109. ],
  110. )
  111. @pytest.mark.parametrize("update_lr", [False, True])
  112. @pytest.mark.parametrize("inplace_mode", [False, True])
  113. def test_sgd(monkeypatch, case, update_lr, inplace_mode):
  114. class CheckValue:
  115. def __init__(self, net, **kwarg):
  116. self.slots = {}
  117. for param in net.parameters():
  118. self.slots[param] = np.zeros(param.shape).astype(np.float32)
  119. for k, v in kwarg.items():
  120. setattr(self, k, v)
  121. def __call__(self, ori_params, new_params, step):
  122. for param in new_params:
  123. grad = param.grad.numpy()
  124. if hasattr(self, "weight_decay") and self.weight_decay != 0.0:
  125. grad = grad + ori_params[param] * self.weight_decay
  126. if hasattr(self, "momentum") and self.momentum != 0.0:
  127. self.slots[param] = grad + self.slots[param] * self.momentum
  128. if hasattr(self, "nesterov") and self.nesterov:
  129. delta = -self.lr * (grad + self.slots[param] * self.momentum)
  130. else:
  131. delta = -self.lr * self.slots[param]
  132. else:
  133. delta = -self.lr * grad
  134. np.testing.assert_almost_equal(
  135. param.numpy(), ori_params[param] + delta, decimal=6
  136. )
  137. with monkeypatch.context() as mk:
  138. mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
  139. _test_optimizer("SGD", case, CheckValue, update_lr=update_lr)
  140. @pytest.mark.parametrize(
  141. "case",
  142. [
  143. {"betas": (0.8, 0.9), "eps": 1e-04, "lr": 0.01},
  144. {
  145. "betas": (0.8, 0.9),
  146. "eps": 1e-04,
  147. "lr": 0.01,
  148. "weight_decay": 0.1,
  149. }, # with weight_decay
  150. ],
  151. )
  152. @pytest.mark.parametrize("update_lr", [False, True])
  153. @pytest.mark.parametrize("inplace_mode", [False, True])
  154. def test_adam(monkeypatch, case, update_lr, inplace_mode):
  155. class CheckValue:
  156. def __init__(self, net, **kwarg):
  157. self.m_slots = {}
  158. self.v_slots = {}
  159. for param in net.parameters():
  160. self.m_slots[param] = np.zeros(param.shape).astype(np.float32)
  161. self.v_slots[param] = np.zeros(param.shape).astype(np.float32)
  162. for k, v in kwarg.items():
  163. setattr(self, k, v)
  164. def __call__(self, ori_params, new_params, step):
  165. for param in new_params:
  166. grad = param.grad.numpy()
  167. if hasattr(self, "weight_decay") and self.weight_decay != 0.0:
  168. grad = grad + ori_params[param] * self.weight_decay
  169. m = self.m_slots[param]
  170. v = self.v_slots[param]
  171. m *= self.betas[0]
  172. m += (1 - self.betas[0]) * grad
  173. v *= self.betas[1]
  174. v += (1 - self.betas[1]) * grad * grad
  175. delta = (m / (1 - self.betas[0] ** step)) / (
  176. np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps
  177. )
  178. np.testing.assert_almost_equal(
  179. param.numpy(), ori_params[param] - self.lr * delta, decimal=6
  180. )
  181. with monkeypatch.context() as mk:
  182. mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
  183. _test_optimizer("Adam", case, CheckValue, update_lr=update_lr)
  184. @pytest.mark.parametrize(
  185. "case",
  186. [
  187. {"lr": 0.01, "eps": 1e-06, "lr_decay": 0.01},
  188. {"lr": 0.01, "eps": 1e-06, "lr_decay": 0.0}, # without lr_decay
  189. {
  190. "lr": 0.01,
  191. "eps": 1e-06,
  192. "lr_decay": 0.01,
  193. "weight_decay": 0.1,
  194. }, # with weight_decay
  195. ],
  196. )
  197. @pytest.mark.parametrize("update_lr", [False, True])
  198. @pytest.mark.parametrize("inplace_mode", [False, True])
  199. def test_adagrad(monkeypatch, case, update_lr, inplace_mode):
  200. class CheckValue:
  201. def __init__(self, net, **kwarg):
  202. self.s_slots = {}
  203. for param in net.parameters():
  204. self.s_slots[param] = np.zeros(param.shape).astype(np.float32)
  205. for k, v in kwarg.items():
  206. setattr(self, k, v)
  207. def __call__(self, ori_params, new_params, step):
  208. for param in new_params:
  209. grad = param.grad.numpy()
  210. if hasattr(self, "weight_decay") and self.weight_decay != 0.0:
  211. grad = grad + ori_params[param] * self.weight_decay
  212. self.s_slots[param] += grad ** 2
  213. delta = grad / (self.s_slots[param] + self.eps) ** 0.5
  214. delta *= -(self.lr / (1 + (step - 1) * self.lr_decay))
  215. np.testing.assert_almost_equal(
  216. param.numpy(), ori_params[param] + delta, decimal=6
  217. )
  218. with monkeypatch.context() as mk:
  219. mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
  220. _test_optimizer("Adagrad", case, CheckValue, update_lr=update_lr)
  221. @pytest.mark.parametrize(
  222. "case",
  223. [
  224. {"lr": 1.0, "eps": 1e-06, "rho": 0.9},
  225. {"lr": 1.0, "eps": 1e-06, "rho": 0.9, "weight_decay": 0.9}, # with weight_decay
  226. ],
  227. )
  228. @pytest.mark.parametrize("update_lr", [False, True])
  229. @pytest.mark.parametrize("inplace_mode", [False, True])
  230. def test_adadelta(monkeypatch, case, update_lr, inplace_mode):
  231. class CheckValue:
  232. def __init__(self, net, **kwarg):
  233. self.s_slots = {}
  234. self.a_slots = {}
  235. for param in net.parameters():
  236. self.s_slots[param] = np.zeros(param.shape).astype(np.float32)
  237. self.a_slots[param] = np.zeros(param.shape).astype(np.float32)
  238. for k, v in kwarg.items():
  239. setattr(self, k, v)
  240. def __call__(self, ori_params, new_params, step):
  241. for param in new_params:
  242. grad = param.grad.numpy()
  243. if hasattr(self, "weight_decay") and self.weight_decay != 0.0:
  244. grad = grad + ori_params[param] * self.weight_decay
  245. self.s_slots[param] = self.s_slots[param] * self.rho + grad ** 2 * (
  246. 1 - self.rho
  247. )
  248. delta = (
  249. grad
  250. * ((self.a_slots[param] + self.eps) ** 0.5)
  251. / (self.s_slots[param] + self.eps) ** 0.5
  252. )
  253. self.a_slots[param] = self.a_slots[param] * self.rho + delta ** 2 * (
  254. 1 - self.rho
  255. )
  256. delta *= -self.lr
  257. np.testing.assert_almost_equal(
  258. param.numpy(), ori_params[param] + delta, decimal=6
  259. )
  260. with monkeypatch.context() as mk:
  261. mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
  262. _test_optimizer("Adadelta", case, CheckValue, update_lr=update_lr)
  263. @pytest.mark.parametrize(
  264. "case",
  265. [
  266. {"betas": (0.8, 0.9), "eps": 1e-08, "lr": 0.01},
  267. {
  268. "betas": (0.8, 0.9),
  269. "eps": 1e-08,
  270. "lr": 0.01,
  271. "weight_decay": 0.1,
  272. }, # with weight_decay
  273. ],
  274. )
  275. @pytest.mark.parametrize("update_lr", [False, True])
  276. @pytest.mark.parametrize("inplace_mode", [False, True])
  277. def test_adamw(monkeypatch, case, update_lr, inplace_mode):
  278. class CheckValue:
  279. def __init__(self, net, **kwarg):
  280. self.m_slots = {}
  281. self.v_slots = {}
  282. for param in net.parameters():
  283. self.m_slots[param] = np.zeros(param.shape).astype(np.float32)
  284. self.v_slots[param] = np.zeros(param.shape).astype(np.float32)
  285. self.weight_decay = 0.01
  286. for k, v in kwarg.items():
  287. setattr(self, k, v)
  288. def __call__(self, ori_params, new_params, step):
  289. step = np.array(step).astype(np.float32)
  290. for param in new_params:
  291. grad = param.grad.numpy()
  292. m = self.m_slots[param]
  293. v = self.v_slots[param]
  294. m *= self.betas[0]
  295. m += (1 - self.betas[0]) * grad
  296. v *= self.betas[1]
  297. v += (1 - self.betas[1]) * grad * grad
  298. delta = (m / (1 - self.betas[0] ** step)) / (
  299. np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps
  300. )
  301. delta += ori_params[param] * self.weight_decay
  302. np.testing.assert_almost_equal(
  303. param.numpy(), ori_params[param] - self.lr * delta, decimal=6
  304. )
  305. with monkeypatch.context() as mk:
  306. mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
  307. _test_optimizer("AdamW", case, CheckValue, update_lr=update_lr)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台