You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_qat.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import io
  2. from itertools import product
  3. import numpy as np
  4. import pytest
  5. import megengine.utils.comp_graph_tools as cgtools
  6. from megengine import jit, tensor
  7. from megengine.device import get_device_count
  8. from megengine.functional import expand_dims
  9. from megengine.module import (
  10. BatchMatMulActivation,
  11. Conv2d,
  12. ConvBn2d,
  13. ConvRelu2d,
  14. ConvTranspose2d,
  15. DequantStub,
  16. Module,
  17. QuantStub,
  18. )
  19. from megengine.quantization.quantize import (
  20. disable_fake_quant,
  21. enable_fake_quant,
  22. quantize,
  23. quantize_qat,
  24. )
  25. def test_qat_convbn2d():
  26. in_channels = 32
  27. out_channels = 64
  28. kernel_size = 3
  29. for groups, bias in product([1, 4], [True, False]):
  30. module = ConvBn2d(
  31. in_channels, out_channels, kernel_size, groups=groups, bias=bias
  32. )
  33. module.train()
  34. qat_module = quantize_qat(module, inplace=False)
  35. disable_fake_quant(qat_module)
  36. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  37. normal_outputs = module(inputs)
  38. qat_outputs = qat_module(inputs)
  39. np.testing.assert_allclose(
  40. normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
  41. )
  42. np.testing.assert_allclose(
  43. module.bn.running_mean.numpy(),
  44. qat_module.bn.running_mean.numpy(),
  45. atol=5e-8,
  46. )
  47. np.testing.assert_allclose(
  48. module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7,
  49. )
  50. module.eval()
  51. normal_outputs = module(inputs)
  52. qat_module.eval()
  53. qat_outputs = qat_module(inputs)
  54. np.testing.assert_allclose(
  55. normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
  56. )
  57. def test_qat_conv():
  58. in_channels = 32
  59. out_channels = 64
  60. kernel_size = 3
  61. class TestNet(Module):
  62. def __init__(self, groups, bias):
  63. super().__init__()
  64. self.quant = QuantStub()
  65. self.dequant = DequantStub()
  66. self.conv = Conv2d(
  67. in_channels, out_channels, kernel_size, groups=groups, bias=bias
  68. )
  69. self.conv_relu = ConvRelu2d(
  70. out_channels, in_channels, kernel_size, groups=groups, bias=bias
  71. )
  72. def forward(self, inp):
  73. out = self.quant(inp)
  74. out = self.conv(out)
  75. out = self.conv_relu(out)
  76. out = self.dequant(out)
  77. return out
  78. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  79. for groups, bias in product([1, 4], [True, False]):
  80. net = TestNet(groups, bias)
  81. net.train()
  82. qat_net = quantize_qat(net, inplace=False)
  83. disable_fake_quant(qat_net)
  84. normal_outputs = net(inputs)
  85. qat_outputs = qat_net(inputs)
  86. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  87. net.eval()
  88. normal_outputs = net(inputs)
  89. qat_net.eval()
  90. qat_outputs = qat_net(inputs)
  91. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  92. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda")
  93. def test_qat_batchmatmul_activation():
  94. batch = 4
  95. in_features = 8
  96. out_features = 4
  97. class TestNet(Module):
  98. def __init__(self, bias):
  99. super().__init__()
  100. self.quant = QuantStub()
  101. self.dequant = DequantStub()
  102. self.batch_mm = BatchMatMulActivation(
  103. batch, in_features, out_features, bias=bias
  104. )
  105. def forward(self, inp):
  106. out = self.quant(inp)
  107. out = self.batch_mm(out)
  108. out = self.dequant(out)
  109. return out
  110. inputs = tensor(
  111. np.random.randn(batch, in_features, out_features).astype(np.float32)
  112. )
  113. for bias in (True, False):
  114. net = TestNet(bias)
  115. net.train()
  116. qat_net = quantize_qat(net, inplace=False)
  117. disable_fake_quant(qat_net)
  118. normal_outputs = net(inputs)
  119. qat_outputs = qat_net(inputs)
  120. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  121. net.eval()
  122. normal_outputs = net(inputs)
  123. qat_net.eval()
  124. qat_outputs = qat_net(inputs)
  125. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  126. @pytest.mark.skip(reason="FIXME: abnormal exit")
  127. def test_quantize_batchmatmul_activation():
  128. batch = 4
  129. in_features = 8
  130. out_features = 4
  131. class TestNet(Module):
  132. def __init__(self, bias):
  133. super().__init__()
  134. self.quant = QuantStub()
  135. self.dequant = DequantStub()
  136. self.batch_mm = BatchMatMulActivation(
  137. batch, in_features, out_features, bias=bias
  138. )
  139. def forward(self, inp):
  140. out = self.quant(inp)
  141. out = self.batch_mm(out)
  142. out = expand_dims(out, -1)
  143. out = self.dequant(out)
  144. return out
  145. inputs = tensor(
  146. np.random.randn(batch, in_features, out_features).astype(np.float32)
  147. )
  148. for bias in (True, False):
  149. net = TestNet(bias)
  150. net.train()
  151. qat_net = quantize_qat(net, inplace=False)
  152. disable_fake_quant(qat_net)
  153. normal_outputs = net(inputs)
  154. qat_outputs = qat_net(inputs)
  155. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  156. net.eval()
  157. normal_outputs = net(inputs)
  158. qat_net.eval()
  159. qat_outputs = qat_net(inputs)
  160. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  161. enable_fake_quant(qat_net)
  162. qat_outputs = qat_net(inputs)
  163. qnet = quantize(qat_net, inplace=False)
  164. qnet.eval()
  165. quantize_outputs = qnet(inputs)
  166. np.testing.assert_allclose(
  167. qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6
  168. )
  169. @jit.trace(capture_as_const=True)
  170. def f(x):
  171. qnet.eval()
  172. return qnet(x)
  173. f(inputs)
  174. file = io.BytesIO()
  175. f.dump(file, enable_nchw4=True)
  176. file.seek(0)
  177. infer_cg = cgtools.GraphInference(file)[0]
  178. dumped_outputs = list(infer_cg.run(inputs.numpy()).values())[0]
  179. np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6)
  180. def test_qat_conv_transpose2d():
  181. in_channels = 32
  182. out_channels = 64
  183. kernel_size = 3
  184. class TestNet(Module):
  185. def __init__(self, bias):
  186. super().__init__()
  187. self.quant = QuantStub()
  188. self.dequant = DequantStub()
  189. self.conv = ConvTranspose2d(
  190. in_channels, out_channels, kernel_size, bias=bias
  191. )
  192. def forward(self, inp):
  193. out = self.quant(inp)
  194. out = self.conv(out)
  195. out = self.dequant(out)
  196. return out
  197. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  198. for bias in [True, False]:
  199. net = TestNet(bias)
  200. net.train()
  201. qat_net = quantize_qat(net, inplace=False)
  202. disable_fake_quant(qat_net)
  203. normal_outputs = net(inputs)
  204. qat_outputs = qat_net(inputs)
  205. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  206. net.eval()
  207. normal_outputs = net(inputs)
  208. qat_net.eval()
  209. qat_outputs = qat_net(inputs)
  210. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台