You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

conv_bias_multi_thread.cpp 45 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111
  1. /**
  2. * \file dnn/test/arm_common/conv_bias_multi_thread.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/dtype.h"
  13. #include "test/arm_common/fixture.h"
  14. #include "test/common/benchmarker.h"
  15. #include "test/common/conv_bias.h"
  16. #include "test/arm_common/cpuinfo_help.h"
  17. using namespace megdnn;
  18. using namespace test;
  19. using namespace conv_bias;
  20. std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args(
  21. std::vector<size_t> kernel, size_t stride, bool no_pad, bool no_bias,
  22. bool no_nonlinemode) {
  23. using namespace conv_bias;
  24. using Param = param::ConvBias;
  25. using NLMode = param::ConvBias::NonlineMode;
  26. std::vector<TestArg> args;
  27. auto pack = [&](size_t n, size_t oc, size_t ic, size_t w, size_t h,
  28. size_t kernel, size_t stride, NLMode nlmode) {
  29. Param param;
  30. param.stride_h = stride;
  31. param.stride_w = stride;
  32. if (!no_pad) {
  33. param.pad_h = kernel / 2;
  34. param.pad_w = kernel / 2;
  35. } else {
  36. param.pad_h = 0;
  37. param.pad_w = 0;
  38. }
  39. param.nonlineMode = nlmode;
  40. args.emplace_back(param, TensorShape{n, ic, h, w},
  41. TensorShape{oc, ic, kernel, kernel}, TensorShape{});
  42. if (!no_bias) {
  43. args.emplace_back(param, TensorShape{n, ic, h, w},
  44. TensorShape{oc, ic, kernel, kernel},
  45. TensorShape{1, oc, 1, 1});
  46. }
  47. };
  48. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  49. if (!no_nonlinemode) {
  50. nonlinemode.emplace_back(NLMode::RELU);
  51. nonlinemode.emplace_back(NLMode::H_SWISH);
  52. }
  53. for (size_t n : {1, 2}) {
  54. for (auto nlmode : nonlinemode) {
  55. for (size_t ic : {1, 3, 7}) {
  56. for (size_t oc : {1, 3, 7}) {
  57. for (size_t size : {4, 6, 8, 14, 16, 18}) {
  58. for (size_t kern : kernel) {
  59. pack(n, oc, ic, size, size, kern, stride, nlmode);
  60. }
  61. }
  62. }
  63. }
  64. }
  65. }
  66. return args;
  67. }
  68. std::vector<conv_bias::TestArg> get_nchw44_channel_wise_args(
  69. std::vector<size_t> kernel, size_t stride, bool no_bias,
  70. bool no_nonlinemode, bool no_full_bias) {
  71. using namespace conv_bias;
  72. using Param = param::ConvBias;
  73. using NLMode = param::ConvBias::NonlineMode;
  74. std::vector<TestArg> args;
  75. auto pack = [&](size_t n, size_t group, size_t w, size_t h, size_t kernel,
  76. size_t stride, NLMode nlmode, bool pad) {
  77. Param param;
  78. param.stride_h = stride;
  79. param.stride_w = stride;
  80. if (pad) {
  81. param.pad_h = kernel / 2;
  82. param.pad_w = kernel / 2;
  83. } else {
  84. param.pad_h = 0;
  85. param.pad_w = 0;
  86. }
  87. param.nonlineMode = nlmode;
  88. param.format = param::ConvBias::Format::NCHW44;
  89. param.sparse = param::ConvBias::Sparse::GROUP;
  90. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  91. TensorShape{group, 1, 1, kernel, kernel, 4},
  92. TensorShape{});
  93. if (!no_bias) {
  94. args.emplace_back(param, TensorShape{n, group, h, w, 4},
  95. TensorShape{group, 1, 1, kernel, kernel, 4},
  96. TensorShape{1, group, 1, 1, 4});
  97. }
  98. if (!no_full_bias) {
  99. args.emplace_back(
  100. param, TensorShape{n, group, h, w, 4},
  101. TensorShape{group, 1, 1, kernel, kernel, 4},
  102. TensorShape{n, group,
  103. (h + 2 * param.pad_w - kernel) / stride + 1,
  104. (w + 2 * param.pad_w - kernel) / stride + 1,
  105. 4});
  106. }
  107. };
  108. std::vector<NLMode> nonlinemode = {NLMode::IDENTITY};
  109. if (!no_nonlinemode) {
  110. nonlinemode.emplace_back(NLMode::RELU);
  111. nonlinemode.emplace_back(NLMode::H_SWISH);
  112. }
  113. for (size_t n : {1, 2}) {
  114. for (auto nlmode : nonlinemode) {
  115. for (bool pad : {true}) {
  116. for (size_t group : {1, 2, 4, 7, 128}) {
  117. for (size_t size : {4, 6, 7, 9, 15, 40}) {
  118. for (size_t kern : kernel) {
  119. pack(n, group, size, size, kern, stride, nlmode,
  120. pad);
  121. }
  122. }
  123. }
  124. }
  125. for (bool pad : {false}) {
  126. for (size_t group : {1, 2, 7, 128}) {
  127. for (size_t size : {7, 9, 15, 40}) {
  128. for (size_t kern : kernel) {
  129. pack(n, group, size, size, kern, stride, nlmode,
  130. pad);
  131. }
  132. }
  133. }
  134. }
  135. }
  136. }
  137. return args;
  138. }
  139. void checker_conv_bias_qint8x8x8(std::vector<conv_bias::TestArg> args,
  140. Handle* handle, const char* algo_name) {
  141. Checker<ConvBias> checker(handle);
  142. checker.set_before_exec_callback(
  143. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  144. #if MEGDNN_ARMV7
  145. checker.set_epsilon(1);
  146. #endif
  147. UniformIntRNG rng{-50, 50};
  148. checker.set_dtype(0, dtype::QuantizedS8(0.41113496f))
  149. .set_dtype(1, dtype::QuantizedS8(0.01887994f))
  150. .set_dtype(2, dtype::QuantizedS32(0.41113496f * 0.01887994f))
  151. .set_dtype(4, dtype::QuantizedS8(0.49550694f))
  152. .set_rng(0, &rng)
  153. .set_rng(1, &rng)
  154. .set_rng(2, &rng);
  155. for (auto&& arg : args) {
  156. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  157. }
  158. }
  159. void checker_conv_bias_qint8x8x32(std::vector<conv_bias::TestArg> args,
  160. Handle* handle, const char* algo_name) {
  161. Checker<ConvBias> checker(handle);
  162. UniformIntRNG rng{-50, 50};
  163. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  164. .set_dtype(1, dtype::QuantizedS8(2.5f))
  165. .set_dtype(2, dtype::QuantizedS32(6.25f))
  166. .set_dtype(4, {});
  167. checker.set_before_exec_callback(
  168. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  169. for (auto&& arg : args) {
  170. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  171. }
  172. }
  173. void checker_conv_bias_quint8x8x8(std::vector<conv_bias::TestArg> args,
  174. Handle* handle, const char* algo_name) {
  175. Checker<ConvBias> checker(handle);
  176. checker.set_before_exec_callback(
  177. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  178. UniformIntRNG rng(0, 255);
  179. checker.set_dtype(0, dtype::Quantized8Asymm(0.2f, 100))
  180. .set_dtype(1, dtype::Quantized8Asymm(0.2f, 120))
  181. .set_dtype(2, dtype::QuantizedS32(0.04f))
  182. .set_dtype(4, dtype::Quantized8Asymm(1.4f, 110))
  183. .set_rng(0, &rng)
  184. .set_rng(1, &rng)
  185. .set_rng(2, &rng);
  186. for (auto&& arg : args) {
  187. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  188. }
  189. }
  190. void checker_conv_bias_quint8x8x32(std::vector<conv_bias::TestArg> args,
  191. Handle* handle, const char* algo_name) {
  192. Checker<ConvBias> checker(handle);
  193. checker.set_before_exec_callback(
  194. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  195. NormalRNG rng(128.f);
  196. checker.set_rng(0, &rng).set_rng(1, &rng);
  197. checker.set_dtype(0, dtype::Quantized8Asymm(1.2f, (uint8_t)127))
  198. .set_dtype(1, dtype::Quantized8Asymm(1.3f, (uint8_t)129))
  199. .set_dtype(2, dtype::QuantizedS32(1.2 * 1.3))
  200. .set_dtype(4, {});
  201. for (auto&& arg : args) {
  202. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  203. }
  204. }
  205. void checker_conv_bias_int8x8x32_multi(std::vector<conv_bias::TestArg> args,
  206. Handle* handle, const char* algo_name) {
  207. Checker<ConvBias> checker(handle);
  208. checker.set_before_exec_callback(
  209. conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
  210. checker.set_dtype(0, dtype::Int8());
  211. checker.set_dtype(1, dtype::Int8());
  212. checker.set_dtype(2, dtype::Int32());
  213. checker.set_dtype(4, dtype::Int32());
  214. for (auto&& arg : args) {
  215. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  216. }
  217. }
  218. /**********************************F32 direct************************/
  219. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32) {
  220. check_conv_bias(
  221. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  222. handle(), "F32DIRECT");
  223. }
  224. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) {
  225. //! k=7 s=1
  226. check_conv_bias(get_nchw44_conv_bias_args({7}, ONLY_IDENTITY_NLMODE,
  227. BR_AND_NO_BIASMODE, 1),
  228. handle(), "F32_CONV_NCHW44_DIRECT");
  229. }
  230. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) {
  231. check_conv_bias(
  232. get_nchw44_conv_bias_args({2, 3}, FULL_NLMODE, ONLY_BR_BIASMODE, 1),
  233. handle(), "F32_CONV_NCHW44_DIRECT");
  234. }
  235. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) {
  236. check_conv_bias(
  237. get_nchw44_conv_bias_args({5}, FULL_NLMODE, ONLY_BR_BIASMODE, 1),
  238. handle(), "F32_CONV_NCHW44_DIRECT");
  239. }
  240. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) {
  241. check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, FULL_NLMODE,
  242. ONLY_BR_BIASMODE, 2),
  243. handle(), "F32_CONV_NCHW44_DIRECT");
  244. }
  245. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR1) {
  246. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 1, false, false, false),
  247. handle(), "F32STRD1");
  248. }
  249. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2) {
  250. check_conv_bias(get_conv_bias_args({2, 3, 5, 7}, 2, false, false, false),
  251. handle(), "F32STRD2");
  252. }
  253. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S2) {
  254. check_conv_bias(
  255. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  256. ONLY_BR_BIASMODE, 2, false, true),
  257. handle(), "F32_CONV_NCHW_NCHW44");
  258. }
  259. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32_S1) {
  260. check_conv_bias(
  261. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  262. ONLY_BR_BIASMODE, 1, false, true),
  263. handle(), "F32_CONV_NCHW_NCHW44");
  264. }
  265. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) {
  266. check_conv_bias(
  267. get_nchw44_channel_wise_args({2, 3}, 1, false, false, false),
  268. handle(), "F32_CHANNEL_WISE_NCHW44");
  269. }
  270. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) {
  271. check_conv_bias(get_nchw44_channel_wise_args({5}, 1, false, false, false),
  272. handle(), "F32_CHANNEL_WISE_NCHW44");
  273. }
  274. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) {
  275. check_conv_bias(
  276. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, false),
  277. handle(), "F32_CHANNEL_WISE_NCHW44");
  278. }
  279. /**********************************F16 direct************************/
  280. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  281. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16) {
  282. NormalRNG rng(1);
  283. checker_conv_bias_f16(
  284. get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
  285. handle(), rng, "F16DIRECT", 0.03);
  286. }
  287. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP16_STR1) {
  288. NormalRNG rng(1);
  289. checker_conv_bias_f16(get_conv_bias_args({2, 3, 5}, 1, false, false, false),
  290. handle(), rng, "F16STRD1", 0.03);
  291. }
  292. #endif
  293. /**********************************algo 8816 direct************************/
  294. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_DIRECT) {
  295. checker_conv_bias_int8x8x16(
  296. get_conv_bias_args({2, 3, 5}, 1, false, true, true), handle(),
  297. "I8816DIRECT");
  298. }
  299. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_STRIDE2) {
  300. checker_conv_bias_int8x8x16(
  301. get_conv_bias_args({2, 3, 5}, 2, false, true, true), handle(),
  302. "I8816STRD2");
  303. }
  304. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S2) {
  305. checker_conv_bias_int8x8x16(
  306. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  307. ONLY_NO_BIASMODE, 2, false, true),
  308. handle(), "I8816_CONV_NCHW_NCHW44");
  309. }
  310. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT16_NCHW_NCHW44_S1) {
  311. checker_conv_bias_int8x8x16(
  312. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  313. ONLY_NO_BIASMODE, 1, false, true),
  314. handle(), "I8816_CONV_NCHW_NCHW44");
  315. }
  316. /**********************************algo 8-8-32 direct************************/
  317. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE1) {
  318. checker_conv_bias_int8x8x32_multi(
  319. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  320. "S8STRD1");
  321. }
  322. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_INT8_INT32_STRIDE2) {
  323. checker_conv_bias_int8x8x32_multi(
  324. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  325. "S8STRD2");
  326. }
  327. TEST_F(ARM_COMMON_MULTI_THREADS,
  328. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT1_NCHW44) {
  329. checker_conv_bias_int8x8x32_multi(
  330. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true),
  331. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  332. }
  333. TEST_F(ARM_COMMON_MULTI_THREADS,
  334. CONV_BIAS_INT8_INT8_INT32_CHANNEL_WISE_DIRECT2_NCHW44) {
  335. checker_conv_bias_int8x8x32_multi(
  336. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true),
  337. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  338. }
  339. TEST_F(ARM_COMMON, CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT1_NCHW44) {
  340. Checker<ConvBias> checker(handle());
  341. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  342. "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
  343. checker.set_dtype(0, dtype::Int8());
  344. checker.set_dtype(1, dtype::Int8());
  345. checker.set_dtype(2, dtype::Int16());
  346. checker.set_dtype(4, dtype::Int16());
  347. auto args = get_nchw44_channel_wise_args({2, 3, 5}, 1, false, true, true);
  348. for (auto&& arg : args) {
  349. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  350. }
  351. }
  352. TEST_F(ARM_COMMON_MULTI_THREADS,
  353. CONV_BIAS_INT8_INT8_INT16_CHANNEL_WISE_DIRECT2_NCHW44) {
  354. Checker<ConvBias> checker(handle());
  355. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  356. "S8x8x16_CHAN_WISE_STRD1_STRD2_NCHW44"));
  357. checker.set_dtype(0, dtype::Int8());
  358. checker.set_dtype(1, dtype::Int8());
  359. checker.set_dtype(2, dtype::Int16());
  360. checker.set_dtype(4, dtype::Int16());
  361. auto args = get_nchw44_channel_wise_args({2, 3, 5}, 2, false, true, true);
  362. for (auto&& arg : args) {
  363. checker.set_param(arg.param).execs({arg.src, arg.filter, {}, {}, {}});
  364. }
  365. }
  366. /********************************qint8 direct******************************/
  367. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1) {
  368. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  369. {2, 3, 5, 7}, 1, false, false, false),
  370. handle(), "S8STRD1");
  371. }
  372. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2) {
  373. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  374. {2, 3, 5, 7}, 2, false, false, false),
  375. handle(), "S8STRD2");
  376. }
  377. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44) {
  378. checker_conv_bias_qint8x8x8(
  379. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  380. ONLY_BR_BIASMODE, 1),
  381. handle(), "S8_NCHW44_DIRECT");
  382. }
  383. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8816) {
  384. checker_conv_bias_int8x8x16(
  385. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  386. ONLY_BR_BIASMODE, 1),
  387. handle(), "S8x8x16_NCHW44_DIRECT");
  388. }
  389. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8816) {
  390. checker_conv_bias_int8x8x16(
  391. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  392. ONLY_BR_BIASMODE, 2),
  393. handle(), "S8x8x16_NCHW44_DIRECT");
  394. }
  395. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_NCHW44_8832) {
  396. checker_conv_bias_qint8x8x32(
  397. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  398. ONLY_BR_BIASMODE, 1),
  399. handle(), "S8_NCHW44_DIRECT");
  400. }
  401. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44_8832) {
  402. checker_conv_bias_qint8x8x32(
  403. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  404. ONLY_NO_BIASMODE, 2),
  405. handle(), "S8_NCHW44_DIRECT");
  406. }
  407. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_NCHW44) {
  408. checker_conv_bias_qint8x8x8(
  409. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  410. BR_AND_NO_BIASMODE, 2),
  411. handle(), "S8_NCHW44_DIRECT");
  412. }
  413. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT1_NCHW44) {
  414. checker_conv_bias_qint8x8x8(
  415. get_nchw44_conv_bias_args({2, 3, 5, 7}, ONLY_IDENTITY_NLMODE,
  416. BR_AND_NO_BIASMODE, 1),
  417. handle(), "S8_NCHW44_DIRECT");
  418. checker_conv_bias_qint8x8x8(
  419. get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, true),
  420. handle(), "S8_CHAN_WISE_STRD1_NCHW44");
  421. }
  422. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QS8_CHANNEL_WISE_DIRECT2_NCHW44) {
  423. checker_conv_bias_qint8x8x8(
  424. get_nchw44_channel_wise_args({2, 3, 5}, 2, false, false, true),
  425. handle(), "S8_CHAN_WISE_STRD2_NCHW44");
  426. }
  427. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1) {
  428. checker_conv_bias_qint8x8x8(
  429. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  430. BR_AND_NO_BIASMODE, 1, false, true),
  431. handle(), "S8_CONV_NCHW_NCHW44");
  432. }
  433. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S2) {
  434. checker_conv_bias_qint8x8x8(
  435. get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  436. BR_AND_NO_BIASMODE, 2, false, true),
  437. handle(), "S8_CONV_NCHW_NCHW44");
  438. }
  439. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_NCHW_NCHW44_S1_F1) {
  440. checker_conv_bias_qint8x8x8(
  441. get_nchw44_conv_bias_args({1}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1,
  442. false, true),
  443. handle(), "S8_CONV_NCHW_NCHW44");
  444. }
  445. /*****************************quint8 direct****************************/
  446. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1) {
  447. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  448. {2, 3, 5, 7}, 1, false, false, false),
  449. handle(), "QU8STRD1");
  450. }
  451. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) {
  452. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  453. {2, 3, 5, 7}, 2, false, false, false),
  454. handle(), "QU8STRD2");
  455. }
  456. /****************************dot qint8 direct*************************/
  457. #if MGB_ENABLE_DOT
  458. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
  459. auto args = get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  460. BR_AND_NO_BIASMODE, 2, false, true);
  461. for (auto&& arg : args) {
  462. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  463. }
  464. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  465. args = get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
  466. BR_AND_NO_BIASMODE, 1, false, true);
  467. for (auto&& arg : args) {
  468. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  469. }
  470. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  471. }
  472. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44_S1_F1) {
  473. auto args = get_nchw44_conv_bias_args({1}, QUAN_NLMODE, BR_AND_NO_BIASMODE,
  474. 1, false, true);
  475. for (auto&& arg : args) {
  476. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  477. }
  478. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
  479. }
  480. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) {
  481. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  482. {2, 3, 5, 7}, 1, false, false, false),
  483. handle(), "ARMDOTS8STRD1");
  484. }
  485. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_INT8_STRIDE2_WITHDOTPROD) {
  486. checker_conv_bias_qint8x8x8(get_int8_quint8_conv_bias_args(
  487. {2, 3, 5, 7}, 2, false, false, false),
  488. handle(), "ARMDOTS8STRD2");
  489. }
  490. /****************************dot 8-8-32 direct*************************/
  491. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD1_WITHDOT) {
  492. checker_conv_bias_qint8x8x32(
  493. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  494. "ARMDOTS8STRD1");
  495. }
  496. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_I8832STRD2_WITHDOT) {
  497. checker_conv_bias_qint8x8x32(
  498. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  499. "ARMDOTS8STRD2");
  500. }
  501. /******************************dot quint8*****************************/
  502. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE1_WITHDOTPROD) {
  503. checker_conv_bias_quint8x8x8(get_int8_quint8_conv_bias_args(
  504. {2, 3, 5, 7}, 1, false, false, false),
  505. handle(), "ARMDOTU8STRD1");
  506. }
  507. //! TODO: this test without test kernel size=3, add it will case buss error now
  508. //! in armv7
  509. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_WITHDOTPROD) {
  510. checker_conv_bias_quint8x8x8(
  511. get_int8_quint8_conv_bias_args({2, 5, 7}, 2, false, false, false),
  512. handle(), "ARMDOTU8STRD2");
  513. }
  514. /******************************dot quint8x8x32***********************/
  515. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE1) {
  516. checker_conv_bias_quint8x8x32(
  517. get_conv_bias_args({2, 3, 5, 7}, 1, false, true, true), handle(),
  518. "ARMDOTU8STRD1");
  519. }
  520. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_QUINT8_DIRECT_STRIDE2) {
  521. checker_conv_bias_quint8x8x32(
  522. get_conv_bias_args({2, 3, 5, 7}, 2, false, true, true), handle(),
  523. "ARMDOTU8STRD2");
  524. }
  525. /******************************dot int8x8x8 nchw44 ***********************/
  526. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x8) {
  527. using namespace conv_bias;
  528. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  529. {2, 3, 5, 7}, QUAN_NLMODE, ONLY_BR_BIASMODE, 1);
  530. for (auto&& arg : args)
  531. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  532. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  533. }
  534. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_Q8x8x32) {
  535. using namespace conv_bias;
  536. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  537. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1);
  538. for (auto&& arg : args)
  539. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  540. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  541. }
  542. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S1_8x8x32) {
  543. using namespace conv_bias;
  544. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  545. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 1);
  546. for (auto&& arg : args)
  547. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  548. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  549. }
  550. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x8) {
  551. using namespace conv_bias;
  552. //! test qint8x8x8
  553. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  554. {2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2);
  555. for (auto&& arg : args)
  556. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  557. checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  558. }
  559. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_Q8x8x32) {
  560. using namespace conv_bias;
  561. //! test qint8x8x8
  562. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  563. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 2);
  564. for (auto&& arg : args)
  565. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  566. checker_conv_bias_qint8x8x32(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  567. }
  568. TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_INT8_DIRECT_DOT_NCHW44_S2_8x8x32) {
  569. using namespace conv_bias;
  570. //! test qint8x8x8
  571. std::vector<TestArg> args = get_nchw44_conv_bias_args(
  572. {2, 3, 5, 7}, ONLY_IDENTITY_NLMODE, BR_AND_NO_BIASMODE, 2);
  573. for (auto&& arg : args)
  574. arg.param.format = param::ConvBias::Format::NCHW44_DOT;
  575. checker_conv_bias_int8x8x32_multi(args, handle(), "ARMDOTS8DIRECT_NCHW44");
  576. }
  577. #endif
  578. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD) {
  579. using namespace conv_bias;
  580. std::vector<TestArg> args = get_winograd_args(3);
  581. Checker<ConvBiasForward> checker(handle());
  582. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  583. DType B_dtype, DType C_dtype, DType D_dtype,
  584. const float eps) {
  585. for (auto&& arg : args) {
  586. checker.set_dtype(0, A_dtype)
  587. .set_dtype(1, B_dtype)
  588. .set_dtype(2, C_dtype)
  589. .set_dtype(4, D_dtype)
  590. .set_epsilon(eps)
  591. .set_param(arg.param)
  592. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  593. }
  594. };
  595. run(args, dtype::Float32(), dtype::Float32(), dtype::Float32(),
  596. dtype::Float32(), 1e-3f);
  597. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  598. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  599. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  600. run(args, dtype::Float16(), dtype::Float16(), dtype::Float16(),
  601. dtype::Float16(), 0.35f);
  602. #endif
  603. }
  604. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4) {
  605. using namespace conv_bias;
  606. std::vector<TestArg> args = get_winograd_mk_packed_args();
  607. Checker<ConvBiasForward> checker(handle());
  608. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4);
  609. }
  610. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F23_4_NCHW44) {
  611. using namespace conv_bias;
  612. std::vector<TestArg> args =
  613. get_nchw44_conv_bias_args({3}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1);
  614. Checker<ConvBiasForward> checker(handle());
  615. check_winograd("4:2:32", checker, args, param::MatrixMul::Format::MK4,
  616. param::ConvBias::Format::NCHW44);
  617. }
  618. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63) {
  619. using namespace conv_bias;
  620. std::vector<TestArg> args = get_winograd_args(3);
  621. Checker<ConvBiasForward> checker(handle());
  622. check_winograd("1:6:32", checker, args);
  623. }
  624. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4) {
  625. using namespace conv_bias;
  626. std::vector<TestArg> args = get_winograd_mk_packed_args();
  627. Checker<ConvBiasForward> checker(handle());
  628. check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4);
  629. }
  630. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F63_4_NCHW44) {
  631. using namespace conv_bias;
  632. std::vector<TestArg> args =
  633. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  634. Checker<ConvBiasForward> checker(handle());
  635. check_winograd("4:6:16", checker, args, param::MatrixMul::Format::MK4,
  636. param::ConvBias::Format::NCHW44);
  637. }
  638. //! uncomment it when low precision mode is ok
  639. #if 0
  640. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44) {
  641. using namespace conv_bias;
  642. std::vector<TestArg> args =
  643. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  644. Checker<ConvBiasForward> checker(handle());
  645. check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
  646. param::ConvBias::Format::NCHW44);
  647. }
  648. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F73_4_NCHW44_WEIGHT_PREPROCESS) {
  649. using namespace conv_bias;
  650. std::vector<TestArg> args =
  651. get_nchw44_conv_bias_args({3},QUAN_NLMODE,BR_AND_NO_BIASMODE,1);
  652. Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
  653. handle());
  654. check_winograd("4:7:16", checker, args, param::MatrixMul::Format::MK4,
  655. param::ConvBias::Format::NCHW44);
  656. }
  657. #endif
  658. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F54) {
  659. using namespace conv_bias;
  660. std::vector<TestArg> args = get_winograd_args(4);
  661. Checker<ConvBiasForward> checker(handle());
  662. check_winograd("1:5:32", checker, args);
  663. }
  664. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F45) {
  665. using namespace conv_bias;
  666. std::vector<TestArg> args = get_winograd_args(5);
  667. Checker<ConvBiasForward> checker(handle());
  668. check_winograd("1:4:32", checker, args);
  669. }
  670. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_1) {
  671. using namespace conv_bias;
  672. Checker<ConvBiasForward> checker(handle());
  673. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  674. DType B_dtype, DType C_dtype, DType D_dtype,
  675. float eps) {
  676. for (auto&& arg : args) {
  677. checker.set_dtype(0, A_dtype)
  678. .set_dtype(1, B_dtype)
  679. .set_dtype(2, C_dtype)
  680. .set_dtype(4, D_dtype)
  681. .set_epsilon(eps)
  682. .set_param(arg.param)
  683. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  684. }
  685. };
  686. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  687. std::vector<TestArg> args_first_half(args.begin(),
  688. args.begin() + args.size() / 2);
  689. run(args_first_half, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  690. dtype::Float32{}, 1e-3f);
  691. }
  692. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F32_2) {
  693. using namespace conv_bias;
  694. Checker<ConvBiasForward> checker(handle());
  695. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  696. DType B_dtype, DType C_dtype, DType D_dtype,
  697. float eps) {
  698. for (auto&& arg : args) {
  699. checker.set_dtype(0, A_dtype)
  700. .set_dtype(1, B_dtype)
  701. .set_dtype(2, C_dtype)
  702. .set_dtype(4, D_dtype)
  703. .set_epsilon(eps)
  704. .set_param(arg.param)
  705. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  706. }
  707. };
  708. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  709. std::vector<TestArg> args_second_half(args.begin() + args.size() / 2,
  710. args.end());
  711. run(args_second_half, dtype::Float32{}, dtype::Float32{}, dtype::Float32{},
  712. dtype::Float32{}, 1e-3f);
  713. }
  714. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  715. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_F16) {
  716. using namespace conv_bias;
  717. Checker<ConvBiasForward> checker(handle());
  718. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  719. DType B_dtype, DType C_dtype, DType D_dtype,
  720. float eps) {
  721. for (auto&& arg : args) {
  722. checker.set_dtype(0, A_dtype)
  723. .set_dtype(1, B_dtype)
  724. .set_dtype(2, C_dtype)
  725. .set_dtype(4, D_dtype)
  726. .set_epsilon(eps)
  727. .set_param(arg.param)
  728. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  729. }
  730. };
  731. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  732. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  733. checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng);
  734. run(args, dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
  735. dtype::Float16{}, 0.25);
  736. }
  737. #endif
  738. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_MK_PACKED_INT8) {
  739. using namespace conv_bias;
  740. Checker<ConvBiasForward> checker(handle());
  741. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  742. DType B_dtype, DType C_dtype, DType D_dtype,
  743. float eps) {
  744. for (auto&& arg : args) {
  745. checker.set_dtype(0, A_dtype)
  746. .set_dtype(1, B_dtype)
  747. .set_dtype(2, C_dtype)
  748. .set_dtype(4, D_dtype)
  749. .set_epsilon(eps)
  750. .set_param(arg.param)
  751. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  752. }
  753. };
  754. #if MEGDNN_AARCH64
  755. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  756. #else
  757. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  758. #endif
  759. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  760. ssprintf("WINOGRAD:%s:8:2:32", matmul_name).c_str()));
  761. std::vector<TestArg> quantized_args =
  762. get_quantized_winograd_mk_packed_args(8);
  763. UniformIntRNG int_rng{-50, 50};
  764. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  765. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  766. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  767. }
  768. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8) {
  769. using namespace conv_bias;
  770. Checker<ConvBiasForward> checker(handle());
  771. auto run = [&checker](const std::vector<TestArg>& args,
  772. DType A_dtype,
  773. DType B_dtype, DType C_dtype, DType D_dtype,
  774. float eps) {
  775. for (auto&& arg : args) {
  776. checker.set_dtype(0, A_dtype)
  777. .set_dtype(1, B_dtype)
  778. .set_dtype(2, C_dtype)
  779. .set_dtype(4, D_dtype)
  780. .set_epsilon(eps)
  781. .set_param(arg.param)
  782. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  783. }
  784. };
  785. #if MEGDNN_AARCH64
  786. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  787. #else
  788. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  789. #endif
  790. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  791. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  792. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4);
  793. UniformIntRNG int_rng{-50, 50};
  794. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  795. run(quantized_args, dtype::QuantizedS8(2.5f),
  796. dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f),
  797. dtype::QuantizedS8(60.25f),1e-3);
  798. }
  799. TEST_F(ARM_COMMON_MULTI_THREADS,
  800. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32_GROUPMODE) {
  801. using namespace conv_bias;
  802. Checker<ConvBiasForward> checker(handle());
  803. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  804. DType B_dtype, DType C_dtype, DType D_dtype,
  805. float eps) {
  806. for (auto&& arg : args) {
  807. checker.set_dtype(0, A_dtype)
  808. .set_dtype(1, B_dtype)
  809. .set_dtype(2, C_dtype)
  810. .set_dtype(4, D_dtype)
  811. .set_epsilon(eps)
  812. .set_param(arg.param)
  813. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  814. }
  815. };
  816. float epsilon = 0.001;
  817. #if MEGDNN_AARCH64
  818. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  819. #else
  820. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  821. #endif
  822. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  823. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  824. std::vector<TestArg> quantized_args =
  825. get_int8_nchw44_args(3, 4, true, true);
  826. UniformIntRNG int_rng{-50, 50};
  827. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  828. run(quantized_args, dtype::QuantizedS8(0.41113496f),
  829. dtype::QuantizedS8(0.01887994f),
  830. dtype::QuantizedS32(0.41113496f * 0.01887994f),
  831. dtype::QuantizedS8(0.49550694f), epsilon);
  832. }
  833. TEST_F(ARM_COMMON_MULTI_THREADS,
  834. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_GROUPMODE) {
  835. using namespace conv_bias;
  836. Checker<ConvBiasForward> checker(handle());
  837. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  838. DType B_dtype, DType C_dtype, DType D_dtype,
  839. float eps) {
  840. for (auto&& arg : args) {
  841. checker.set_dtype(0, A_dtype)
  842. .set_dtype(1, B_dtype)
  843. .set_dtype(2, C_dtype)
  844. .set_dtype(4, D_dtype)
  845. .set_epsilon(eps)
  846. .set_param(arg.param)
  847. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  848. }
  849. };
  850. #if MEGDNN_AARCH64
  851. const char* matmul_name = "AARCH64_INT16X16X32_MK8_8X8";
  852. #else
  853. const char* matmul_name = "ARMV7_INT16X16X32_MK8_4X8";
  854. #endif
  855. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  856. ssprintf("WINOGRAD_NCHW44:%s:8:2:32", matmul_name).c_str()));
  857. std::vector<TestArg> quantized_args =
  858. get_int8_nchw44_args(3, 4, false, true);
  859. UniformIntRNG int_rng{-50, 50};
  860. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  861. run(quantized_args, dtype::QuantizedS8(2.5f), dtype::QuantizedS8(2.5f),
  862. dtype::QuantizedS32(6.25f), dtype::QuantizedS8(60.25f), 1e-3);
  863. }
  864. TEST_F(ARM_COMMON_MULTI_THREADS,
  865. CONV_BIAS_WINOGRAD_NCHW44_MK_PACKED_INT8_COMP_F32) {
  866. using namespace conv_bias;
  867. Checker<ConvBiasForward> checker(handle());
  868. auto run = [&checker](const std::vector<TestArg>& args, DType A_dtype,
  869. DType B_dtype, DType C_dtype, DType D_dtype,
  870. float eps) {
  871. for (auto&& arg : args) {
  872. checker.set_dtype(0, A_dtype)
  873. .set_dtype(1, B_dtype)
  874. .set_dtype(2, C_dtype)
  875. .set_dtype(4, D_dtype)
  876. .set_epsilon(eps)
  877. .set_param(arg.param)
  878. .execs({arg.src, arg.filter, arg.bias, {}, {}});
  879. }
  880. };
  881. float epsilon = 0.001;
  882. #if MEGDNN_AARCH64
  883. const char* matmul_name = "AARCH64_F32_MK4_4x16";
  884. #else
  885. const char* matmul_name = "ARMV7_F32_MK4_4x8";
  886. #endif
  887. checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
  888. ssprintf("WINOGRAD_NCHW44:%s:4:2:32", matmul_name).c_str()));
  889. std::vector<TestArg> quantized_args = get_int8_nchw44_args(3, 4, true);
  890. UniformIntRNG int_rng{-50, 50};
  891. checker.set_rng(0, &int_rng).set_rng(1, &int_rng).set_rng(2, &int_rng);
  892. run(quantized_args, dtype::QuantizedS8(0.41113496f),
  893. dtype::QuantizedS8(0.01887994f),
  894. dtype::QuantizedS32(0.41113496f * 0.01887994f),
  895. dtype::QuantizedS8(0.49550694f), epsilon);
  896. }
  897. #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
  898. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F23) {
  899. using namespace conv_bias;
  900. std::vector<TestArg> args = get_winograd_mk_packed_args();
  901. Checker<ConvBiasForward> checker(handle());
  902. check_winograd_fp16("1:2:32", checker, args, NULL, 0.08);
  903. }
  904. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_1) {
  905. using namespace conv_bias;
  906. std::vector<TestArg> args = get_winograd_args(5);
  907. std::vector<TestArg> args_head_half(args.begin(),
  908. args.begin() + args.size() / 2);
  909. Checker<ConvBiasForward> checker(handle());
  910. //! fp16 range -1.0 ~ 1.0
  911. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  912. check_winograd_fp16("1:4:32", checker, args_head_half, rng, 0.25);
  913. }
  914. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F45_2) {
  915. using namespace conv_bias;
  916. std::vector<TestArg> args = get_winograd_args(5);
  917. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  918. args.end());
  919. Checker<ConvBiasForward> checker(handle());
  920. //! fp16 range -1.0 ~ 1.0
  921. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  922. check_winograd_fp16("1:4:32", checker, args_back_half, rng, 0.25);
  923. }
  924. //! FIXME: This test may be failed if run `ARM_COMMON.CONV_BIAS_WINOGRAD*`, but
  925. //! it will pass when run single testcase
  926. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_F63) {
  927. using namespace conv_bias;
  928. std::vector<TestArg> args = get_winograd_args(3);
  929. Checker<ConvBiasForward> checker(handle());
  930. //! fp16 range -1.0 ~ 1.0
  931. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  932. check_winograd_fp16("1:6:32", checker, args, rng, 0.3);
  933. }
  934. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_1) {
  935. using namespace conv_bias;
  936. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  937. std::vector<TestArg> args_head_half(args.begin(),
  938. args.begin() + args.size() / 2);
  939. Checker<ConvBiasForward> checker(handle());
  940. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  941. check_winograd_fp16("8:2:32", checker, args_head_half, rng, 0.25,
  942. param::MatrixMul::Format::MK8);
  943. }
  944. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_F16_8x8_2) {
  945. using namespace conv_bias;
  946. std::vector<TestArg> args = get_winograd_mk_packed_args(8);
  947. std::vector<TestArg> args_back_half(args.begin() + args.size() / 2,
  948. args.end());
  949. Checker<ConvBiasForward> checker(handle());
  950. Float16PeriodicalRNG* rng = new Float16PeriodicalRNG(0x3c00);
  951. check_winograd_fp16("8:2:32", checker, args_back_half, rng, 0.25,
  952. param::MatrixMul::Format::MK8);
  953. }
  954. #endif
  955. TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_WINOGRAD_INT8_8X8) {
  956. using namespace conv_bias;
  957. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  958. Checker<ConvBiasForward> checker(handle());
  959. UniformIntRNG rng{-50, 50};
  960. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  961. .set_dtype(1, dtype::QuantizedS8(2.5f))
  962. .set_dtype(2, dtype::QuantizedS32(6.25f))
  963. .set_dtype(4, dtype::QuantizedS8(60.25f))
  964. .set_rng(0, &rng)
  965. .set_rng(1, &rng)
  966. .set_rng(2, &rng);
  967. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  968. }
  969. TEST_F(ARM_COMMON_MULTI_THREADS,
  970. CONV_BIAS_WINOGRAD_INT8_8X8_WEIGHT_PREPROCESS) {
  971. using namespace conv_bias;
  972. std::vector<TestArg> args = get_quantized_winograd_mk_packed_args(8);
  973. Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
  974. handle());
  975. UniformIntRNG rng{-50, 50};
  976. checker.set_dtype(0, dtype::QuantizedS8(2.5f))
  977. .set_dtype(1, dtype::QuantizedS8(2.5f))
  978. .set_dtype(2, dtype::QuantizedS32(6.25f))
  979. .set_dtype(4, dtype::QuantizedS8(60.25f))
  980. .set_rng(0, &rng)
  981. .set_rng(1, &rng)
  982. .set_rng(2, &rng);
  983. check_winograd("8:2:32", checker, args, param::MatrixMul::Format::MK8);
  984. }
  985. // clang-format on
  986. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台