You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_convert_format.py 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import numpy as np
  2. import pytest
  3. import megengine as mge
  4. import megengine.autodiff as autodiff
  5. import megengine.functional as F
  6. import megengine.module as M
  7. import megengine.optimizer as optim
  8. from megengine import Parameter, Tensor, amp
  9. from megengine.core._config import set_auto_format_convert
  10. from megengine.core._trace_option import use_symbolic_shape
  11. class MyModule(M.Module):
  12. class InnerModule(M.Module):
  13. def __init__(self):
  14. super().__init__()
  15. self.bn = M.BatchNorm2d(4)
  16. def forward(self, x):
  17. return self.bn(x)
  18. def __init__(self):
  19. super().__init__()
  20. self.i = self.InnerModule()
  21. self.conv = M.Conv2d(4, 4, 4, groups=2)
  22. self.bn = M.BatchNorm2d(4)
  23. self.param = Parameter(np.ones((1, 3, 1, 1), dtype=np.float32))
  24. self.buff = Tensor(np.ones((1, 3, 1, 1), dtype=np.float32))
  25. def forward(self, x):
  26. x = self.i(x)
  27. x = self.bn(x)
  28. return x
  29. @pytest.mark.parametrize("is_inplace", [False, True])
  30. def test_convert_module(is_inplace):
  31. m = MyModule()
  32. expected_shape = {
  33. "i.bn.weight": (1, 4, 1, 1),
  34. "i.bn.bias": (1, 4, 1, 1),
  35. "i.bn.running_mean": (1, 4, 1, 1),
  36. "i.bn.running_var": (1, 4, 1, 1),
  37. "conv.weight": (2, 2, 2, 4, 4),
  38. "conv.bias": (1, 4, 1, 1),
  39. "bn.weight": (1, 4, 1, 1),
  40. "bn.bias": (1, 4, 1, 1),
  41. "bn.running_mean": (1, 4, 1, 1),
  42. "bn.running_var": (1, 4, 1, 1),
  43. "param": (1, 3, 1, 1),
  44. "buff": (1, 3, 1, 1),
  45. }
  46. m = amp.convert_module_format(m, is_inplace)
  47. for name, param in m.named_tensors():
  48. assert param.format == "nhwc"
  49. if use_symbolic_shape():
  50. np.testing.assert_array_equal(
  51. param.shape.numpy(), expected_shape[name], name
  52. )
  53. else:
  54. assert param.shape == expected_shape[name], name
  55. class Module(M.Module):
  56. def __init__(self):
  57. super().__init__()
  58. self.conv = M.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
  59. self.bn = M.BatchNorm2d(16)
  60. def forward(self, x):
  61. out = F.relu(self.bn(self.conv(x)))
  62. return out
  63. def test_format_remained():
  64. m = Module()
  65. m = amp.convert_module_format(m)
  66. gm = autodiff.GradManager().attach(m.parameters())
  67. opt = optim.SGD(m.parameters(), lr=0.01)
  68. scaler = amp.GradScaler()
  69. image = mge.tensor(np.random.normal(size=(1, 3, 224, 224)), dtype="float32")
  70. label = mge.tensor(np.ones((1, 224, 224)), dtype="int32")
  71. image = amp.convert_tensor_format(image)
  72. @amp.autocast(enabled=True)
  73. def train_step(image):
  74. with gm:
  75. logits = m(image)
  76. loss = F.nn.cross_entropy(logits, label)
  77. scaler.backward(gm, loss)
  78. opt.step().clear_grad()
  79. return logits
  80. for _ in range(5):
  81. res = train_step(image)
  82. assert res.format == "nhwc"