You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_convert_format.py 1.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import numpy as np
  2. import pytest
  3. import megengine.functional as F
  4. import megengine.module as M
  5. from megengine import Parameter, Tensor, amp
  6. from megengine.core._config import set_auto_format_convert
  7. from megengine.core._trace_option import use_symbolic_shape
  8. class MyModule(M.Module):
  9. class InnerModule(M.Module):
  10. def __init__(self):
  11. super().__init__()
  12. self.bn = M.BatchNorm2d(4)
  13. def forward(self, x):
  14. return self.bn(x)
  15. def __init__(self):
  16. super().__init__()
  17. self.i = self.InnerModule()
  18. self.conv = M.Conv2d(4, 4, 4, groups=2)
  19. self.bn = M.BatchNorm2d(4)
  20. self.param = Parameter(np.ones((1, 3, 1, 1), dtype=np.float32))
  21. self.buff = Tensor(np.ones((1, 3, 1, 1), dtype=np.float32))
  22. def forward(self, x):
  23. x = self.i(x)
  24. x = self.bn(x)
  25. return x
  26. @pytest.mark.parametrize("is_inplace", [False, True])
  27. def test_convert_module(is_inplace):
  28. m = MyModule()
  29. expected_shape = {
  30. "i.bn.weight": (1, 4, 1, 1),
  31. "i.bn.bias": (1, 4, 1, 1),
  32. "i.bn.running_mean": (1, 4, 1, 1),
  33. "i.bn.running_var": (1, 4, 1, 1),
  34. "conv.weight": (2, 2, 2, 4, 4),
  35. "conv.bias": (1, 4, 1, 1),
  36. "bn.weight": (1, 4, 1, 1),
  37. "bn.bias": (1, 4, 1, 1),
  38. "bn.running_mean": (1, 4, 1, 1),
  39. "bn.running_var": (1, 4, 1, 1),
  40. "param": (1, 3, 1, 1),
  41. "buff": (1, 3, 1, 1),
  42. }
  43. m = amp.convert_module_format(m, is_inplace)
  44. for name, param in m.named_tensors():
  45. assert param.format == "nhwc"
  46. if use_symbolic_shape():
  47. np.testing.assert_array_equal(
  48. param.shape.numpy(), expected_shape[name], name
  49. )
  50. else:
  51. assert param.shape == expected_shape[name], name