You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_qat.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import io
  2. from itertools import product
  3. import numpy as np
  4. import pytest
  5. import megengine.utils.comp_graph_tools as cgtools
  6. from megengine import jit, tensor
  7. from megengine.distributed.helper import get_device_count_by_fork
  8. from megengine.functional import expand_dims
  9. from megengine.module import (
  10. BatchMatMulActivation,
  11. Conv2d,
  12. ConvBn2d,
  13. ConvRelu2d,
  14. DequantStub,
  15. Module,
  16. QuantStub,
  17. )
  18. from megengine.quantization.quantize import (
  19. disable_fake_quant,
  20. enable_fake_quant,
  21. quantize,
  22. quantize_qat,
  23. )
  24. def test_qat_convbn2d():
  25. in_channels = 32
  26. out_channels = 64
  27. kernel_size = 3
  28. for groups, bias in product([1, 4], [True, False]):
  29. module = ConvBn2d(
  30. in_channels, out_channels, kernel_size, groups=groups, bias=bias
  31. )
  32. module.train()
  33. qat_module = quantize_qat(module, inplace=False)
  34. disable_fake_quant(qat_module)
  35. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  36. normal_outputs = module(inputs)
  37. qat_outputs = qat_module(inputs)
  38. np.testing.assert_allclose(
  39. normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
  40. )
  41. np.testing.assert_allclose(
  42. module.bn.running_mean.numpy(),
  43. qat_module.bn.running_mean.numpy(),
  44. atol=5e-8,
  45. )
  46. np.testing.assert_allclose(
  47. module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7,
  48. )
  49. module.eval()
  50. normal_outputs = module(inputs)
  51. qat_module.eval()
  52. qat_outputs = qat_module(inputs)
  53. np.testing.assert_allclose(
  54. normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
  55. )
  56. def test_qat_conv():
  57. in_channels = 32
  58. out_channels = 64
  59. kernel_size = 3
  60. class TestNet(Module):
  61. def __init__(self, groups, bias):
  62. super().__init__()
  63. self.quant = QuantStub()
  64. self.dequant = DequantStub()
  65. self.conv = Conv2d(
  66. in_channels, out_channels, kernel_size, groups=groups, bias=bias
  67. )
  68. self.conv_relu = ConvRelu2d(
  69. out_channels, in_channels, kernel_size, groups=groups, bias=bias
  70. )
  71. def forward(self, inp):
  72. out = self.quant(inp)
  73. out = self.conv(out)
  74. out = self.conv_relu(out)
  75. out = self.dequant(out)
  76. return out
  77. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  78. for groups, bias in product([1, 4], [True, False]):
  79. net = TestNet(groups, bias)
  80. net.train()
  81. qat_net = quantize_qat(net, inplace=False)
  82. disable_fake_quant(qat_net)
  83. normal_outputs = net(inputs)
  84. qat_outputs = qat_net(inputs)
  85. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  86. net.eval()
  87. normal_outputs = net(inputs)
  88. qat_net.eval()
  89. qat_outputs = qat_net(inputs)
  90. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  91. @pytest.mark.skipif(
  92. get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda"
  93. )
  94. def test_qat_batchmatmul_activation():
  95. batch = 4
  96. in_features = 8
  97. out_features = 4
  98. class TestNet(Module):
  99. def __init__(self, bias):
  100. super().__init__()
  101. self.quant = QuantStub()
  102. self.dequant = DequantStub()
  103. self.batch_mm = BatchMatMulActivation(
  104. batch, in_features, out_features, bias=bias
  105. )
  106. def forward(self, inp):
  107. out = self.quant(inp)
  108. out = self.batch_mm(out)
  109. out = self.dequant(out)
  110. return out
  111. inputs = tensor(
  112. np.random.randn(batch, in_features, out_features).astype(np.float32)
  113. )
  114. for bias in (True, False):
  115. net = TestNet(bias)
  116. net.train()
  117. qat_net = quantize_qat(net, inplace=False)
  118. disable_fake_quant(qat_net)
  119. normal_outputs = net(inputs)
  120. qat_outputs = qat_net(inputs)
  121. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  122. net.eval()
  123. normal_outputs = net(inputs)
  124. qat_net.eval()
  125. qat_outputs = qat_net(inputs)
  126. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  127. @pytest.mark.skip(reason="FIXME: abnormal exit")
  128. def test_quantize_batchmatmul_activation():
  129. batch = 4
  130. in_features = 8
  131. out_features = 4
  132. class TestNet(Module):
  133. def __init__(self, bias):
  134. super().__init__()
  135. self.quant = QuantStub()
  136. self.dequant = DequantStub()
  137. self.batch_mm = BatchMatMulActivation(
  138. batch, in_features, out_features, bias=bias
  139. )
  140. def forward(self, inp):
  141. out = self.quant(inp)
  142. out = self.batch_mm(out)
  143. out = expand_dims(out, -1)
  144. out = self.dequant(out)
  145. return out
  146. inputs = tensor(
  147. np.random.randn(batch, in_features, out_features).astype(np.float32)
  148. )
  149. for bias in (True, False):
  150. net = TestNet(bias)
  151. net.train()
  152. qat_net = quantize_qat(net, inplace=False)
  153. disable_fake_quant(qat_net)
  154. normal_outputs = net(inputs)
  155. qat_outputs = qat_net(inputs)
  156. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  157. net.eval()
  158. normal_outputs = net(inputs)
  159. qat_net.eval()
  160. qat_outputs = qat_net(inputs)
  161. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  162. enable_fake_quant(qat_net)
  163. qat_outputs = qat_net(inputs)
  164. qnet = quantize(qat_net, inplace=False)
  165. qnet.eval()
  166. quantize_outputs = qnet(inputs)
  167. np.testing.assert_allclose(
  168. qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6
  169. )
  170. @jit.trace(capture_as_const=True)
  171. def f(x):
  172. qnet.eval()
  173. return qnet(x)
  174. f(inputs)
  175. file = io.BytesIO()
  176. f.dump(file, enable_nchw4=True)
  177. file.seek(0)
  178. infer_cg = cgtools.GraphInference(file)[0]
  179. dumped_outputs = list(infer_cg.run(inputs.numpy()).values())[0]
  180. np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台