You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_clip_grad.py 2.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import numpy as np
  2. import megengine as mge
  3. import megengine.autodiff as ad
  4. import megengine.functional as F
  5. import megengine.module as M
  6. import megengine.optimizer as optim
  7. class Net(M.Module):
  8. def __init__(self):
  9. super().__init__()
  10. self.conv1 = M.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
  11. self.bn1 = M.BatchNorm2d(64)
  12. self.avgpool = M.AvgPool2d(kernel_size=5, stride=5, padding=0)
  13. self.fc = M.Linear(64, 10)
  14. def forward(self, x):
  15. x = self.conv1(x)
  16. x = self.bn1(x)
  17. x = F.relu(x)
  18. x = self.avgpool(x)
  19. x = F.avg_pool2d(x, 22)
  20. x = F.flatten(x, 1)
  21. x = self.fc(x)
  22. return x
  23. def save_grad_value(net):
  24. for param in net.parameters():
  25. param.grad_backup = param.grad.numpy().copy()
  26. def test_clip_grad_norm():
  27. net = Net()
  28. x = mge.tensor(np.random.randn(10, 3, 224, 224))
  29. gm = ad.GradManager().attach(net.parameters())
  30. opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
  31. with gm:
  32. loss = net(x).sum()
  33. gm.backward(loss)
  34. save_grad_value(net)
  35. max_norm = 1.0
  36. original_norm = optim.clip_grad_norm(net.parameters(), max_norm=max_norm, ord=2)
  37. scale = max_norm / original_norm
  38. for param in net.parameters():
  39. np.testing.assert_almost_equal(param.grad.numpy(), param.grad_backup * scale)
  40. opt.step().clear_grad()
  41. def test_clip_grad_value():
  42. net = Net()
  43. x = np.random.randn(10, 3, 224, 224).astype("float32")
  44. gm = ad.GradManager().attach(net.parameters())
  45. opt = optim.SGD(net.parameters(), 1e-3, momentum=0.9)
  46. with gm:
  47. y = net(mge.tensor(x))
  48. y = y.mean()
  49. gm.backward(y)
  50. save_grad_value(net)
  51. max_val = 5
  52. min_val = -2
  53. optim.clip_grad_value(net.parameters(), lower=min_val, upper=max_val)
  54. for param in net.parameters():
  55. np.testing.assert_almost_equal(
  56. param.grad.numpy(),
  57. np.maximum(np.minimum(param.grad_backup, max_val), min_val),
  58. )
  59. opt.step().clear_grad()