You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_convert_format.py 2.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import numpy as np
  9. import pytest
  10. import megengine.functional as F
  11. import megengine.module as M
  12. from megengine import Parameter, Tensor, amp
  13. from megengine.core._config import set_auto_format_convert
  14. class MyModule(M.Module):
  15. class InnerModule(M.Module):
  16. def __init__(self):
  17. super().__init__()
  18. self.bn = M.BatchNorm2d(4)
  19. def forward(self, x):
  20. return self.bn(x)
  21. def __init__(self):
  22. super().__init__()
  23. self.i = self.InnerModule()
  24. self.conv = M.Conv2d(4, 4, 4, groups=2)
  25. self.bn = M.BatchNorm2d(4)
  26. self.param = Parameter(np.ones((1, 3, 1, 1), dtype=np.float32))
  27. self.buff = Tensor(np.ones((1, 3, 1, 1), dtype=np.float32))
  28. def forward(self, x):
  29. x = self.i(x)
  30. x = self.bn(x)
  31. return x
  32. @pytest.mark.parametrize("is_inplace", [False, True])
  33. def test_convert_module(is_inplace):
  34. m = MyModule()
  35. expected_shape = {
  36. "i.bn.weight": (1, 1, 1, 4),
  37. "i.bn.bias": (1, 1, 1, 4),
  38. "i.bn.running_mean": (1, 1, 1, 4),
  39. "i.bn.running_var": (1, 1, 1, 4),
  40. "conv.weight": (2, 2, 4, 4, 2),
  41. "conv.bias": (1, 1, 1, 4),
  42. "bn.weight": (1, 1, 1, 4),
  43. "bn.bias": (1, 1, 1, 4),
  44. "bn.running_mean": (1, 1, 1, 4),
  45. "bn.running_var": (1, 1, 1, 4),
  46. "param": (1, 1, 1, 3),
  47. "buff": (1, 1, 1, 3),
  48. }
  49. m = amp.convert_module_format(m, is_inplace)
  50. for name, param in m.named_tensors():
  51. assert param.format == "nhwc"
  52. set_auto_format_convert(False)
  53. assert param.shape == expected_shape[name], name
  54. set_auto_format_convert(True)