You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_op.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. import numpy as np
  2. import pytest
  3. import megengine as mge
  4. import megengine.functional as F
  5. from megengine.core.tensor import dtype
  6. from megengine.device import get_cuda_compute_capability, get_device_count
  7. from megengine.functional.elemwise import _elemwise_multi_type, _elwise
  8. from megengine.module.quantized.conv import ConvTranspose2d
  9. from megengine.quantization import QuantMode, create_qparams
  10. def quant(x, scale):
  11. x_dtype = dtype.qint8(scale)
  12. return x.astype(x_dtype)
  13. def fake_quant(x, scale):
  14. x = x / scale
  15. x = F.round(x)
  16. x = F.clip(x, -128, 127)
  17. x = x * scale
  18. return x
  19. @pytest.mark.parametrize("kind", ["abs", "sin", "sub", "mul", "fuse_add_tanh"])
  20. def test_elemwise(kind):
  21. x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  22. x1_scale = np.float32(np.random.rand() + 1)
  23. x1 = fake_quant(x1, x1_scale)
  24. x1.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", x1_scale))
  25. x1_int8 = quant(x1, x1_scale)
  26. x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  27. x2_scale = np.float32(np.random.rand() + 1)
  28. x2 = fake_quant(x2, x2_scale)
  29. x2.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", x2_scale))
  30. x2_int8 = quant(x2, x2_scale)
  31. output_scale = np.float32(np.random.rand() + 1)
  32. output_dtype = dtype.qint8(output_scale)
  33. quantized_kind = "q" + kind
  34. if kind in ("abs", "sin"):
  35. desired_out = fake_quant(_elwise(x1, mode=kind), output_scale)
  36. actual_out = (
  37. _elemwise_multi_type(
  38. x1_int8, mode=quantized_kind, dtype=output_dtype
  39. ).numpy()
  40. * output_scale
  41. )
  42. else:
  43. desired_out = fake_quant(_elwise(x1, x2, mode=kind), output_scale)
  44. actual_out = (
  45. _elemwise_multi_type(
  46. x1_int8, x2_int8, mode=quantized_kind, dtype=output_dtype
  47. ).numpy()
  48. * output_scale
  49. )
  50. np.testing.assert_allclose(actual_out, desired_out.numpy())
  51. @pytest.mark.skipif(
  52. get_device_count("gpu") > 0, reason="cuda does not support nchw int8"
  53. )
  54. def test_conv_bias():
  55. inp_scale = np.float32(np.random.rand() + 1)
  56. w_scale = np.float32(np.random.rand() + 1)
  57. outp_scale = np.float32(np.random.rand() + 1)
  58. inp_dtype = dtype.qint8(inp_scale)
  59. w_dtype = dtype.qint8(w_scale)
  60. b_dtype = dtype.qint32(inp_scale * w_scale)
  61. out_dtype = dtype.qint8(outp_scale)
  62. def run(
  63. N,
  64. IC,
  65. OC,
  66. IH,
  67. IW,
  68. KH,
  69. KW,
  70. PH,
  71. PW,
  72. SH,
  73. SW,
  74. has_bias=True,
  75. nonlinear_mode="identity",
  76. ):
  77. inp_v = np.random.normal(size=(N, IC, IH, IW))
  78. w_v = np.random.normal(size=(OC, IC, KH, KW))
  79. b_v = np.random.normal(size=(1, OC, 1, 1))
  80. inp_scale = dtype.get_scale(inp_dtype)
  81. w_scale = dtype.get_scale(w_dtype)
  82. b_scale = dtype.get_scale(b_dtype)
  83. inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
  84. wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
  85. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  86. inp_int8 = mge.tensor(inpv, dtype=inp_dtype)
  87. w_int8 = mge.Parameter(wv, dtype=w_dtype)
  88. b_int32 = mge.Parameter(bv, dtype=b_dtype)
  89. inp_fp32 = inp_int8.astype("float32")
  90. w_fp32 = w_int8.astype("float32")
  91. b_fp32 = b_int32.astype("float32")
  92. def convert_to_nchw4(var):
  93. var = F.reshape(
  94. var, (var.shape[0], var.shape[1] // 4, 4, var.shape[2], var.shape[3])
  95. )
  96. var = F.transpose(var, (0, 1, 3, 4, 2))
  97. return var
  98. def run_conv2d(inp, w, b):
  99. O = F.conv2d(
  100. inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
  101. )
  102. if nonlinear_mode == "relu":
  103. return F.relu(O)
  104. else:
  105. return O
  106. def run_conv_bias(inp, w, b, format="NCHW"):
  107. b = b if has_bias else mge.Parameter(np.zeros_like(b.numpy()))
  108. if format == "NCHW4":
  109. inp = convert_to_nchw4(inp)
  110. w = convert_to_nchw4(w)
  111. b = convert_to_nchw4(b)
  112. return F.quantized.conv_bias_activation(
  113. inp,
  114. w,
  115. b,
  116. stride=(SH, SW),
  117. padding=(PH, PW),
  118. dtype=out_dtype,
  119. nonlinear_mode=nonlinear_mode,
  120. )
  121. format = "NCHW4" if mge.is_cuda_available() else "NCHW"
  122. expected = run_conv2d(inp_fp32, w_fp32, b_fp32)
  123. expected = expected.astype(out_dtype).astype("float32")
  124. result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype(
  125. "float32"
  126. )
  127. if format == "NCHW4":
  128. result = F.transpose(result, (0, 1, 4, 2, 3))
  129. expected = F.flatten(expected)
  130. result = F.flatten(result)
  131. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  132. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False)
  133. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False)
  134. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False)
  135. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1)
  136. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
  137. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)
  138. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu")
  139. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu")
  140. @pytest.mark.skip(reason="does not support int4 when cuda version is lower than 10.2")
  141. def test_conv_bias_int4():
  142. inp_scale = 1.5
  143. w_scale = 2.5
  144. outp_scale = 1.5
  145. inp_dtype = dtype.quint4(inp_scale, 0)
  146. w_dtype = dtype.qint4(w_scale)
  147. b_dtype = dtype.qint32(inp_scale * w_scale)
  148. out_dtype = dtype.quint4(outp_scale, 0)
  149. def run(
  150. N,
  151. IC,
  152. OC,
  153. IH,
  154. IW,
  155. KH,
  156. KW,
  157. PH,
  158. PW,
  159. SH,
  160. SW,
  161. has_bias=True,
  162. nonlinear_mode="identity",
  163. ):
  164. inp_v = np.random.normal(size=(N, IC, IH, IW))
  165. w_v = np.random.normal(size=(OC, IC, KH, KW))
  166. b_v = np.random.normal(size=(1, OC, 1, 1))
  167. inp_scale = dtype.get_scale(inp_dtype)
  168. w_scale = dtype.get_scale(w_dtype)
  169. b_scale = dtype.get_scale(b_dtype)
  170. inpv = dtype.convert_to_quint4(inp_v * inp_scale, inp_dtype)
  171. wv = dtype.convert_to_qint4(w_v * w_scale, w_dtype)
  172. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  173. inp_uint4 = mge.Tensor(inpv, dtype=inp_dtype)
  174. w_int4 = mge.Parameter(wv, dtype=w_dtype)
  175. b_int32 = mge.Parameter(bv, dtype=b_dtype)
  176. inp_fp32 = inp_uint4.astype("float32")
  177. w_fp32 = w_int4.astype("float32")
  178. b_fp32 = b_int32.astype("float32")
  179. def run_conv2d(inp, w, b):
  180. O = F.conv2d(
  181. inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
  182. )
  183. if nonlinear_mode == "relu":
  184. return F.relu(O)
  185. else:
  186. return O
  187. def run_conv_bias(inp, w, b):
  188. b = b if has_bias else mge.Parameter(np.zeros_like(b.numpy()))
  189. return F.quantized.conv_bias_activation(
  190. inp,
  191. w,
  192. b,
  193. stride=(SH, SW),
  194. padding=(PH, PW),
  195. dtype=out_dtype,
  196. nonlinear_mode=nonlinear_mode,
  197. )
  198. expected = run_conv2d(inp_fp32, w_fp32, b_fp32)
  199. expected = expected.astype(out_dtype).astype("float32")
  200. result = run_conv_bias(inp_uint4, w_int4, b_int32).astype("float32")
  201. expected = F.flatten(expected)
  202. result = F.flatten(result)
  203. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  204. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False)
  205. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False)
  206. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False)
  207. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1)
  208. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
  209. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)
  210. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu")
  211. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu")
  212. @pytest.mark.skipif(
  213. get_device_count("gpu") > 0 and get_cuda_compute_capability(0) < 61,
  214. reason="does not support int8 when gpu compute capability less than 6.1",
  215. )
  216. def test_conv_transpose2d():
  217. rng = np.random.RandomState(seed=2021)
  218. def test_func(
  219. N,
  220. IC,
  221. IH,
  222. IW,
  223. OC,
  224. KH,
  225. KW,
  226. SH,
  227. SW,
  228. PH,
  229. PW,
  230. DH,
  231. DW,
  232. groups=1,
  233. has_bias=True,
  234. conv_mode: str = "cross_correlation",
  235. compute_mode: str = "default",
  236. ):
  237. inp_scale = np.float32(rng.uniform(low=0.04, high=0.06))
  238. weight_scale = np.float32(rng.uniform(low=0.04, high=0.06))
  239. bias_scale = inp_scale * weight_scale
  240. out_scale = np.float32(rng.uniform(low=0.04, high=0.06))
  241. inp_dtype = dtype.qint8(inp_scale)
  242. weight_dtype = dtype.qint8(weight_scale)
  243. bias_dtype = dtype.qint32(bias_scale)
  244. out_dtype = dtype.qint8(out_scale)
  245. inp_fp32 = rng.uniform(low=-1, high=1, size=(N, IC, IH, IW)).astype(np.float32)
  246. weight_fp32 = rng.uniform(low=-1, high=1, size=(IC, OC, KH, KW)).astype(
  247. np.float32
  248. )
  249. bias_fp32 = rng.uniform(low=-1, high=1, size=(1, OC, 1, 1)).astype(np.float32)
  250. inp_int8 = dtype.convert_to_qint8(inp_fp32, inp_dtype)
  251. weight_int8 = dtype.convert_to_qint8(weight_fp32, weight_dtype)
  252. bias_int32 = dtype.convert_to_qint32(bias_fp32, bias_dtype)
  253. inp_int8 = mge.tensor(inp_int8, dtype=inp_dtype)
  254. weight_int8 = mge.Parameter(weight_int8, dtype=weight_dtype)
  255. bias_int32 = mge.Parameter(bias_int32, dtype=bias_dtype)
  256. inp_fp32 = inp_int8.astype("float32")
  257. weight_fp32 = weight_int8.astype("float32")
  258. bias_fp32 = bias_int32.astype("float32")
  259. expected = F.conv_transpose2d(
  260. inp_fp32,
  261. weight_fp32,
  262. bias_fp32 if has_bias else None,
  263. stride=(SH, SW),
  264. padding=(PH, PW),
  265. dilation=(DH, DW),
  266. groups=groups,
  267. conv_mode=conv_mode,
  268. compute_mode=compute_mode,
  269. )
  270. expected = dtype.convert_to_qint8(expected.numpy(), out_dtype)
  271. expected = dtype.convert_from_qint8(expected)
  272. conv_transpose2d = ConvTranspose2d(
  273. in_channels=IC,
  274. out_channels=OC,
  275. kernel_size=(KH, KW),
  276. stride=(SH, SW),
  277. padding=(PH, PW),
  278. dilation=(DH, DW),
  279. groups=groups,
  280. bias=has_bias,
  281. conv_mode=conv_mode,
  282. compute_mode=compute_mode,
  283. dtype=out_dtype,
  284. )
  285. conv_transpose2d.weight = mge.Parameter(weight_int8)
  286. if has_bias:
  287. conv_transpose2d.bias = mge.Parameter(bias_int32)
  288. result = conv_transpose2d.forward(inp_int8).numpy()
  289. result = dtype.convert_from_qint8(result)
  290. np.testing.assert_allclose(result, expected, atol=out_scale)
  291. test_func(1, 4, 1, 1, 4, 1, 1, 1, 1, 0, 0, 1, 1, 1, False)
  292. test_func(2, 4, 3, 1, 8, 1, 1, 1, 1, 0, 0, 1, 1, 1, False)
  293. test_func(4, 4, 16, 16, 8, 3, 3, 1, 1, 1, 1, 1, 1, 1, False)
  294. test_func(32, 64, 36, 28, 16, 3, 2, 1, 3, 1, 0, 1, 1, 1, False)
  295. def test_matmul():
  296. inp_scale = np.float32(np.random.rand())
  297. weight_scale = np.float32(np.random.rand())
  298. inp_dtype = dtype.qint8(inp_scale)
  299. weight_dtype = dtype.qint8(weight_scale)
  300. inp_data = np.random.random((3, 12))
  301. weight_data = np.random.random((5, 12))
  302. inp_int8 = mge.tensor(dtype.convert_to_qint8(inp_data, inp_dtype))
  303. weight_int8 = mge.tensor(dtype.convert_to_qint8(weight_data, weight_dtype))
  304. res = F.matmul(inp_int8, weight_int8, transpose_b=True)
  305. res_scale = dtype.get_scale(res.dtype)
  306. np.testing.assert_allclose(inp_scale * weight_scale, res_scale)