You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_qat.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import io
  2. from itertools import product
  3. import numpy as np
  4. import pytest
  5. import megengine.utils.comp_graph_tools as cgtools
  6. from megengine import jit, tensor
  7. from megengine.device import get_device_count
  8. from megengine.functional import expand_dims
  9. from megengine.module import (
  10. BatchMatMulActivation,
  11. Conv2d,
  12. ConvBn2d,
  13. ConvRelu2d,
  14. DequantStub,
  15. Module,
  16. QuantStub,
  17. )
  18. from megengine.quantization.quantize import (
  19. disable_fake_quant,
  20. enable_fake_quant,
  21. quantize,
  22. quantize_qat,
  23. )
  24. def test_qat_convbn2d():
  25. in_channels = 32
  26. out_channels = 64
  27. kernel_size = 3
  28. for groups, bias in product([1, 4], [True, False]):
  29. module = ConvBn2d(
  30. in_channels, out_channels, kernel_size, groups=groups, bias=bias
  31. )
  32. module.train()
  33. qat_module = quantize_qat(module, inplace=False)
  34. disable_fake_quant(qat_module)
  35. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  36. normal_outputs = module(inputs)
  37. qat_outputs = qat_module(inputs)
  38. np.testing.assert_allclose(
  39. normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
  40. )
  41. np.testing.assert_allclose(
  42. module.bn.running_mean.numpy(),
  43. qat_module.bn.running_mean.numpy(),
  44. atol=5e-8,
  45. )
  46. np.testing.assert_allclose(
  47. module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7,
  48. )
  49. module.eval()
  50. normal_outputs = module(inputs)
  51. qat_module.eval()
  52. qat_outputs = qat_module(inputs)
  53. np.testing.assert_allclose(
  54. normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
  55. )
  56. def test_qat_conv():
  57. in_channels = 32
  58. out_channels = 64
  59. kernel_size = 3
  60. class TestNet(Module):
  61. def __init__(self, groups, bias):
  62. super().__init__()
  63. self.quant = QuantStub()
  64. self.dequant = DequantStub()
  65. self.conv = Conv2d(
  66. in_channels, out_channels, kernel_size, groups=groups, bias=bias
  67. )
  68. self.conv_relu = ConvRelu2d(
  69. out_channels, in_channels, kernel_size, groups=groups, bias=bias
  70. )
  71. def forward(self, inp):
  72. out = self.quant(inp)
  73. out = self.conv(out)
  74. out = self.conv_relu(out)
  75. out = self.dequant(out)
  76. return out
  77. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  78. for groups, bias in product([1, 4], [True, False]):
  79. net = TestNet(groups, bias)
  80. net.train()
  81. qat_net = quantize_qat(net, inplace=False)
  82. disable_fake_quant(qat_net)
  83. normal_outputs = net(inputs)
  84. qat_outputs = qat_net(inputs)
  85. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  86. net.eval()
  87. normal_outputs = net(inputs)
  88. qat_net.eval()
  89. qat_outputs = qat_net(inputs)
  90. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  91. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda")
  92. def test_qat_batchmatmul_activation():
  93. batch = 4
  94. in_features = 8
  95. out_features = 4
  96. class TestNet(Module):
  97. def __init__(self, bias):
  98. super().__init__()
  99. self.quant = QuantStub()
  100. self.dequant = DequantStub()
  101. self.batch_mm = BatchMatMulActivation(
  102. batch, in_features, out_features, bias=bias
  103. )
  104. def forward(self, inp):
  105. out = self.quant(inp)
  106. out = self.batch_mm(out)
  107. out = self.dequant(out)
  108. return out
  109. inputs = tensor(
  110. np.random.randn(batch, in_features, out_features).astype(np.float32)
  111. )
  112. for bias in (True, False):
  113. net = TestNet(bias)
  114. net.train()
  115. qat_net = quantize_qat(net, inplace=False)
  116. disable_fake_quant(qat_net)
  117. normal_outputs = net(inputs)
  118. qat_outputs = qat_net(inputs)
  119. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  120. net.eval()
  121. normal_outputs = net(inputs)
  122. qat_net.eval()
  123. qat_outputs = qat_net(inputs)
  124. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  125. @pytest.mark.skip(reason="FIXME: abnormal exit")
  126. def test_quantize_batchmatmul_activation():
  127. batch = 4
  128. in_features = 8
  129. out_features = 4
  130. class TestNet(Module):
  131. def __init__(self, bias):
  132. super().__init__()
  133. self.quant = QuantStub()
  134. self.dequant = DequantStub()
  135. self.batch_mm = BatchMatMulActivation(
  136. batch, in_features, out_features, bias=bias
  137. )
  138. def forward(self, inp):
  139. out = self.quant(inp)
  140. out = self.batch_mm(out)
  141. out = expand_dims(out, -1)
  142. out = self.dequant(out)
  143. return out
  144. inputs = tensor(
  145. np.random.randn(batch, in_features, out_features).astype(np.float32)
  146. )
  147. for bias in (True, False):
  148. net = TestNet(bias)
  149. net.train()
  150. qat_net = quantize_qat(net, inplace=False)
  151. disable_fake_quant(qat_net)
  152. normal_outputs = net(inputs)
  153. qat_outputs = qat_net(inputs)
  154. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  155. net.eval()
  156. normal_outputs = net(inputs)
  157. qat_net.eval()
  158. qat_outputs = qat_net(inputs)
  159. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  160. enable_fake_quant(qat_net)
  161. qat_outputs = qat_net(inputs)
  162. qnet = quantize(qat_net, inplace=False)
  163. qnet.eval()
  164. quantize_outputs = qnet(inputs)
  165. np.testing.assert_allclose(
  166. qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6
  167. )
  168. @jit.trace(capture_as_const=True)
  169. def f(x):
  170. qnet.eval()
  171. return qnet(x)
  172. f(inputs)
  173. file = io.BytesIO()
  174. f.dump(file, enable_nchw4=True)
  175. file.seek(0)
  176. infer_cg = cgtools.GraphInference(file)[0]
  177. dumped_outputs = list(infer_cg.run(inputs.numpy()).values())[0]
  178. np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台