You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_qat.py 9.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import io
  2. from itertools import product
  3. import numpy as np
  4. import pytest
  5. import megengine.utils.comp_graph_tools as cgtools
  6. from megengine import jit
  7. from megengine import module as M
  8. from megengine import tensor
  9. from megengine.device import get_device_count
  10. from megengine.functional import expand_dims
  11. from megengine.module import (
  12. BatchMatMulActivation,
  13. Conv2d,
  14. ConvBn2d,
  15. ConvRelu2d,
  16. ConvTranspose2d,
  17. ConvTransposeBn2d,
  18. ConvTransposeRelu2d,
  19. DequantStub,
  20. Module,
  21. QuantStub,
  22. )
  23. from megengine.quantization.quantize import (
  24. disable_fake_quant,
  25. enable_fake_quant,
  26. quantize,
  27. quantize_qat,
  28. )
  29. def test_qat_convbn2d():
  30. in_channels = 32
  31. out_channels = 64
  32. kernel_size = 3
  33. for groups, bias in product([1, 4], [True, False]):
  34. module = ConvBn2d(
  35. in_channels, out_channels, kernel_size, groups=groups, bias=bias
  36. )
  37. M.init.normal_(module.bn.weight)
  38. M.init.normal_(module.bn.bias)
  39. module.train()
  40. qat_module = quantize_qat(module, inplace=False)
  41. disable_fake_quant(qat_module)
  42. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  43. normal_outputs = module(inputs)
  44. qat_outputs = qat_module(inputs)
  45. np.testing.assert_allclose(
  46. normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
  47. )
  48. np.testing.assert_allclose(
  49. module.bn.running_mean.numpy(),
  50. qat_module.bn.running_mean.numpy(),
  51. atol=5e-8,
  52. )
  53. np.testing.assert_allclose(
  54. module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7,
  55. )
  56. module.eval()
  57. normal_outputs = module(inputs)
  58. qat_module.eval()
  59. qat_outputs = qat_module(inputs)
  60. np.testing.assert_allclose(
  61. normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
  62. )
  63. def test_qat_convtransposebn2d():
  64. in_channels = 32
  65. out_channels = 64
  66. kernel_size = 3
  67. for groups, bias in product([1, 4], [True, False]):
  68. module = ConvTransposeBn2d(
  69. in_channels=in_channels,
  70. out_channels=out_channels,
  71. kernel_size=kernel_size,
  72. output_padding=0,
  73. groups=groups,
  74. bias=bias,
  75. )
  76. M.init.normal_(module.bn.weight)
  77. M.init.normal_(module.bn.bias)
  78. module.train()
  79. qat_module = quantize_qat(module, inplace=False)
  80. disable_fake_quant(qat_module)
  81. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  82. normal_outputs = module(inputs)
  83. qat_outputs = qat_module(inputs)
  84. np.testing.assert_allclose(
  85. normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
  86. )
  87. np.testing.assert_allclose(
  88. module.bn.running_mean.numpy(),
  89. qat_module.bn.running_mean.numpy(),
  90. atol=5e-8,
  91. )
  92. np.testing.assert_allclose(
  93. module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7,
  94. )
  95. module.eval()
  96. normal_outputs = module(inputs)
  97. qat_module.eval()
  98. qat_outputs = qat_module(inputs)
  99. np.testing.assert_allclose(
  100. normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
  101. )
  102. @pytest.mark.parametrize(
  103. "padding, padding_mode",
  104. [
  105. (0, "zeros"),
  106. ((1, 2), "zeros"),
  107. (3, "reflect"),
  108. ((1, 2), "reflect"),
  109. (4, "replicate"),
  110. ((1, 2), "replicate"),
  111. ],
  112. )
  113. def test_qat_conv(padding, padding_mode):
  114. in_channels = 32
  115. out_channels = 64
  116. kernel_size = 3
  117. class TestNet(Module):
  118. def __init__(self, groups, bias):
  119. super().__init__()
  120. self.quant = QuantStub()
  121. self.dequant = DequantStub()
  122. self.conv = Conv2d(
  123. in_channels,
  124. out_channels,
  125. kernel_size,
  126. groups=groups,
  127. bias=bias,
  128. padding=padding,
  129. padding_mode=padding_mode,
  130. )
  131. self.conv_relu = ConvRelu2d(
  132. out_channels, in_channels, kernel_size, groups=groups, bias=bias
  133. )
  134. def forward(self, inp):
  135. out = self.quant(inp)
  136. out = self.conv(out)
  137. out = self.conv_relu(out)
  138. out = self.dequant(out)
  139. return out
  140. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  141. for groups, bias in product([1, 4], [True, False]):
  142. net = TestNet(groups, bias)
  143. net.train()
  144. qat_net = quantize_qat(net, inplace=False)
  145. disable_fake_quant(qat_net)
  146. normal_outputs = net(inputs)
  147. qat_outputs = qat_net(inputs)
  148. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  149. net.eval()
  150. normal_outputs = net(inputs)
  151. qat_net.eval()
  152. qat_outputs = qat_net(inputs)
  153. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  154. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda")
  155. def test_qat_batchmatmul_activation():
  156. batch = 4
  157. in_features = 8
  158. out_features = 4
  159. class TestNet(Module):
  160. def __init__(self, bias):
  161. super().__init__()
  162. self.quant = QuantStub()
  163. self.dequant = DequantStub()
  164. self.batch_mm = BatchMatMulActivation(
  165. batch, in_features, out_features, bias=bias
  166. )
  167. def forward(self, inp):
  168. out = self.quant(inp)
  169. out = self.batch_mm(out)
  170. out = self.dequant(out)
  171. return out
  172. inputs = tensor(
  173. np.random.randn(batch, in_features, out_features).astype(np.float32)
  174. )
  175. for bias in (True, False):
  176. net = TestNet(bias)
  177. net.train()
  178. qat_net = quantize_qat(net, inplace=False)
  179. disable_fake_quant(qat_net)
  180. normal_outputs = net(inputs)
  181. qat_outputs = qat_net(inputs)
  182. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  183. net.eval()
  184. normal_outputs = net(inputs)
  185. qat_net.eval()
  186. qat_outputs = qat_net(inputs)
  187. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  188. @pytest.mark.skip(reason="FIXME: abnormal exit")
  189. def test_quantize_batchmatmul_activation():
  190. batch = 4
  191. in_features = 8
  192. out_features = 4
  193. class TestNet(Module):
  194. def __init__(self, bias):
  195. super().__init__()
  196. self.quant = QuantStub()
  197. self.dequant = DequantStub()
  198. self.batch_mm = BatchMatMulActivation(
  199. batch, in_features, out_features, bias=bias
  200. )
  201. def forward(self, inp):
  202. out = self.quant(inp)
  203. out = self.batch_mm(out)
  204. out = expand_dims(out, -1)
  205. out = self.dequant(out)
  206. return out
  207. inputs = tensor(
  208. np.random.randn(batch, in_features, out_features).astype(np.float32)
  209. )
  210. for bias in (True, False):
  211. net = TestNet(bias)
  212. net.train()
  213. qat_net = quantize_qat(net, inplace=False)
  214. disable_fake_quant(qat_net)
  215. normal_outputs = net(inputs)
  216. qat_outputs = qat_net(inputs)
  217. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  218. net.eval()
  219. normal_outputs = net(inputs)
  220. qat_net.eval()
  221. qat_outputs = qat_net(inputs)
  222. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  223. enable_fake_quant(qat_net)
  224. qat_outputs = qat_net(inputs)
  225. qnet = quantize(qat_net, inplace=False)
  226. qnet.eval()
  227. quantize_outputs = qnet(inputs)
  228. np.testing.assert_allclose(
  229. qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6
  230. )
  231. @jit.trace(capture_as_const=True)
  232. def f(x):
  233. qnet.eval()
  234. return qnet(x)
  235. f(inputs)
  236. file = io.BytesIO()
  237. f.dump(file, enable_nchw4=True)
  238. file.seek(0)
  239. infer_cg = cgtools.GraphInference(file)[0]
  240. dumped_outputs = list(infer_cg.run(inputs.numpy()).values())[0]
  241. np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6)
  242. def test_qat_conv_transpose2d():
  243. in_channels = 32
  244. out_channels = 64
  245. kernel_size = 3
  246. class TestNet(Module):
  247. def __init__(self, bias):
  248. super().__init__()
  249. self.quant = QuantStub()
  250. self.dequant = DequantStub()
  251. self.conv = ConvTranspose2d(
  252. in_channels, out_channels, kernel_size, bias=bias
  253. )
  254. self.conv_transpose2d_relu = ConvTransposeRelu2d(
  255. out_channels, in_channels, kernel_size, bias=bias
  256. )
  257. def forward(self, inp):
  258. out = self.quant(inp)
  259. out = self.conv(out)
  260. out = self.conv_transpose2d_relu(out)
  261. out = self.dequant(out)
  262. return out
  263. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  264. for bias in [True, False]:
  265. net = TestNet(bias)
  266. net.train()
  267. qat_net = quantize_qat(net, inplace=False)
  268. disable_fake_quant(qat_net)
  269. normal_outputs = net(inputs)
  270. qat_outputs = qat_net(inputs)
  271. np.testing.assert_allclose(
  272. normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6
  273. )
  274. net.eval()
  275. normal_outputs = net(inputs)
  276. qat_net.eval()
  277. qat_outputs = qat_net(inputs)
  278. np.testing.assert_allclose(
  279. normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6
  280. )