You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_convert_format.py 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import numpy as np
  9. import pytest
  10. import megengine.functional as F
  11. import megengine.module as M
  12. from megengine import Parameter, Tensor, amp
  13. from megengine.core._config import set_auto_format_convert
  14. from megengine.core._trace_option import use_symbolic_shape
  15. class MyModule(M.Module):
  16. class InnerModule(M.Module):
  17. def __init__(self):
  18. super().__init__()
  19. self.bn = M.BatchNorm2d(4)
  20. def forward(self, x):
  21. return self.bn(x)
  22. def __init__(self):
  23. super().__init__()
  24. self.i = self.InnerModule()
  25. self.conv = M.Conv2d(4, 4, 4, groups=2)
  26. self.bn = M.BatchNorm2d(4)
  27. self.param = Parameter(np.ones((1, 3, 1, 1), dtype=np.float32))
  28. self.buff = Tensor(np.ones((1, 3, 1, 1), dtype=np.float32))
  29. def forward(self, x):
  30. x = self.i(x)
  31. x = self.bn(x)
  32. return x
  33. @pytest.mark.parametrize("is_inplace", [False, True])
  34. def test_convert_module(is_inplace):
  35. m = MyModule()
  36. expected_shape = {
  37. "i.bn.weight": (1, 4, 1, 1),
  38. "i.bn.bias": (1, 4, 1, 1),
  39. "i.bn.running_mean": (1, 4, 1, 1),
  40. "i.bn.running_var": (1, 4, 1, 1),
  41. "conv.weight": (2, 2, 2, 4, 4),
  42. "conv.bias": (1, 4, 1, 1),
  43. "bn.weight": (1, 4, 1, 1),
  44. "bn.bias": (1, 4, 1, 1),
  45. "bn.running_mean": (1, 4, 1, 1),
  46. "bn.running_var": (1, 4, 1, 1),
  47. "param": (1, 3, 1, 1),
  48. "buff": (1, 3, 1, 1),
  49. }
  50. m = amp.convert_module_format(m, is_inplace)
  51. for name, param in m.named_tensors():
  52. assert param.format == "nhwc"
  53. if use_symbolic_shape():
  54. np.testing.assert_array_equal(
  55. param.shape.numpy(), expected_shape[name], name
  56. )
  57. else:
  58. assert param.shape == expected_shape[name], name