You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_qat.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. import io
  2. from itertools import product
  3. import numpy as np
  4. import pytest
  5. import megengine.utils.comp_graph_tools as cgtools
  6. from megengine import jit
  7. from megengine import module as M
  8. from megengine import tensor
  9. from megengine.device import get_device_count
  10. from megengine.functional import expand_dims
  11. from megengine.module import (
  12. BatchMatMulActivation,
  13. Conv2d,
  14. ConvBn2d,
  15. ConvRelu2d,
  16. ConvTranspose2d,
  17. ConvTransposeBn2d,
  18. ConvTransposeRelu2d,
  19. DequantStub,
  20. Module,
  21. QuantStub,
  22. )
  23. from megengine.quantization.quantize import (
  24. disable_fake_quant,
  25. enable_fake_quant,
  26. quantize,
  27. quantize_qat,
  28. )
  29. def test_qat_convbn2d():
  30. in_channels = 32
  31. out_channels = 64
  32. kernel_size = 3
  33. class TestNet(Module):
  34. def __init__(self, groups, bias):
  35. super().__init__()
  36. self.quant = QuantStub()
  37. self.dequant = DequantStub()
  38. self.conv_bn = ConvBn2d(
  39. in_channels, out_channels, kernel_size, groups=groups, bias=bias,
  40. )
  41. def forward(self, inp):
  42. out = self.quant(inp)
  43. out = self.conv_bn(out)
  44. out = self.dequant(out)
  45. return out
  46. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  47. for groups, bias in product([1, 4], [True, False]):
  48. net = TestNet(groups, bias)
  49. net.train()
  50. qat_net = quantize_qat(net, inplace=False)
  51. disable_fake_quant(qat_net)
  52. normal_outputs = net(inputs)
  53. qat_outputs = qat_net(inputs)
  54. np.testing.assert_allclose(
  55. normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4,
  56. )
  57. np.testing.assert_allclose(
  58. net.conv_bn.bn.running_mean.numpy(),
  59. qat_net.conv_bn.bn.running_mean.numpy(),
  60. atol=5e-8,
  61. )
  62. np.testing.assert_allclose(
  63. net.conv_bn.bn.running_var.numpy(),
  64. qat_net.conv_bn.bn.running_var.numpy(),
  65. atol=5e-7,
  66. )
  67. net.eval()
  68. normal_outputs = net(inputs)
  69. qat_net.eval()
  70. qat_outputs = qat_net(inputs)
  71. np.testing.assert_allclose(
  72. normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4,
  73. )
  74. def test_qat_convtransposebn2d():
  75. in_channels = 32
  76. out_channels = 64
  77. kernel_size = 3
  78. class TestNet(Module):
  79. def __init__(self, groups, bias):
  80. super().__init__()
  81. self.quant = QuantStub()
  82. self.dequant = DequantStub()
  83. self.conv_transpose_bn = ConvTransposeBn2d(
  84. in_channels, out_channels, kernel_size, groups=groups, bias=bias,
  85. )
  86. def forward(self, inp):
  87. out = self.quant(inp)
  88. out = self.conv_transpose_bn(out)
  89. out = self.dequant(out)
  90. return out
  91. for groups, bias in product([1, 4], [True, False]):
  92. net = TestNet(groups, bias)
  93. net.train()
  94. qat_net = quantize_qat(net, inplace=False)
  95. disable_fake_quant(qat_net)
  96. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  97. normal_outputs = net(inputs)
  98. qat_outputs = qat_net(inputs)
  99. np.testing.assert_allclose(
  100. normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5,
  101. )
  102. np.testing.assert_allclose(
  103. net.conv_transpose_bn.bn.running_var.numpy(),
  104. qat_net.conv_transpose_bn.bn.running_var.numpy(),
  105. atol=5e-7,
  106. )
  107. net.eval()
  108. normal_outputs = net(inputs)
  109. qat_net.eval()
  110. qat_outputs = qat_net(inputs)
  111. np.testing.assert_allclose(
  112. normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5,
  113. )
  114. @pytest.mark.parametrize(
  115. "padding, padding_mode",
  116. [
  117. (0, "zeros"),
  118. ((1, 2), "zeros"),
  119. (3, "reflect"),
  120. ((1, 2), "reflect"),
  121. (4, "replicate"),
  122. ((1, 2), "replicate"),
  123. ],
  124. )
  125. def test_qat_conv(padding, padding_mode):
  126. in_channels = 32
  127. out_channels = 64
  128. kernel_size = 3
  129. class TestNet(Module):
  130. def __init__(self, groups, bias):
  131. super().__init__()
  132. self.quant = QuantStub()
  133. self.dequant = DequantStub()
  134. self.conv = Conv2d(
  135. in_channels,
  136. out_channels,
  137. kernel_size,
  138. groups=groups,
  139. bias=bias,
  140. padding=padding,
  141. padding_mode=padding_mode,
  142. )
  143. self.conv_relu = ConvRelu2d(
  144. out_channels, in_channels, kernel_size, groups=groups, bias=bias
  145. )
  146. def forward(self, inp):
  147. out = self.quant(inp)
  148. out = self.conv(out)
  149. out = self.conv_relu(out)
  150. out = self.dequant(out)
  151. return out
  152. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  153. for groups, bias in product([1, 4], [True, False]):
  154. net = TestNet(groups, bias)
  155. net.train()
  156. qat_net = quantize_qat(net, inplace=False)
  157. disable_fake_quant(qat_net)
  158. normal_outputs = net(inputs)
  159. qat_outputs = qat_net(inputs)
  160. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  161. net.eval()
  162. normal_outputs = net(inputs)
  163. qat_net.eval()
  164. qat_outputs = qat_net(inputs)
  165. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  166. @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda")
  167. def test_qat_batchmatmul_activation():
  168. batch = 4
  169. in_features = 8
  170. out_features = 4
  171. class TestNet(Module):
  172. def __init__(self, bias):
  173. super().__init__()
  174. self.quant = QuantStub()
  175. self.dequant = DequantStub()
  176. self.batch_mm = BatchMatMulActivation(
  177. batch, in_features, out_features, bias=bias
  178. )
  179. def forward(self, inp):
  180. out = self.quant(inp)
  181. out = self.batch_mm(out)
  182. out = self.dequant(out)
  183. return out
  184. inputs = tensor(
  185. np.random.randn(batch, in_features, out_features).astype(np.float32)
  186. )
  187. for bias in (True, False):
  188. net = TestNet(bias)
  189. net.train()
  190. qat_net = quantize_qat(net, inplace=False)
  191. disable_fake_quant(qat_net)
  192. normal_outputs = net(inputs)
  193. qat_outputs = qat_net(inputs)
  194. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  195. net.eval()
  196. normal_outputs = net(inputs)
  197. qat_net.eval()
  198. qat_outputs = qat_net(inputs)
  199. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  200. @pytest.mark.skip(reason="FIXME: abnormal exit")
  201. def test_quantize_batchmatmul_activation():
  202. batch = 4
  203. in_features = 8
  204. out_features = 4
  205. class TestNet(Module):
  206. def __init__(self, bias):
  207. super().__init__()
  208. self.quant = QuantStub()
  209. self.dequant = DequantStub()
  210. self.batch_mm = BatchMatMulActivation(
  211. batch, in_features, out_features, bias=bias
  212. )
  213. def forward(self, inp):
  214. out = self.quant(inp)
  215. out = self.batch_mm(out)
  216. out = expand_dims(out, -1)
  217. out = self.dequant(out)
  218. return out
  219. inputs = tensor(
  220. np.random.randn(batch, in_features, out_features).astype(np.float32)
  221. )
  222. for bias in (True, False):
  223. net = TestNet(bias)
  224. net.train()
  225. qat_net = quantize_qat(net, inplace=False)
  226. disable_fake_quant(qat_net)
  227. normal_outputs = net(inputs)
  228. qat_outputs = qat_net(inputs)
  229. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  230. net.eval()
  231. normal_outputs = net(inputs)
  232. qat_net.eval()
  233. qat_outputs = qat_net(inputs)
  234. np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
  235. enable_fake_quant(qat_net)
  236. qat_outputs = qat_net(inputs)
  237. qnet = quantize(qat_net, inplace=False)
  238. qnet.eval()
  239. quantize_outputs = qnet(inputs)
  240. np.testing.assert_allclose(
  241. qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6
  242. )
  243. @jit.trace(capture_as_const=True)
  244. def f(x):
  245. qnet.eval()
  246. return qnet(x)
  247. f(inputs)
  248. file = io.BytesIO()
  249. f.dump(file, enable_nchw4=True)
  250. file.seek(0)
  251. infer_cg = cgtools.GraphInference(file)[0]
  252. dumped_outputs = list(infer_cg.run(inputs.numpy()).values())[0]
  253. np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6)
  254. def test_qat_conv_transpose2d():
  255. in_channels = 32
  256. out_channels = 64
  257. kernel_size = 3
  258. class TestNet(Module):
  259. def __init__(self, bias):
  260. super().__init__()
  261. self.quant = QuantStub()
  262. self.dequant = DequantStub()
  263. self.conv = ConvTranspose2d(
  264. in_channels, out_channels, kernel_size, bias=bias
  265. )
  266. self.conv_transpose2d_relu = ConvTransposeRelu2d(
  267. out_channels, in_channels, kernel_size, bias=bias
  268. )
  269. def forward(self, inp):
  270. out = self.quant(inp)
  271. out = self.conv(out)
  272. out = self.conv_transpose2d_relu(out)
  273. out = self.dequant(out)
  274. return out
  275. inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
  276. for bias in [True, False]:
  277. net = TestNet(bias)
  278. net.train()
  279. qat_net = quantize_qat(net, inplace=False)
  280. disable_fake_quant(qat_net)
  281. normal_outputs = net(inputs)
  282. qat_outputs = qat_net(inputs)
  283. np.testing.assert_allclose(
  284. normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6
  285. )
  286. net.eval()
  287. normal_outputs = net(inputs)
  288. qat_net.eval()
  289. qat_outputs = qat_net(inputs)
  290. np.testing.assert_allclose(
  291. normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6
  292. )