You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_qat.py 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import io
  2. from itertools import product
  3. import numpy as np
  4. import pytest
  5. import megengine.utils.comp_graph_tools as cgtools
  6. from megengine import jit, tensor
  7. from megengine.device import get_device_count
  8. from megengine.functional import expand_dims
  9. from megengine.module import (
  10. BatchMatMulActivation,
  11. Conv2d,
  12. ConvBn2d,
  13. ConvRelu2d,
  14. ConvTranspose2d,
  15. DequantStub,
  16. Module,
  17. QuantStub,
  18. )
  19. from megengine.quantization.quantize import (
  20. disable_fake_quant,
  21. enable_fake_quant,
  22. quantize,
  23. quantize_qat,
  24. )
  25. def test_qat_convbn2d():
  26. in_channels = 32
  27. out_channels = 64
  28. kernel_size = 3
  29. for groups, bias in product([1, 4], [True, False]):
  30. module = ConvBn2d(
  31. in_channels, out_channels, kernel_size, groups=groups, bias=bias
  32. )
  33. module.train()
  34. qat_module = quantize_qat(module, inplace=False)
  35. disable_fake_quant(qat_module)
  36. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  37. normal_outputs = module(inputs)
  38. qat_outputs = qat_module(inputs)
  39. np.testing.assert_allclose(
  40. normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
  41. )
  42. np.testing.assert_allclose(
  43. module.bn.running_mean.numpy(),
  44. qat_module.bn.running_mean.numpy(),
  45. atol=5e-8,
  46. )
  47. np.testing.assert_allclose(
  48. module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7,
  49. )
  50. module.eval()
  51. normal_outputs = module(inputs)
  52. qat_module.eval()
  53. qat_outputs = qat_module(inputs)
  54. np.testing.assert_allclose(
  55. normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
  56. )
  57. @pytest.mark.parametrize(
  58. "padding, padding_mode",
  59. [
  60. (0, "zeros"),
  61. ((1, 2), "zeros"),
  62. (3, "reflect"),
  63. ((1, 2), "reflect"),
  64. (4, "replicate"),
  65. ((1, 2), "replicate"),
  66. ],
  67. )
  68. def test_qat_conv(padding, padding_mode):
  69. in_channels = 32
  70. out_channels = 64
  71. kernel_size = 3
  72. class TestNet(Module):
  73. def __init__(self, groups, bias):
  74. super().__init__()
  75. self.quant = QuantStub()
  76. self.dequant = DequantStub()
  77. self.conv = Conv2d(
  78. in_channels,
  79. out_channels,
  80. kernel_size,
  81. groups=groups,
  82. bias=bias,
  83. padding=padding,
  84. padding_mode=padding_mode,
  85. )
  86. self.conv_relu = ConvRelu2d(
  87. out_channels, in_channels, kernel_size, groups=groups, bias=bias
  88. )
  89. def forward(self, inp):
  90. out = self.quant(inp)
  91. out = self.conv(out)
  92. out = self.conv_relu(out)
  93. out = self.dequant(out)
  94. return out
  95. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  96. for groups, bias in product([1, 4], [True, False]):
  97. net = TestNet(groups, bias)
  98. net.train()
  99. qat_net = quantize_qat(net, inplace=False)
  100. disable_fake_quant(qat_net)
  101. normal_outputs = net(inputs)
  102. qat_outputs = qat_net(inputs)
  103. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  104. net.eval()
  105. normal_outputs = net(inputs)
  106. qat_net.eval()
  107. qat_outputs = qat_net(inputs)
  108. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  109. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda")
  110. def test_qat_batchmatmul_activation():
  111. batch = 4
  112. in_features = 8
  113. out_features = 4
  114. class TestNet(Module):
  115. def __init__(self, bias):
  116. super().__init__()
  117. self.quant = QuantStub()
  118. self.dequant = DequantStub()
  119. self.batch_mm = BatchMatMulActivation(
  120. batch, in_features, out_features, bias=bias
  121. )
  122. def forward(self, inp):
  123. out = self.quant(inp)
  124. out = self.batch_mm(out)
  125. out = self.dequant(out)
  126. return out
  127. inputs = tensor(
  128. np.random.randn(batch, in_features, out_features).astype(np.float32)
  129. )
  130. for bias in (True, False):
  131. net = TestNet(bias)
  132. net.train()
  133. qat_net = quantize_qat(net, inplace=False)
  134. disable_fake_quant(qat_net)
  135. normal_outputs = net(inputs)
  136. qat_outputs = qat_net(inputs)
  137. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  138. net.eval()
  139. normal_outputs = net(inputs)
  140. qat_net.eval()
  141. qat_outputs = qat_net(inputs)
  142. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  143. @pytest.mark.skip(reason="FIXME: abnormal exit")
  144. def test_quantize_batchmatmul_activation():
  145. batch = 4
  146. in_features = 8
  147. out_features = 4
  148. class TestNet(Module):
  149. def __init__(self, bias):
  150. super().__init__()
  151. self.quant = QuantStub()
  152. self.dequant = DequantStub()
  153. self.batch_mm = BatchMatMulActivation(
  154. batch, in_features, out_features, bias=bias
  155. )
  156. def forward(self, inp):
  157. out = self.quant(inp)
  158. out = self.batch_mm(out)
  159. out = expand_dims(out, -1)
  160. out = self.dequant(out)
  161. return out
  162. inputs = tensor(
  163. np.random.randn(batch, in_features, out_features).astype(np.float32)
  164. )
  165. for bias in (True, False):
  166. net = TestNet(bias)
  167. net.train()
  168. qat_net = quantize_qat(net, inplace=False)
  169. disable_fake_quant(qat_net)
  170. normal_outputs = net(inputs)
  171. qat_outputs = qat_net(inputs)
  172. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  173. net.eval()
  174. normal_outputs = net(inputs)
  175. qat_net.eval()
  176. qat_outputs = qat_net(inputs)
  177. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  178. enable_fake_quant(qat_net)
  179. qat_outputs = qat_net(inputs)
  180. qnet = quantize(qat_net, inplace=False)
  181. qnet.eval()
  182. quantize_outputs = qnet(inputs)
  183. np.testing.assert_allclose(
  184. qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6
  185. )
  186. @jit.trace(capture_as_const=True)
  187. def f(x):
  188. qnet.eval()
  189. return qnet(x)
  190. f(inputs)
  191. file = io.BytesIO()
  192. f.dump(file, enable_nchw4=True)
  193. file.seek(0)
  194. infer_cg = cgtools.GraphInference(file)[0]
  195. dumped_outputs = list(infer_cg.run(inputs.numpy()).values())[0]
  196. np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6)
  197. def test_qat_conv_transpose2d():
  198. in_channels = 32
  199. out_channels = 64
  200. kernel_size = 3
  201. class TestNet(Module):
  202. def __init__(self, bias):
  203. super().__init__()
  204. self.quant = QuantStub()
  205. self.dequant = DequantStub()
  206. self.conv = ConvTranspose2d(
  207. in_channels, out_channels, kernel_size, bias=bias
  208. )
  209. def forward(self, inp):
  210. out = self.quant(inp)
  211. out = self.conv(out)
  212. out = self.dequant(out)
  213. return out
  214. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  215. for bias in [True, False]:
  216. net = TestNet(bias)
  217. net.train()
  218. qat_net = quantize_qat(net, inplace=False)
  219. disable_fake_quant(qat_net)
  220. normal_outputs = net(inputs)
  221. qat_outputs = qat_net(inputs)
  222. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  223. net.eval()
  224. normal_outputs = net(inputs)
  225. qat_net.eval()
  226. qat_outputs = qat_net(inputs)
  227. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())