You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_passes.py 2.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import types
  2. import numpy as np
  3. import pytest
  4. import megengine as mge
  5. import megengine.functional as F
  6. import megengine.module as M
  7. import megengine.traced_module as tm
  8. class myconv(M.Conv2d):
  9. pass
  10. class mybn(M.BatchNorm2d):
  11. pass
  12. class MyBlock(M.Module):
  13. def __init__(self, conv_cls, bn_cls):
  14. super().__init__()
  15. self.conv = conv_cls(3, 3, 1, 1, 0)
  16. self.bn = bn_cls(3)
  17. self.conv2 = conv_cls(3, 3, 1, 1, 0)
  18. self.bn2 = bn_cls(3)
  19. self.scale = mge.Tensor([3, 4])
  20. def forward(self, x):
  21. x1 = self.conv(x)
  22. x1 = self.bn(x1)
  23. x1 = F.relu(x1)
  24. x1 = x1 * self.scale[0]
  25. x2 = self.conv2(x)
  26. x2 = self.bn2(x2)
  27. x2 = F.relu(x2)
  28. x2 = x2 * self.scale[1]
  29. y = x1 + x2
  30. y = y + 4
  31. y = self.scale[0] + y
  32. y = F.relu(y) * 3
  33. return y
  34. class MyModule(M.Module):
  35. def __init__(self, conv_cls, bn_cls):
  36. super().__init__()
  37. self.block_0 = MyBlock(conv_cls, bn_cls)
  38. self.block_1 = MyBlock(conv_cls, bn_cls)
  39. def forward(self, x):
  40. x1 = self.block_0(x)
  41. x2 = self.block_1(x)
  42. y = x1 + x2
  43. y = F.reshape(y, (-1))
  44. y = y * 3
  45. return y
  46. @pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv])
  47. @pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn])
  48. def test_backward_fold_scale(conv_cls, bn_cls):
  49. module = MyModule(conv_cls, bn_cls)
  50. module.eval()
  51. inp = mge.Tensor(np.random.random((1, 3, 32, 32)))
  52. desired = module(inp)
  53. traced_net = tm.trace_module(module, inp)
  54. traced_net = traced_net.flatten()
  55. optimized_net = tm.optimize(traced_net, "BackwardFoldScale")
  56. actual = optimized_net(inp)
  57. np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4)
  58. # fuse all mul to conv
  59. mul_list = optimized_net.graph.get_method_by_type("__mul__").as_list()
  60. assert len(mul_list) == 0
  61. @pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv])
  62. @pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn])
  63. def test_fuse_bn(conv_cls, bn_cls):
  64. module = MyModule(conv_cls, bn_cls)
  65. module.eval()
  66. inp = mge.Tensor(np.random.random((1, 3, 32, 32)))
  67. desired = module(inp)
  68. traced_net = tm.trace_module(module, inp)
  69. traced_net = traced_net.flatten()
  70. optimized_net = tm.optimize(traced_net, "FuseConvBn")
  71. actual = optimized_net(inp)
  72. np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4)
  73. # fuse all mul to conv
  74. bn_list = optimized_net.graph.get_function_by_type(F.batch_norm).as_list()
  75. assert len(bn_list) == 0
  76. bn_list = optimized_net.graph.get_module_by_type(M.BatchNorm2d).as_list()
  77. assert len(bn_list) == 0