You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_grad_scaler.py 1.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import numpy as np
  2. import pytest
  3. import megengine as mge
  4. from megengine.amp import GradScaler
  5. from megengine.autodiff import GradManager
  6. from megengine.jit import trace
  7. @pytest.mark.parametrize(
  8. "is_trace", [False, True],
  9. )
  10. def test_grad_scaler(is_trace):
  11. gm = GradManager()
  12. scaler = GradScaler()
  13. def f(idx, data, calc):
  14. x = mge.tensor(data, no_cache=True)
  15. y = mge.tensor(data, no_cache=True)
  16. if is_trace:
  17. calc = trace(calc)
  18. gm.attach([x, y])
  19. with gm:
  20. loss = calc(x, y)
  21. scaler.backward(gm, loss, unscale_grad=False)
  22. np.testing.assert_equal(x.grad.numpy(), 2 * scaler.scale_factor)
  23. scaler.unscale(filter(lambda t: t.grad is not None, gm.attached_tensors()))
  24. # scaler.unscale(gm.attached_tensors())
  25. np.testing.assert_equal(x.grad.numpy(), 2)
  26. def double_variables(x, y):
  27. z = x + 2 * y
  28. loss = 2 * z + 1
  29. return loss
  30. def single_variable(x, y):
  31. z = x + 1
  32. loss = 2 * z + 1
  33. return loss
  34. # need grad being unique storage or not inplace modifying grad
  35. def double_variables_with_same_grad(x, y):
  36. z = x + y
  37. loss = 2 * z + 1
  38. return loss
  39. for data in [np.random.random((1, 2, 3, 4)), 1.0]:
  40. for calc in [
  41. double_variables,
  42. single_variable,
  43. double_variables_with_same_grad,
  44. ]:
  45. for idx in range(3):
  46. f(idx, data, calc)