You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_op.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import numpy as np
  9. import pytest
  10. import megengine as mge
  11. import megengine.functional as F
  12. from megengine.core.tensor import dtype
  13. from megengine.device import get_cuda_compute_capability, get_device_count
  14. from megengine.functional.elemwise import _elemwise_multi_type, _elwise
  15. from megengine.module.quantized.conv import ConvTranspose2d
  16. from megengine.quantization import QuantMode, create_qparams
  17. def quant(x, scale):
  18. x_dtype = dtype.qint8(scale)
  19. return x.astype(x_dtype)
  20. def fake_quant(x, scale):
  21. x = x / scale
  22. x = F.round(x)
  23. x = F.clip(x, -128, 127)
  24. x = x * scale
  25. return x
  26. @pytest.mark.parametrize("kind", ["abs", "sin", "sub", "mul", "fuse_add_tanh"])
  27. def test_elemwise(kind):
  28. x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  29. x1_scale = np.float32(np.random.rand() + 1)
  30. x1 = fake_quant(x1, x1_scale)
  31. x1.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", x1_scale))
  32. x1_int8 = quant(x1, x1_scale)
  33. x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
  34. x2_scale = np.float32(np.random.rand() + 1)
  35. x2 = fake_quant(x2, x2_scale)
  36. x2.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", x2_scale))
  37. x2_int8 = quant(x2, x2_scale)
  38. output_scale = np.float32(np.random.rand() + 1)
  39. output_dtype = dtype.qint8(output_scale)
  40. quantized_kind = "q" + kind
  41. if kind in ("abs", "sin"):
  42. desired_out = fake_quant(_elwise(x1, mode=kind), output_scale)
  43. actual_out = (
  44. _elemwise_multi_type(
  45. x1_int8, mode=quantized_kind, dtype=output_dtype
  46. ).numpy()
  47. * output_scale
  48. )
  49. else:
  50. desired_out = fake_quant(_elwise(x1, x2, mode=kind), output_scale)
  51. actual_out = (
  52. _elemwise_multi_type(
  53. x1_int8, x2_int8, mode=quantized_kind, dtype=output_dtype
  54. ).numpy()
  55. * output_scale
  56. )
  57. np.testing.assert_allclose(actual_out, desired_out.numpy())
  58. @pytest.mark.skipif(
  59. get_device_count("gpu") > 0, reason="cuda does not support nchw int8"
  60. )
  61. def test_conv_bias():
  62. inp_scale = np.float32(np.random.rand() + 1)
  63. w_scale = np.float32(np.random.rand() + 1)
  64. outp_scale = np.float32(np.random.rand() + 1)
  65. inp_dtype = dtype.qint8(inp_scale)
  66. w_dtype = dtype.qint8(w_scale)
  67. b_dtype = dtype.qint32(inp_scale * w_scale)
  68. out_dtype = dtype.qint8(outp_scale)
  69. def run(
  70. N,
  71. IC,
  72. OC,
  73. IH,
  74. IW,
  75. KH,
  76. KW,
  77. PH,
  78. PW,
  79. SH,
  80. SW,
  81. has_bias=True,
  82. nonlinear_mode="identity",
  83. ):
  84. inp_v = np.random.normal(size=(N, IC, IH, IW))
  85. w_v = np.random.normal(size=(OC, IC, KH, KW))
  86. b_v = np.random.normal(size=(1, OC, 1, 1))
  87. inp_scale = dtype.get_scale(inp_dtype)
  88. w_scale = dtype.get_scale(w_dtype)
  89. b_scale = dtype.get_scale(b_dtype)
  90. inpv = dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype)
  91. wv = dtype.convert_to_qint8(w_v * w_scale, w_dtype)
  92. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  93. inp_int8 = mge.tensor(inpv, dtype=inp_dtype)
  94. w_int8 = mge.Parameter(wv, dtype=w_dtype)
  95. b_int32 = mge.Parameter(bv, dtype=b_dtype)
  96. inp_fp32 = inp_int8.astype("float32")
  97. w_fp32 = w_int8.astype("float32")
  98. b_fp32 = b_int32.astype("float32")
  99. def convert_to_nchw4(var):
  100. var = F.reshape(
  101. var, (var.shape[0], var.shape[1] // 4, 4, var.shape[2], var.shape[3])
  102. )
  103. var = F.transpose(var, (0, 1, 3, 4, 2))
  104. return var
  105. def run_conv2d(inp, w, b):
  106. O = F.conv2d(
  107. inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
  108. )
  109. if nonlinear_mode == "relu":
  110. return F.relu(O)
  111. else:
  112. return O
  113. def run_conv_bias(inp, w, b, format="NCHW"):
  114. b = b if has_bias else mge.Parameter(np.zeros_like(b.numpy()))
  115. if format == "NCHW4":
  116. inp = convert_to_nchw4(inp)
  117. w = convert_to_nchw4(w)
  118. b = convert_to_nchw4(b)
  119. return F.quantized.conv_bias_activation(
  120. inp,
  121. w,
  122. b,
  123. stride=(SH, SW),
  124. padding=(PH, PW),
  125. dtype=out_dtype,
  126. nonlinear_mode=nonlinear_mode,
  127. )
  128. format = "NCHW4" if mge.is_cuda_available() else "NCHW"
  129. expected = run_conv2d(inp_fp32, w_fp32, b_fp32)
  130. expected = expected.astype(out_dtype).astype("float32")
  131. result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype(
  132. "float32"
  133. )
  134. if format == "NCHW4":
  135. result = F.transpose(result, (0, 1, 4, 2, 3))
  136. expected = F.flatten(expected)
  137. result = F.flatten(result)
  138. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  139. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False)
  140. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False)
  141. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False)
  142. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1)
  143. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
  144. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)
  145. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu")
  146. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu")
  147. @pytest.mark.skip(reason="does not support int4 when cuda version is lower than 10.2")
  148. def test_conv_bias_int4():
  149. inp_scale = 1.5
  150. w_scale = 2.5
  151. outp_scale = 1.5
  152. inp_dtype = dtype.quint4(inp_scale, 0)
  153. w_dtype = dtype.qint4(w_scale)
  154. b_dtype = dtype.qint32(inp_scale * w_scale)
  155. out_dtype = dtype.quint4(outp_scale, 0)
  156. def run(
  157. N,
  158. IC,
  159. OC,
  160. IH,
  161. IW,
  162. KH,
  163. KW,
  164. PH,
  165. PW,
  166. SH,
  167. SW,
  168. has_bias=True,
  169. nonlinear_mode="identity",
  170. ):
  171. inp_v = np.random.normal(size=(N, IC, IH, IW))
  172. w_v = np.random.normal(size=(OC, IC, KH, KW))
  173. b_v = np.random.normal(size=(1, OC, 1, 1))
  174. inp_scale = dtype.get_scale(inp_dtype)
  175. w_scale = dtype.get_scale(w_dtype)
  176. b_scale = dtype.get_scale(b_dtype)
  177. inpv = dtype.convert_to_quint4(inp_v * inp_scale, inp_dtype)
  178. wv = dtype.convert_to_qint4(w_v * w_scale, w_dtype)
  179. bv = dtype.convert_to_qint32(b_v * b_scale, b_dtype)
  180. inp_uint4 = mge.Tensor(inpv, dtype=inp_dtype)
  181. w_int4 = mge.Parameter(wv, dtype=w_dtype)
  182. b_int32 = mge.Parameter(bv, dtype=b_dtype)
  183. inp_fp32 = inp_uint4.astype("float32")
  184. w_fp32 = w_int4.astype("float32")
  185. b_fp32 = b_int32.astype("float32")
  186. def run_conv2d(inp, w, b):
  187. O = F.conv2d(
  188. inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW),
  189. )
  190. if nonlinear_mode == "relu":
  191. return F.relu(O)
  192. else:
  193. return O
  194. def run_conv_bias(inp, w, b):
  195. b = b if has_bias else mge.Parameter(np.zeros_like(b.numpy()))
  196. return F.quantized.conv_bias_activation(
  197. inp,
  198. w,
  199. b,
  200. stride=(SH, SW),
  201. padding=(PH, PW),
  202. dtype=out_dtype,
  203. nonlinear_mode=nonlinear_mode,
  204. )
  205. expected = run_conv2d(inp_fp32, w_fp32, b_fp32)
  206. expected = expected.astype(out_dtype).astype("float32")
  207. result = run_conv_bias(inp_uint4, w_int4, b_int32).astype("float32")
  208. expected = F.flatten(expected)
  209. result = F.flatten(result)
  210. np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=outp_scale)
  211. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False)
  212. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False)
  213. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False)
  214. run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1)
  215. run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1)
  216. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2)
  217. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "relu")
  218. run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu")
  219. @pytest.mark.skipif(
  220. get_cuda_compute_capability(0) < 61,
  221. reason="does not support int8 when gpu compute capability less than 6.1",
  222. )
  223. def test_conv_transpose2d():
  224. rng = np.random.RandomState(seed=2021)
  225. def test_func(
  226. N,
  227. IC,
  228. IH,
  229. IW,
  230. OC,
  231. KH,
  232. KW,
  233. SH,
  234. SW,
  235. PH,
  236. PW,
  237. DH,
  238. DW,
  239. groups=1,
  240. has_bias=True,
  241. conv_mode: str = "cross_correlation",
  242. compute_mode: str = "default",
  243. ):
  244. inp_scale = np.float32(rng.uniform(low=0.04, high=0.06))
  245. weight_scale = np.float32(rng.uniform(low=0.04, high=0.06))
  246. bias_scale = inp_scale * weight_scale
  247. out_scale = np.float32(rng.uniform(low=0.04, high=0.06))
  248. inp_dtype = dtype.qint8(inp_scale)
  249. weight_dtype = dtype.qint8(weight_scale)
  250. bias_dtype = dtype.qint32(bias_scale)
  251. out_dtype = dtype.qint8(out_scale)
  252. inp_fp32 = rng.uniform(low=-1, high=1, size=(N, IC, IH, IW)).astype(np.float32)
  253. weight_fp32 = rng.uniform(low=-1, high=1, size=(IC, OC, KH, KW)).astype(
  254. np.float32
  255. )
  256. bias_fp32 = rng.uniform(low=-1, high=1, size=(1, OC, 1, 1)).astype(np.float32)
  257. inp_int8 = dtype.convert_to_qint8(inp_fp32, inp_dtype)
  258. weight_int8 = dtype.convert_to_qint8(weight_fp32, weight_dtype)
  259. bias_int32 = dtype.convert_to_qint32(bias_fp32, bias_dtype)
  260. inp_int8 = mge.tensor(inp_int8, dtype=inp_dtype)
  261. weight_int8 = mge.Parameter(weight_int8, dtype=weight_dtype)
  262. bias_int32 = mge.Parameter(bias_int32, dtype=bias_dtype)
  263. inp_fp32 = inp_int8.astype("float32")
  264. weight_fp32 = weight_int8.astype("float32")
  265. bias_fp32 = bias_int32.astype("float32")
  266. expected = F.conv_transpose2d(
  267. inp_fp32,
  268. weight_fp32,
  269. bias_fp32 if has_bias else None,
  270. stride=(SH, SW),
  271. padding=(PH, PW),
  272. dilation=(DH, DW),
  273. groups=groups,
  274. conv_mode=conv_mode,
  275. compute_mode=compute_mode,
  276. )
  277. expected = dtype.convert_to_qint8(expected.numpy(), out_dtype)
  278. expected = dtype.convert_from_qint8(expected)
  279. conv_transpose2d = ConvTranspose2d(
  280. in_channels=IC,
  281. out_channels=OC,
  282. kernel_size=(KH, KW),
  283. stride=(SH, SW),
  284. padding=(PH, PW),
  285. dilation=(DH, DW),
  286. groups=groups,
  287. bias=has_bias,
  288. conv_mode=conv_mode,
  289. compute_mode=compute_mode,
  290. dtype=out_dtype,
  291. )
  292. conv_transpose2d.weight = mge.Parameter(weight_int8)
  293. if has_bias:
  294. conv_transpose2d.bias = mge.Parameter(bias_int32)
  295. result = conv_transpose2d.forward(inp_int8).numpy()
  296. result = dtype.convert_from_qint8(result)
  297. np.testing.assert_allclose(result, expected, atol=out_scale)
  298. test_func(1, 4, 1, 1, 4, 1, 1, 1, 1, 0, 0, 1, 1, 1, False)
  299. test_func(2, 4, 3, 1, 8, 1, 1, 1, 1, 0, 0, 1, 1, 1, False)
  300. test_func(4, 4, 16, 16, 8, 3, 3, 1, 1, 1, 1, 1, 1, 1, False)
  301. test_func(32, 64, 36, 28, 16, 3, 2, 1, 3, 1, 0, 1, 1, 1, False)